From 660be3fa79b7a5b15214c2a25141a25f531f76c8 Mon Sep 17 00:00:00 2001 From: eigmax Date: Sun, 15 Oct 2023 00:53:50 +0800 Subject: [PATCH] chore: submit memory prototype --- Cargo.toml | 1 + src/all_stark.rs | 24 ++-- src/arithmetic/mod.rs | 270 ++++++++++++++++++++++++++++++++++++++++++ src/prover.rs | 4 +- src/util.rs | 150 ----------------------- 5 files changed, 285 insertions(+), 164 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index da5cbcab..0ab7d53c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,7 @@ tiny-keccak = "2.0.2" rand = "0.8.5" rand_chacha = "0.3.1" once_cell = "1.13.0" +static_assertions = "1.1.0" [dev-dependencies] env_logger = { version = "0.9.0", default-features = false } diff --git a/src/all_stark.rs b/src/all_stark.rs index 90a00c4d..7c4001cb 100644 --- a/src/all_stark.rs +++ b/src/all_stark.rs @@ -4,8 +4,8 @@ use plonky2::field::extension::Extendable; use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; -//use crate::arithmetic::arithmetic_stark; -//use crate::arithmetic::arithmetic_stark::ArithmeticStark; +use crate::arithmetic::arithmetic_stark; +use crate::arithmetic::arithmetic_stark::ArithmeticStark; //use crate::byte_packing::byte_packing_stark::{self, BytePackingStark}; use crate::config::StarkConfig; use crate::cpu::cpu_stark; @@ -25,7 +25,7 @@ use crate::stark::Stark; #[derive(Clone)] pub struct AllStark, const D: usize> { - // pub arithmetic_stark: ArithmeticStark, + pub arithmetic_stark: ArithmeticStark, // pub byte_packing_stark: BytePackingStark, pub cpu_stark: CpuStark, pub keccak_stark: KeccakStark, @@ -38,7 +38,7 @@ pub struct AllStark, const D: usize> { impl, const D: usize> Default for AllStark { fn default() -> Self { Self { - // arithmetic_stark: ArithmeticStark::default(), + arithmetic_stark: ArithmeticStark::default(), // byte_packing_stark: BytePackingStark::default(), cpu_stark: CpuStark::default(), keccak_stark: KeccakStark::default(), @@ -66,13 +66,13 @@ impl, const D: usize> AllStark { #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub enum Table { - // Arithmetic = 0, + Arithmetic = 0, // BytePacking = 1, - Cpu = 0, - Keccak = 1, - KeccakSponge = 2, - Logic = 3, - Memory = 4, + Cpu = 1, + Keccak = 2, + KeccakSponge = 3, + Logic = 4, + Memory = 5, } pub(crate) const NUM_TABLES: usize = Table::Memory as usize + 1; @@ -80,7 +80,7 @@ pub(crate) const NUM_TABLES: usize = Table::Memory as usize + 1; impl Table { pub(crate) fn all() -> [Self; NUM_TABLES] { [ - // Self::Arithmetic, + Self::Arithmetic, // Self::BytePacking, Self::Cpu, Self::Keccak, @@ -103,7 +103,6 @@ pub(crate) fn all_cross_table_lookups() -> Vec> { ] } -/* fn ctl_arithmetic() -> CrossTableLookup { CrossTableLookup::new( vec![cpu_stark::ctl_arithmetic_base_rows()], @@ -111,6 +110,7 @@ fn ctl_arithmetic() -> CrossTableLookup { ) } +/* fn ctl_byte_packing() -> CrossTableLookup { let cpu_packing_looking = TableWithColumns::new( Table::Cpu, diff --git a/src/arithmetic/mod.rs b/src/arithmetic/mod.rs index 8b137891..d53be0db 100644 --- a/src/arithmetic/mod.rs +++ b/src/arithmetic/mod.rs @@ -1 +1,271 @@ +pub mod arithmetic_stark; +pub mod columns; +pub mod shift; +pub mod addcy; +pub mod divmod; +pub mod mul; +pub mod utils; +use num::Zero; +use plonky2::field::types::PrimeField64; +use crate::util::*; + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub(crate) enum BinaryOperator { + Add, + Mul, + Sub, + Div, + Mod, + Lt, + Gt, + //Byte, + Shl, // simulated with MUL + Shr, // simulated with DIV +} + +impl BinaryOperator { + pub(crate) fn result(&self, input0: u32, input1: u32) -> u32 { + match self { + BinaryOperator::Add => input0.overflowing_add(input1).0, + BinaryOperator::Mul => input0.overflowing_mul(input1).0, + BinaryOperator::Shl => { + if input0 < 256 { + input1 << input0 + } else { + u32::zero() + } + } + BinaryOperator::Sub => input0.overflowing_sub(input1).0, + BinaryOperator::Div => { + if input1.is_zero() { + u32::zero() + } else { + input0 / input1 + } + } + BinaryOperator::Shr => { + if input0 < 256 { + input1 >> input0 + } else { + u32::zero() + } + } + BinaryOperator::Mod => { + if input1.is_zero() { + u32::zero() + } else { + input0 % input1 + } + } + BinaryOperator::Lt => u32::from((input0 < input1) as u8), + BinaryOperator::Gt => u32::from((input0 > input1) as u8), + /* + BinaryOperator::Byte => { + if input0 >= 32.into() { + u32::zero() + } else { + input1.byte(31 - input0.as_usize()).into() + } + } + */ + } + } + + pub(crate) fn row_filter(&self) -> usize { + match self { + BinaryOperator::Add => columns::IS_ADD, + BinaryOperator::Mul => columns::IS_MUL, + BinaryOperator::Sub => columns::IS_SUB, + BinaryOperator::Div => columns::IS_DIV, + BinaryOperator::Mod => columns::IS_MOD, + BinaryOperator::Lt => columns::IS_LT, + BinaryOperator::Gt => columns::IS_GT, + //BinaryOperator::Byte => columns::IS_BYTE, + BinaryOperator::Shl => columns::IS_SHL, + BinaryOperator::Shr => columns::IS_SHR, + } + } +} + +#[allow(clippy::enum_variant_names)] +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub(crate) enum TernaryOperator { + AddMod, + MulMod, + SubMod, +} + +impl TernaryOperator { + pub(crate) fn result(&self, input0: u32, input1: u32, input2: u32) -> u32 { + match self { + TernaryOperator::AddMod => ((input0 + input1) % input2), + TernaryOperator::MulMod => ((input0 * input1) % input2), + TernaryOperator::SubMod => ((input0 - input1) % input2), + } + } + + pub(crate) fn row_filter(&self) -> usize { + match self { + TernaryOperator::AddMod => columns::IS_ADDMOD, + TernaryOperator::MulMod => columns::IS_MULMOD, + TernaryOperator::SubMod => columns::IS_SUBMOD, + } + } +} + +/// An enum representing arithmetic operations that can be either binary or ternary. +#[derive(Debug)] +pub(crate) enum Operation { + BinaryOperation { + operator: BinaryOperator, + input0: u32, + input1: u32, + result: u32, + }, + TernaryOperation { + operator: TernaryOperator, + input0: u32, + input1: u32, + input2: u32, + result: u32, + }, +} + +impl Operation { + /// Create a binary operator with given inputs. + /// + /// NB: This works as you would expect, EXCEPT for SHL and SHR, + /// whose inputs need a small amount of preprocessing. Specifically, + /// to create `SHL(shift, value)`, call (note the reversal of + /// argument order): + /// + /// `Operation::binary(BinaryOperator::Shl, value, 1 << shift)` + /// + /// Similarly, to create `SHR(shift, value)`, call + /// + /// `Operation::binary(BinaryOperator::Shr, value, 1 << shift)` + /// + /// See witness/operation.rs::append_shift() for an example (indeed + /// the only call site for such inputs). + pub(crate) fn binary(operator: BinaryOperator, input0: u32, input1: u32) -> Self { + let result = operator.result(input0, input1); + Self::BinaryOperation { + operator, + input0, + input1, + result, + } + } + + pub(crate) fn ternary( + operator: TernaryOperator, + input0: u32, + input1: u32, + input2: u32, + ) -> Self { + let result = operator.result(input0, input1, input2); + Self::TernaryOperation { + operator, + input0, + input1, + input2, + result, + } + } + + pub(crate) fn result(&self) -> u32 { + match self { + Operation::BinaryOperation { result, .. } => *result, + Operation::TernaryOperation { result, .. } => *result, + } + } + + /// Convert operation into one or two rows of the trace. + /// + /// Morally these types should be [F; NUM_ARITH_COLUMNS], but we + /// use vectors because that's what utils::transpose (who consumes + /// the result of this function as part of the range check code) + /// expects. + /// + /// The `is_simulated` bool indicates whether we use a native arithmetic + /// operation or simulate one with another. This is used to distinguish + /// SHL and SHR operations that are simulated through MUL and DIV respectively. + fn to_rows(&self) -> (Vec, Option>) { + match *self { + Operation::BinaryOperation { + operator, + input0, + input1, + result, + } => binary_op_to_rows(operator, input0, input1, result), + Operation::TernaryOperation { + operator, + input0, + input1, + input2, + result, + } => ternary_op_to_rows(operator.row_filter(), input0, input1, input2, result), + } + } +} + +fn ternary_op_to_rows( + row_filter: usize, + input0: u32, + input1: u32, + input2: u32, + _result: u32, +) -> (Vec, Option>) { + let mut row1 = vec![F::ZERO; columns::NUM_ARITH_COLUMNS]; + let mut row2 = vec![F::ZERO; columns::NUM_ARITH_COLUMNS]; + + row1[row_filter] = F::ONE; + + // FIXME + // modular::generate(&mut row1, &mut row2, row_filter, input0, input1, input2); + + (row1, Some(row2)) +} + +fn binary_op_to_rows( + op: BinaryOperator, + input0: u32, + input1: u32, + result: u32, +) -> (Vec, Option>) { + let mut row = vec![F::ZERO; columns::NUM_ARITH_COLUMNS]; + row[op.row_filter()] = F::ONE; + + match op { + BinaryOperator::Add | BinaryOperator::Sub | BinaryOperator::Lt | BinaryOperator::Gt => { + addcy::generate(&mut row, op.row_filter(), input0, input1); + (row, None) + } + BinaryOperator::Mul => { + mul::generate(&mut row, input0, input1); + (row, None) + } + BinaryOperator::Shl => { + let mut nv = vec![F::ZERO; columns::NUM_ARITH_COLUMNS]; + shift::generate(&mut row, &mut nv, true, input0, input1, result); + (row, None) + } + BinaryOperator::Div | BinaryOperator::Mod => { + let mut nv = vec![F::ZERO; columns::NUM_ARITH_COLUMNS]; + divmod::generate(&mut row, &mut nv, op.row_filter(), input0, input1, result); + (row, Some(nv)) + } + BinaryOperator::Shr => { + let mut nv = vec![F::ZERO; columns::NUM_ARITH_COLUMNS]; + shift::generate(&mut row, &mut nv, false, input0, input1, result); + (row, Some(nv)) + } + /* + BinaryOperator::Byte => { + byte::generate(&mut row, input0, input1); + (row, None) + } + */ + } +} diff --git a/src/prover.rs b/src/prover.rs index efd8a7bf..2cb6b530 100644 --- a/src/prover.rs +++ b/src/prover.rs @@ -176,7 +176,6 @@ where F: RichField + Extendable, C: GenericConfig, { - /* let arithmetic_proof = timed!( timing, "prove Arithmetic STARK", @@ -191,6 +190,7 @@ where timing, )? ); + /* let byte_packing_proof = timed!( timing, "prove byte packing STARK", @@ -278,7 +278,7 @@ where ); Ok([ - //arithmetic_proof, + arithmetic_proof, //byte_packing_proof, cpu_proof, keccak_proof, diff --git a/src/util.rs b/src/util.rs index 08233056..f0ede6bd 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,6 +1,5 @@ use std::mem::{size_of, transmute_copy, ManuallyDrop}; -use ethereum_types::{H160, H256, U256}; use itertools::Itertools; use num::BigUint; use plonky2::field::extension::Extendable; @@ -47,76 +46,6 @@ pub fn trace_rows_to_poly_values( .collect() } -/// Returns the lowest LE 32-bit limb of a `U256` as a field element, -/// and errors if the integer is actually greater. -pub(crate) fn u256_to_u32(u256: U256) -> Result { - if TryInto::::try_into(u256).is_err() { - return Err(ProgramError::IntegerTooLarge); - } - - Ok(F::from_canonical_u32(u256.low_u32())) -} - -/// Returns the lowest LE 64-bit word of a `U256` as two field elements -/// each storing a 32-bit limb, and errors if the integer is actually greater. -pub(crate) fn u256_to_u64(u256: U256) -> Result<(F, F), ProgramError> { - if TryInto::::try_into(u256).is_err() { - return Err(ProgramError::IntegerTooLarge); - } - - Ok(( - F::from_canonical_u32(u256.low_u64() as u32), - F::from_canonical_u32((u256.low_u64() >> 32) as u32), - )) -} - -/// Safe alternative to `U256::as_usize()`, which errors in case of overflow instead of panicking. -pub(crate) fn u256_to_usize(u256: U256) -> Result { - u256.try_into().map_err(|_| ProgramError::IntegerTooLarge) -} - -#[allow(unused)] // TODO: Remove? -/// Returns the 32-bit little-endian limbs of a `U256`. -pub(crate) fn u256_limbs(u256: U256) -> [F; 8] { - u256.0 - .into_iter() - .flat_map(|limb_64| { - let lo = limb_64 as u32; - let hi = (limb_64 >> 32) as u32; - [lo, hi] - }) - .map(F::from_canonical_u32) - .collect_vec() - .try_into() - .unwrap() -} - -#[allow(unused)] -/// Returns the 32-bit little-endian limbs of a `H256`. -pub(crate) fn h256_limbs(h256: H256) -> [F; 8] { - let mut temp_h256 = h256.0; - temp_h256.reverse(); - temp_h256 - .chunks(4) - .map(|chunk| u32::from_le_bytes(chunk.try_into().unwrap())) - .map(F::from_canonical_u32) - .collect_vec() - .try_into() - .unwrap() -} - -#[allow(unused)] -/// Returns the 32-bit limbs of a `U160`. -pub(crate) fn h160_limbs(h160: H160) -> [F; 5] { - h160.0 - .chunks(4) - .map(|chunk| u32::from_le_bytes(chunk.try_into().unwrap())) - .map(F::from_canonical_u32) - .collect_vec() - .try_into() - .unwrap() -} - pub(crate) const fn indices_arr() -> [usize; N] { let mut indices_arr = [0; N]; let mut i = 0; @@ -134,82 +63,3 @@ pub(crate) unsafe fn transmute_no_compile_time_size_checks(value: T) -> U // Copy the bit pattern. The original value is no longer safe to use. transmute_copy(&value) } - -pub(crate) fn addmod(x: U256, y: U256, m: U256) -> U256 { - if m.is_zero() { - return m; - } - let x = u256_to_biguint(x); - let y = u256_to_biguint(y); - let m = u256_to_biguint(m); - biguint_to_u256((x + y) % m) -} - -pub(crate) fn mulmod(x: U256, y: U256, m: U256) -> U256 { - if m.is_zero() { - return m; - } - let x = u256_to_biguint(x); - let y = u256_to_biguint(y); - let m = u256_to_biguint(m); - biguint_to_u256(x * y % m) -} - -pub(crate) fn submod(x: U256, y: U256, m: U256) -> U256 { - if m.is_zero() { - return m; - } - let mut x = u256_to_biguint(x); - let y = u256_to_biguint(y); - let m = u256_to_biguint(m); - while x < y { - x += &m; - } - biguint_to_u256((x - y) % m) -} - -pub(crate) fn u256_to_biguint(x: U256) -> BigUint { - let mut bytes = [0u8; 32]; - x.to_little_endian(&mut bytes); - BigUint::from_bytes_le(&bytes) -} - -pub(crate) fn biguint_to_u256(x: BigUint) -> U256 { - let bytes = x.to_bytes_le(); - // This could panic if `bytes.len() > 32` but this is only - // used here with `BigUint` constructed from `U256`. - U256::from_little_endian(&bytes) -} - -pub(crate) fn mem_vec_to_biguint(x: &[U256]) -> BigUint { - BigUint::from_slice( - &x.iter() - .map(|&n| n.try_into().unwrap()) - .flat_map(|a: u128| { - [ - (a % (1 << 32)) as u32, - ((a >> 32) % (1 << 32)) as u32, - ((a >> 64) % (1 << 32)) as u32, - ((a >> 96) % (1 << 32)) as u32, - ] - }) - .collect::>(), - ) -} - -pub(crate) fn biguint_to_mem_vec(x: BigUint) -> Vec { - let num_limbs = ((x.bits() + 127) / 128) as usize; - - let mut digits = x.iter_u64_digits(); - - let mut mem_vec = Vec::with_capacity(num_limbs); - while let Some(lo) = digits.next() { - let hi = digits.next().unwrap_or(0); - mem_vec.push(U256::from(lo as u128 | (hi as u128) << 64)); - } - mem_vec -} - -pub(crate) fn h2u(h: H256) -> U256 { - U256::from_big_endian(&h.0) -}