Skip to content

Commit

Permalink
feat: impl kernel (#3)
Browse files Browse the repository at this point in the history
* feat: impl kernel

* opt: fix code hash
  • Loading branch information
eigmax authored Oct 24, 2023
1 parent 27c5395 commit 063654c
Show file tree
Hide file tree
Showing 12 changed files with 224 additions and 21 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ rand_chacha = "0.3.1"
once_cell = "1.13.0"
static_assertions = "1.1.0"
keccak-hash = "0.10.0"
byteorder = "1.5.0"

elf = { version = "0.7", default-features = false }

Expand Down
114 changes: 109 additions & 5 deletions src/cpu/bootstrap_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@ use crate::cpu::columns::CpuColumnsView;
use crate::cpu::kernel::KERNEL;
use crate::cpu::membus::NUM_GP_CHANNELS;
use crate::generation::state::GenerationState;
use crate::keccak_sponge::columns::KECCAK_RATE_BYTES;
use crate::keccak_sponge::columns::KECCAK_WIDTH_BYTES;
use crate::memory::segments::Segment;
use crate::witness::memory::MemoryAddress;
use crate::witness::util::mem_write_gp_log_and_fill;
use crate::witness::util::{keccak_sponge_log, mem_write_gp_log_and_fill};

pub(crate) fn generate_bootstrap_kernel<F: Field>(state: &mut GenerationState<F>) {
// Iterate through chunks of the code, such that we can write one chunk to memory per row.
Expand Down Expand Up @@ -43,16 +45,118 @@ pub(crate) fn generate_bootstrap_kernel<F: Field>(state: &mut GenerationState<F>
final_cpu_row.mem_channels[2].value[0] = F::ZERO; // virt
final_cpu_row.mem_channels[3].value[0] = F::from_canonical_usize(KERNEL.code.len()); // len

// final_cpu_row.mem_channels[4].value = KERNEL.code_hash.map(F::from_canonical_u32);
// final_cpu_row.mem_channels[4].value.reverse();
/*
final_cpu_row.mem_channels[4].value = KERNEL.code_hash.map(F::from_canonical_u32);
final_cpu_row.mem_channels[4].value.reverse();
keccak_sponge_log(
state,
MemoryAddress::new(0, Segment::Code, 0),
KERNEL.code.clone(),
);
*/
state.traces.push_cpu(final_cpu_row);
state.traces.push_cpu(final_cpu_row);
log::info!("Bootstrapping took {} cycles", state.traces.clock());
}

pub(crate) fn eval_bootstrap_kernel_packed<F: Field, P: PackedField<Scalar = F>>(
local_values: &CpuColumnsView<P>,
next_values: &CpuColumnsView<P>,
yield_constr: &mut ConstraintConsumer<P>,
) {
// IS_BOOTSTRAP_KERNEL must have an init value of 1, a final value of 0, and a delta in {0, -1}.
let local_is_bootstrap = local_values.is_bootstrap_kernel;
let next_is_bootstrap = next_values.is_bootstrap_kernel;
yield_constr.constraint_first_row(local_is_bootstrap - P::ONES);
yield_constr.constraint_last_row(local_is_bootstrap);
let delta_is_bootstrap = next_is_bootstrap - local_is_bootstrap;
yield_constr.constraint_transition(delta_is_bootstrap * (delta_is_bootstrap + P::ONES));

// If this is a bootloading row and the i'th memory channel is used, it must have the right
// address, name context = 0, segment = Code, virt = clock * NUM_GP_CHANNELS + i.
let code_segment = F::from_canonical_usize(Segment::Code as usize);
for (i, channel) in local_values.mem_channels.iter().enumerate() {
let filter = local_is_bootstrap * channel.used;
yield_constr.constraint(filter * channel.addr_context);
yield_constr.constraint(filter * (channel.addr_segment - code_segment));
let expected_virt = local_values.clock * F::from_canonical_usize(NUM_GP_CHANNELS)
+ F::from_canonical_usize(i);
yield_constr.constraint(filter * (channel.addr_virtual - expected_virt));
}

// If this is the final bootstrap row (i.e. delta_is_bootstrap = 1), check that
// - all memory channels are disabled
// - the current kernel hash matches a precomputed one
for channel in local_values.mem_channels.iter() {
yield_constr.constraint_transition(delta_is_bootstrap * channel.used);
}
for (&expected, actual) in KERNEL
.code_hash
.iter()
.rev()
.zip(local_values.mem_channels.last().unwrap().value)
{
let expected = P::from(F::from_canonical_u32(expected));
let diff = expected - actual;
yield_constr.constraint_transition(delta_is_bootstrap * diff);
}
}

pub(crate) fn eval_bootstrap_kernel_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
local_values: &CpuColumnsView<ExtensionTarget<D>>,
next_values: &CpuColumnsView<ExtensionTarget<D>>,
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let one = builder.one_extension();

// IS_BOOTSTRAP_KERNEL must have an init value of 1, a final value of 0, and a delta in {0, -1}.
let local_is_bootstrap = local_values.is_bootstrap_kernel;
let next_is_bootstrap = next_values.is_bootstrap_kernel;
let constraint = builder.sub_extension(local_is_bootstrap, one);
yield_constr.constraint_first_row(builder, constraint);
yield_constr.constraint_last_row(builder, local_is_bootstrap);
let delta_is_bootstrap = builder.sub_extension(next_is_bootstrap, local_is_bootstrap);
let constraint =
builder.mul_add_extension(delta_is_bootstrap, delta_is_bootstrap, delta_is_bootstrap);
yield_constr.constraint_transition(builder, constraint);

// If this is a bootloading row and the i'th memory channel is used, it must have the right
// address, name context = 0, segment = Code, virt = clock * NUM_GP_CHANNELS + i.
let code_segment =
builder.constant_extension(F::Extension::from_canonical_usize(Segment::Code as usize));
for (i, channel) in local_values.mem_channels.iter().enumerate() {
let filter = builder.mul_extension(local_is_bootstrap, channel.used);
let constraint = builder.mul_extension(filter, channel.addr_context);
yield_constr.constraint(builder, constraint);

let segment_diff = builder.sub_extension(channel.addr_segment, code_segment);
let constraint = builder.mul_extension(filter, segment_diff);
yield_constr.constraint(builder, constraint);

let i_ext = builder.constant_extension(F::Extension::from_canonical_usize(i));
let num_gp_channels_f = F::from_canonical_usize(NUM_GP_CHANNELS);
let expected_virt =
builder.mul_const_add_extension(num_gp_channels_f, local_values.clock, i_ext);
let virt_diff = builder.sub_extension(channel.addr_virtual, expected_virt);
let constraint = builder.mul_extension(filter, virt_diff);
yield_constr.constraint(builder, constraint);
}

// If this is the final bootstrap row (i.e. delta_is_bootstrap = 1), check that
// - all memory channels are disabled
// - the current kernel hash matches a precomputed one
for channel in local_values.mem_channels.iter() {
let constraint = builder.mul_extension(delta_is_bootstrap, channel.used);
yield_constr.constraint_transition(builder, constraint);
}
for (&expected, actual) in KERNEL
.code_hash
.iter()
.rev()
.zip(local_values.mem_channels.last().unwrap().value)
{
let expected = builder.constant_extension(F::Extension::from_canonical_u32(expected));
let diff = builder.sub_extension(expected, actual);
let constraint = builder.mul_extension(delta_is_bootstrap, diff);
yield_constr.constraint_transition(builder, constraint);
}
}
6 changes: 3 additions & 3 deletions src/cpu/cpu_stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use crate::cpu::{
modfp254, pc, push0, shift, simple_logic, stack, stack_bounds, syscalls_exceptions,
};
*/
use crate::cpu::{control_flow, decode, membus, pc, shift, simple_logic};
use crate::cpu::{bootstrap_kernel, control_flow, decode, membus, pc, shift, simple_logic};
use crate::cross_table_lookup::{Column, TableWithColumns};
use crate::evaluation_frame::{StarkEvaluationFrame, StarkFrame};
use crate::memory::segments::Segment;
Expand Down Expand Up @@ -224,8 +224,8 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for CpuStark<F, D
let next_values: &[P; NUM_CPU_COLUMNS] = vars.get_next_values().try_into().unwrap();
let next_values: &CpuColumnsView<P> = next_values.borrow();

/*
bootstrap_kernel::eval_bootstrap_kernel_packed(local_values, next_values, yield_constr);
/*
contextops::eval_packed(local_values, next_values, yield_constr);
*/
control_flow::eval_packed_generic(local_values, next_values, yield_constr);
Expand Down Expand Up @@ -267,13 +267,13 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for CpuStark<F, D
vars.get_next_values().try_into().unwrap();
let next_values: &CpuColumnsView<ExtensionTarget<D>> = next_values.borrow();

/*
bootstrap_kernel::eval_bootstrap_kernel_ext_circuit(
builder,
local_values,
next_values,
yield_constr,
);
/*
contextops::eval_ext_circuit(builder, local_values, next_values, yield_constr);
*/
control_flow::eval_ext_circuit(builder, local_values, next_values, yield_constr);
Expand Down
12 changes: 11 additions & 1 deletion src/cpu/kernel/assembler.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,28 @@
use keccak_hash::keccak;
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

#[derive(PartialEq, Eq, Debug, Serialize, Deserialize)]
pub struct Kernel {
pub(crate) code: Vec<u8>,
pub(crate) code_hash: [u32; 8],
pub(crate) ordered_labels: Vec<String>,
pub(crate) global_labels: HashMap<String, usize>,
}

// FIXME: impl the mips vm
pub(crate) fn combined_kernel() -> Kernel {
let code: Vec<u8> = vec![];
let code_hash_bytes = keccak(&code).0;
let code_hash_be = core::array::from_fn(|i| {
u32::from_le_bytes(core::array::from_fn(|j| code_hash_bytes[i * 4 + j]))
});
let code_hash = code_hash_be.map(u32::from_be);

Kernel {
code: vec![],
code,
code_hash,
ordered_labels: vec![],
global_labels: HashMap::new(),
}
Expand Down
2 changes: 1 addition & 1 deletion src/cpu/keccak_util.rs → src/cpu/kernel/keccak_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ pub(crate) fn keccakf_u8s(state_u8s: &mut [u8; KECCAK_WIDTH_BYTES]) {
mod tests {
use tiny_keccak::keccakf;

use crate::cpu::keccak_util::{keccakf_u32s, keccakf_u8s};
use crate::cpu::kernel::keccak_util::{keccakf_u32s, keccakf_u8s};

#[test]
#[rustfmt::skip]
Expand Down
1 change: 1 addition & 0 deletions src/cpu/kernel/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub(crate) mod assembler;
pub(crate) mod constants;
pub mod keccak_util;

pub use assembler::KERNEL;
2 changes: 1 addition & 1 deletion src/cpu/membus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer
use crate::cpu::columns::CpuColumnsView;

/// General-purpose memory channels; they can read and write to all contexts/segments/addresses.
pub const NUM_GP_CHANNELS: usize = 32;
pub const NUM_GP_CHANNELS: usize = 5;

pub mod channel_indices {
use std::ops::Range;
Expand Down
3 changes: 1 addition & 2 deletions src/cpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@ pub mod columns;
pub(crate) mod control_flow;
pub mod cpu_stark;
pub(crate) mod decode;
pub mod keccak_util;
pub(crate) mod kernel;
pub(crate) mod membus;
pub(crate) mod pc;
pub(crate) mod simple_logic;
pub(crate) mod shift;
pub(crate) mod simple_logic;
4 changes: 2 additions & 2 deletions src/keccak_sponge/keccak_sponge_stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use plonky2::util::timing::TimingTree;
use plonky2_util::ceil_div_usize;

use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
use crate::cpu::keccak_util::keccakf_u32s;
use crate::cpu::kernel::keccak_util::keccakf_u32s;
use crate::cross_table_lookup::Column;
use crate::evaluation_frame::{StarkEvaluationFrame, StarkFrame};
use crate::keccak_sponge::columns::*;
Expand Down Expand Up @@ -754,7 +754,7 @@ mod tests {
let op = KeccakSpongeOp {
base_address: MemoryAddress {
context: 0,
segment: Segment::MainMemory as usize,
segment: Segment::Code as usize,
virt: 0,
},
timestamp: 0,
Expand Down
16 changes: 12 additions & 4 deletions src/memory/memory_stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,14 +260,18 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F
let addr_context = local_values[ADDR_CONTEXT];
let addr_segment = local_values[ADDR_SEGMENT];
let addr_virtual = local_values[ADDR_VIRTUAL];
let value_limbs: Vec<_> = (0..VALUE_LIMBS).map(|i| local_values[value_limb(i)]).collect();
let value_limbs: Vec<_> = (0..VALUE_LIMBS)
.map(|i| local_values[value_limb(i)])
.collect();

let next_timestamp = next_values[TIMESTAMP];
let next_is_read = next_values[IS_READ];
let next_addr_context = next_values[ADDR_CONTEXT];
let next_addr_segment = next_values[ADDR_SEGMENT];
let next_addr_virtual = next_values[ADDR_VIRTUAL];
let next_values_limbs: Vec<_> = (0..VALUE_LIMBS).map(|i| next_values[value_limb(i)]).collect();
let next_values_limbs: Vec<_> = (0..VALUE_LIMBS)
.map(|i| next_values[value_limb(i)])
.collect();

// The filter must be 0 or 1.
let filter = local_values[FILTER];
Expand Down Expand Up @@ -337,13 +341,17 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for MemoryStark<F
let addr_context = local_values[ADDR_CONTEXT];
let addr_segment = local_values[ADDR_SEGMENT];
let addr_virtual = local_values[ADDR_VIRTUAL];
let value_limbs: Vec<_> = (0..VALUE_LIMBS).map(|i| local_values[value_limb(i)]).collect();
let value_limbs: Vec<_> = (0..VALUE_LIMBS)
.map(|i| local_values[value_limb(i)])
.collect();
let timestamp = local_values[TIMESTAMP];

let next_addr_context = next_values[ADDR_CONTEXT];
let next_addr_segment = next_values[ADDR_SEGMENT];
let next_addr_virtual = next_values[ADDR_VIRTUAL];
let next_values_limbs: Vec<_> = (0..VALUE_LIMBS).map(|i| next_values[value_limb(i)]).collect();
let next_values_limbs: Vec<_> = (0..VALUE_LIMBS)
.map(|i| next_values[value_limb(i)])
.collect();
let next_is_read = next_values[IS_READ];
let next_timestamp = next_values[TIMESTAMP];

Expand Down
2 changes: 1 addition & 1 deletion src/memory/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ pub mod segments;

// TODO: Move to CPU module, now that channels have been removed from the memory table.
pub(crate) const NUM_CHANNELS: usize = crate::cpu::membus::NUM_CHANNELS;
pub(crate) const VALUE_LIMBS: usize = 1;
pub(crate) const VALUE_LIMBS: usize = 8;
Loading

0 comments on commit 063654c

Please sign in to comment.