Skip to content

Commit

Permalink
feat: add support for rust program (#132)
Browse files Browse the repository at this point in the history
* add initial support for rust

* reduce initial stack

* reduce patchGo

* add trace generation for new instructions

* fix bal target

* add log for write mem

* fix register sync

* tmp add dump memory

* add support for powdr-revme

* update constraint for syscall

* fix constraint for sysbrk syscall and fix fmt

* add constraint for bal

* add constraint for sys_set_thread_info syscall

* add constraint for seb/seh/wsbh

* add constraint for teq

* add constraint for rdhwr

* mov condmov/teq to misc

* fix teq

* fix wsbh

* fix constraints for bits

* add constraint for ext inst

* fix ext constraint

* fix typo errors

* remove duplicated constrain

* fix tests after rebase

* fix clippy

* add evm-no-std

* relax restriction for seg size

* fix fmt

* rename blk to brk, change some log level

* remove unused tests
  • Loading branch information
weilzkm authored Jun 7, 2024
1 parent fe386d7 commit c3f0dfc
Show file tree
Hide file tree
Showing 24 changed files with 1,312 additions and 227 deletions.
23 changes: 15 additions & 8 deletions examples/zkmips.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,18 @@ fn select_degree_bits(seg_size: usize) -> [std::ops::Range<usize>; 6] {
(65536, 3),
(262144, 4),
]);
match seg_size_to_bits.get(&seg_size) {
Some(s) => DEGREE_BITS_RANGE[*s].clone(),
None => panic!(
"Invalid segment size, supported: {:?}",
seg_size_to_bits.keys()
),

let mut index = -1;
for (key, value) in seg_size_to_bits.iter() {
if *key >= seg_size {
index = *value;
break;
}
}

match index {
-1i32 => panic!("Invalid segment size, supported largest size: 262144"),
_ => DEGREE_BITS_RANGE[index as usize].clone(),
}
}

Expand Down Expand Up @@ -100,9 +106,10 @@ fn split_elf_into_segs() {
instrumented_state.split_segment(true, &seg_path, new_writer);
}
}

instrumented_state.split_segment(true, &seg_path, new_writer);
log::info!("Split done");
log::info!("Split done {}", instrumented_state.state.step);

instrumented_state.dump_memory();
}

fn prove_single_seg() {
Expand Down
139 changes: 139 additions & 0 deletions src/cpu/bits.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
use crate::cpu::columns::CpuColumnsView;
use crate::util::{limb_from_bits_le, limb_from_bits_le_recursive};
use plonky2::field::extension::Extendable;
use plonky2::field::packed::PackedField;
use plonky2::hash::hash_types::RichField;
use plonky2::iop::ext_target::ExtensionTarget;

pub fn eval_packed<P: PackedField>(
lv: &CpuColumnsView<P>,
yield_constr: &mut ConstraintConsumer<P>,
) {
let filter_seh = lv.op.signext16;
let filter_seb = lv.op.signext8;
let filter_wsbh = lv.op.swaphalf;
let filter = filter_seh + filter_seb + filter_wsbh;

// Check rt Reg
{
let rt_reg = lv.mem_channels[0].addr_virtual;
let rt_src = limb_from_bits_le(lv.rt_bits);
yield_constr.constraint(filter * (rt_reg - rt_src));
}

// Check rd Reg
{
let rd_reg = lv.mem_channels[1].addr_virtual;
let rd_dst = limb_from_bits_le(lv.rd_bits);
yield_constr.constraint(filter * (rd_reg - rd_dst));
}

let rt = lv.mem_channels[0].value;
let bits_le = lv.general.io().rt_le;
for bit in bits_le {
yield_constr.constraint(filter * bit * (P::ONES - bit));
}
let sum = limb_from_bits_le(bits_le);

yield_constr.constraint(filter * (rt - sum));

// check seb result
let rd = lv.mem_channels[1].value;
let mut seb_result = [bits_le[7]; 32];
seb_result[..7].copy_from_slice(&bits_le[..7]);
let sum = limb_from_bits_le(seb_result);
yield_constr.constraint(filter_seb * (rd - sum));

// check seh result
let mut seh_result = [bits_le[15]; 32];
seh_result[..15].copy_from_slice(&bits_le[..15]);
let sum = limb_from_bits_le(seh_result);
yield_constr.constraint(filter_seh * (rd - sum));

// check wsbh result
let mut wsbh_result = [bits_le[0]; 32];
wsbh_result[..8].copy_from_slice(&bits_le[8..16]);
wsbh_result[8..16].copy_from_slice(&bits_le[..8]);
wsbh_result[16..24].copy_from_slice(&bits_le[24..32]);
wsbh_result[24..32].copy_from_slice(&bits_le[16..24]);

let sum = limb_from_bits_le(wsbh_result);
yield_constr.constraint(filter_wsbh * (rd - sum));
}

pub fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder<F, D>,
lv: &CpuColumnsView<ExtensionTarget<D>>,
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let filter_seh = lv.op.signext16;
let filter_seb = lv.op.signext8;
let filter_wsbh = lv.op.swaphalf;
let filter = builder.add_extension(filter_seh, filter_seb);
let filter = builder.add_extension(filter_wsbh, filter);

// Check rt Reg
{
let rt_reg = lv.mem_channels[0].addr_virtual;
let rt_src = limb_from_bits_le_recursive(builder, lv.rt_bits);
let constr = builder.sub_extension(rt_reg, rt_src);
let constr = builder.mul_extension(constr, filter);
yield_constr.constraint(builder, constr);
}

// Check rd Reg
{
let rd_reg = lv.mem_channels[1].addr_virtual;
let rd_src = limb_from_bits_le_recursive(builder, lv.rd_bits);
let constr = builder.sub_extension(rd_reg, rd_src);
let constr = builder.mul_extension(constr, filter);
yield_constr.constraint(builder, constr);
}

let one = builder.one_extension();
let rt = lv.mem_channels[0].value;
let bits_le = lv.general.io().rt_le;
for bit in bits_le {
let bit_neg = builder.sub_extension(one, bit);
let t = builder.mul_many_extension([filter, bit, bit_neg]);
yield_constr.constraint(builder, t);
}
let sum = limb_from_bits_le_recursive(builder, bits_le);

let t1 = builder.sub_extension(rt, sum);
let t = builder.mul_extension(filter, t1);
yield_constr.constraint(builder, t);

// check seb result
let rd = lv.mem_channels[1].value;
let mut seb_result = [bits_le[7]; 32];
seb_result[..7].copy_from_slice(&bits_le[..7]);
let sum = limb_from_bits_le_recursive(builder, seb_result);

let t1 = builder.sub_extension(rd, sum);
let t = builder.mul_extension(filter_seb, t1);
yield_constr.constraint(builder, t);

// check seh result
let mut seh_result = [bits_le[15]; 32];
seh_result[..15].copy_from_slice(&bits_le[..15]);
let sum = limb_from_bits_le_recursive(builder, seh_result);

let t1 = builder.sub_extension(rd, sum);
let t = builder.mul_extension(filter_seh, t1);
yield_constr.constraint(builder, t);

// check wsbh result
let mut wsbh_result = [bits_le[0]; 32];
wsbh_result[..8].copy_from_slice(&bits_le[8..16]);
wsbh_result[8..16].copy_from_slice(&bits_le[..8]);
wsbh_result[16..24].copy_from_slice(&bits_le[24..32]);
wsbh_result[24..32].copy_from_slice(&bits_le[16..24]);

let sum = limb_from_bits_le_recursive(builder, wsbh_result);

let t1 = builder.sub_extension(rd, sum);
let t = builder.mul_extension(filter_wsbh, t1);
yield_constr.constraint(builder, t);
}
28 changes: 26 additions & 2 deletions src/cpu/columns/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub(crate) union CpuGeneralColumnsView<T: Copy> {
shift: CpuShiftView<T>,
io: CpuIOAuxView<T>,
hash: CpuHashView<T>,
misc: CpuMiscView<T>,
}

impl<T: Copy> CpuGeneralColumnsView<T> {
Expand Down Expand Up @@ -54,6 +55,16 @@ impl<T: Copy> CpuGeneralColumnsView<T> {
unsafe { &mut self.shift }
}

// SAFETY: Each view is a valid interpretation of the underlying array.
pub(crate) fn misc(&self) -> &CpuMiscView<T> {
unsafe { &self.misc }
}

// SAFETY: Each view is a valid interpretation of the underlying array.
pub(crate) fn misc_mut(&mut self) -> &mut CpuMiscView<T> {
unsafe { &mut self.misc }
}

// SAFETY: Each view is a valid interpretation of the underlying array.
pub(crate) fn io(&self) -> &CpuIOAuxView<T> {
unsafe { &self.io }
Expand Down Expand Up @@ -96,12 +107,25 @@ impl<T: Copy> BorrowMut<[T; NUM_SHARED_COLUMNS]> for CpuGeneralColumnsView<T> {

#[derive(Copy, Clone)]
pub(crate) struct CpuSyscallView<T: Copy> {
pub(crate) cond: [T; 10],
pub(crate) sysnum: [T; 11],
pub(crate) cond: [T; 12],
pub(crate) sysnum: [T; 12],
pub(crate) a0: [T; 3],
pub(crate) a1: T,
}

#[derive(Copy, Clone)]
pub(crate) struct CpuMiscView<T: Copy> {
pub(crate) rs_bits: [T; 32],
pub(crate) is_msb: [T; 32],
pub(crate) is_lsb: [T; 32],
pub(crate) auxm: T,
pub(crate) auxl: T,
pub(crate) auxs: T,
pub(crate) rd_index: T,
pub(crate) rd_index_eq_0: T,
pub(crate) rd_index_eq_29: T,
}

#[derive(Copy, Clone)]
pub(crate) struct CpuLogicView<T: Copy> {
// Pseudoinverse of `(input0 - input1)`. Used prove that they are unequal. Assumes 32-bit limbs.
Expand Down
2 changes: 2 additions & 0 deletions src/cpu/columns/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ pub struct MemIOView<T: Copy> {
pub(crate) is_swr: T,
pub(crate) is_ll: T,
pub(crate) is_sc: T,
pub(crate) is_sdc1: T,
pub(crate) is_lb: T,
pub(crate) aux_filter: T,
}
Expand All @@ -78,6 +79,7 @@ pub struct CpuColumnsView<T: Copy> {

/// If CPU cycle: The program counter for the current instruction.
pub program_counter: T,
pub next_program_counter: T,

/// If CPU cycle: We're in kernel (privileged) mode.
pub is_kernel_mode: T,
Expand Down
7 changes: 7 additions & 0 deletions src/cpu/columns/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ pub struct OpsColumnsView<T: Copy> {
pub keccak_general: T,
pub jumps: T,
pub jumpi: T,
pub jumpdirect: T,
pub branch: T,
pub pc: T,
pub get_context: T,
Expand All @@ -29,6 +30,12 @@ pub struct OpsColumnsView<T: Copy> {
pub m_op_load: T,
pub m_op_store: T,
pub nop: T,
pub ext: T,
pub rdhwr: T,
pub signext8: T,
pub signext16: T,
pub swaphalf: T,
pub teq: T,

pub syscall: T,
}
Expand Down
6 changes: 5 additions & 1 deletion src/cpu/cpu_stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer
use crate::cpu::columns::{COL_MAP, NUM_CPU_COLUMNS};
//use crate::cpu::membus::NUM_GP_CHANNELS;
use crate::cpu::{
bootstrap_kernel, count, decode, exit_kernel, jumps, membus, memio, shift, syscall,
bits, bootstrap_kernel, count, decode, exit_kernel, jumps, membus, memio, misc, shift, syscall,
};
use crate::cross_table_lookup::{Column, Filter, TableWithColumns};
use crate::evaluation_frame::{StarkEvaluationFrame, StarkFrame};
Expand Down Expand Up @@ -210,6 +210,8 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for CpuStark<F, D
shift::eval_packed(local_values, yield_constr);
count::eval_packed(local_values, yield_constr);
syscall::eval_packed(local_values, yield_constr);
bits::eval_packed(local_values, yield_constr);
misc::eval_packed(local_values, yield_constr);
exit_kernel::eval_exit_kernel_packed(local_values, next_values, yield_constr);
}

Expand Down Expand Up @@ -240,6 +242,8 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for CpuStark<F, D
shift::eval_ext_circuit(builder, local_values, yield_constr);
count::eval_ext_circuit(builder, local_values, yield_constr);
syscall::eval_ext_circuit(builder, local_values, yield_constr);
bits::eval_ext_circuit(builder, local_values, yield_constr);
misc::eval_ext_circuit(builder, local_values, yield_constr);
exit_kernel::eval_exit_kernel_ext_circuit(builder, local_values, next_values, yield_constr);
}

Expand Down
8 changes: 6 additions & 2 deletions src/cpu/exit_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,27 @@ pub(crate) fn generate_exit_kernel<F: RichField>(state: &mut GenerationState<F>,
cpu_row.clock = F::from_canonical_usize(state.traces.clock());
cpu_row.is_kernel_mode = F::ONE;
cpu_row.program_counter = F::from_canonical_usize(state.registers.program_counter);
cpu_row.next_program_counter = F::from_canonical_usize(state.registers.next_pc);

let log_end_pc = reg_zero_write_with_log(0, kernel.program.end_pc, state, &mut cpu_row);
state.traces.push_memory(log_end_pc);
state.traces.push_cpu(cpu_row);

// sync registers to memory
let registers_addr: Vec<_> = (REGISTERS_START..=REGISTERS_START + (36 << 2) - 1)
let registers_addr: Vec<_> = (REGISTERS_START..=REGISTERS_START + (39 << 2) - 1)
.step_by(4)
.collect::<Vec<u32>>();
let mut registers_value: [u32; 36] = [0; 36];
let mut registers_value: [u32; 39] = [0; 39];
for i in 0..32 {
registers_value[i] = state.registers.gprs[i] as u32;
}
registers_value[32] = state.registers.lo as u32;
registers_value[33] = state.registers.hi as u32;
registers_value[34] = state.registers.heap as u32;
registers_value[35] = state.registers.program_counter as u32;
registers_value[36] = state.registers.next_pc as u32;
registers_value[37] = state.registers.brk as u32;
registers_value[38] = state.registers.local_user as u32;

let register_addr_value: Vec<_> = registers_addr.iter().zip(registers_value).collect();
for chunk in &register_addr_value.iter().chunks(8) {
Expand Down
Loading

0 comments on commit c3f0dfc

Please sign in to comment.