From 80a0f8813c2974e2b56ffc988becff1c47207c73 Mon Sep 17 00:00:00 2001 From: weilzkm <140377101+weilzkm@users.noreply.github.com> Date: Fri, 27 Sep 2024 08:11:00 +0800 Subject: [PATCH] feat: Optimize segment size, use cycle size instead (#168) * feat: use cycle instead of instructions for segment size * fix return step in split_*_segs function * fix clippy * update default segment size to 65536 * fix bugs in step check --- emulator/src/memory.rs | 4 ++-- emulator/src/state.rs | 33 +++++++++++++++++++++++++- emulator/src/utils.rs | 17 ++++++------- prover/examples/README.md | 6 ++--- prover/examples/zkmips.rs | 8 +++---- prover/src/cpu/kernel/assembler.rs | 8 ++++++- prover/src/cpu/kernel/elf.rs | 3 +++ runtime/entrypoint/src/syscalls/mod.rs | 2 +- 8 files changed, 61 insertions(+), 20 deletions(-) diff --git a/emulator/src/memory.rs b/emulator/src/memory.rs index 8b646add..f3d9f527 100644 --- a/emulator/src/memory.rs +++ b/emulator/src/memory.rs @@ -161,8 +161,8 @@ impl Memory { } } - pub fn page_count(&self) -> usize { - self.pages.len() + pub fn page_count(&self) -> u64 { + self.rtrace.len() as u64 } pub fn for_each_page>) -> Result<(), String>>( diff --git a/emulator/src/state.rs b/emulator/src/state.rs index 8a26b0fd..6dbec34f 100644 --- a/emulator/src/state.rs +++ b/emulator/src/state.rs @@ -19,6 +19,12 @@ pub const FD_HINT: u32 = 4; pub const MIPS_EBADF: u32 = 9; pub const REGISTERS_START: u32 = 0x81020400u32; +pub const PAGE_LOAD_CYCLES: u64 = 128; +pub const PAGE_HASH_CYCLES: u64 = 1; +pub const PAGE_CYCLES: u64 = PAGE_LOAD_CYCLES + PAGE_HASH_CYCLES; +pub const IMAGE_ID_CYCLES: u64 = 3; +pub const MAX_INSTRUCTION_CYCLES: u64 = PAGE_CYCLES * 6; //TOFIX +pub const RESERVE_CYCLES: u64 = IMAGE_ID_CYCLES + MAX_INSTRUCTION_CYCLES; // image_id = keccak(page_hash_root || end_pc) #[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone, Default)] @@ -31,6 +37,7 @@ pub struct Segment { pub image_id: [u8; 32], pub page_hash_root: [u8; 32], pub end_pc: u32, + pub step: u64, pub input_stream: Vec>, pub input_stream_ptr: usize, pub public_values_stream: Vec, @@ -62,6 +69,11 @@ pub struct State { /// step tracks the total step has been executed. pub step: u64, + pub total_step: u64, + + /// cycle tracks the total cycle has been executed. + pub cycle: u64, + pub total_cycle: u64, /// A stream of input values (global to the entire program). pub input_stream: Vec>, @@ -111,6 +123,9 @@ impl State { heap: 0, local_user: 0, step: 0, + total_step: 0, + cycle: 0, + total_cycle: 0, brk: 0, input_stream: Vec::new(), input_stream_ptr: 0, @@ -135,6 +150,9 @@ impl State { heap: 0x20000000, local_user: 0, step: 0, + total_step: 0, + cycle: 0, + total_cycle: 0, brk: 0, input_stream: Vec::new(), input_stream_ptr: 0, @@ -406,6 +424,7 @@ impl State { .set_memory_range(0x31000004, data) .expect("set memory range failed"); + self.cycle += (data_len as u64 + 35) / 32; let len = data_len & 3; let end = data_len % POSEIDON_RATE_BYTES; @@ -535,6 +554,9 @@ impl InstrumentedState { ); log::debug!("input: {:?}", vec); assert_eq!(a0 % 4, 0, "hint read address not aligned to 4 bytes"); + if a1 >= 1 { + self.state.cycle += (a1 as u64 + 31) / 32; + } for i in (0..a1).step_by(4) { // Get each byte in the chunk let b1 = vec[i as usize]; @@ -837,6 +859,7 @@ impl InstrumentedState { } self.state.step += 1; + self.state.cycle += 1; // fetch instruction let insn = self.state.memory.get_memory(self.state.pc); @@ -1229,7 +1252,7 @@ impl InstrumentedState { ); } - pub fn step(&mut self) { + pub fn step(&mut self) -> u64 { let dump: bool = self.state.dump_info; self.state.dump_info = false; @@ -1241,6 +1264,8 @@ impl InstrumentedState { self.state.registers ); }; + + self.state.cycle + (self.state.memory.page_count() + 1) * PAGE_CYCLES + RESERVE_CYCLES } /// the caller should provide a write to write segemnt if proof is true @@ -1250,6 +1275,9 @@ impl InstrumentedState { output: &str, new_writer: fn(&str) -> Option, ) { + self.state.total_cycle += + self.state.cycle + (self.state.memory.page_count() + 1) * PAGE_CYCLES; + self.state.total_step += self.state.step; self.state.memory.update_page_hash(); let regiters = self.state.get_registers_bytes(); @@ -1270,6 +1298,7 @@ impl InstrumentedState { pre_image_id: self.pre_image_id, image_id, end_pc: self.state.pc, + step: self.state.step, page_hash_root, input_stream: self.pre_input.clone(), input_stream_ptr: self.pre_input_ptr, @@ -1291,6 +1320,8 @@ impl InstrumentedState { self.pre_pc = self.state.pc; self.pre_image_id = image_id; self.pre_hash_root = page_hash_root; + self.state.cycle = 0; + self.state.step = 0; } pub fn dump_memory(&mut self) { diff --git a/emulator/src/utils.rs b/emulator/src/utils.rs index 15e8f96e..eeb4cd82 100644 --- a/emulator/src/utils.rs +++ b/emulator/src/utils.rs @@ -3,7 +3,7 @@ use elf::{endian::AnyEndian, ElfBytes}; use std::fs; use std::fs::File; -pub const SEGMENT_STEPS: usize = 1024; +pub const SEGMENT_STEPS: usize = 65536; /// From the minigeth's rule, the `block` starts with `0_` pub fn get_block_path(basedir: &str, block: &str, file: &str) -> String { @@ -30,25 +30,26 @@ pub fn split_prog_into_segs( std::fs::create_dir_all(seg_path).unwrap(); let new_writer = |_: &str| -> Option { None }; instrumented_state.split_segment(false, seg_path, new_writer); - let mut segment_step: usize = seg_size; let new_writer = |name: &str| -> Option { File::create(name).ok() }; loop { if instrumented_state.state.exited { break; } - instrumented_state.step(); - segment_step -= 1; - if segment_step == 0 { - segment_step = seg_size; + let cycles = instrumented_state.step(); + if cycles >= seg_size as u64 { instrumented_state.split_segment(true, seg_path, new_writer); } } instrumented_state.split_segment(true, seg_path, new_writer); - log::info!("Split done {}", instrumented_state.state.step); + log::info!( + "Split done {} : {}", + instrumented_state.state.total_step, + instrumented_state.state.total_cycle + ); instrumented_state.dump_memory(); ( - instrumented_state.state.step as usize, + instrumented_state.state.total_step as usize, instrumented_state.state, ) } diff --git a/prover/examples/README.md b/prover/examples/README.md index ec062056..8fc94a60 100644 --- a/prover/examples/README.md +++ b/prover/examples/README.md @@ -13,21 +13,21 @@ GOOS=linux GOARCH=mips GOMIPS=softfloat go build hello.go * Split the ELF hello into segments. Note that the flag `BLOCK_NO` is only necessary for minigeth. ``` -BASEDIR=./emulator/test-vectors RUST_LOG=info ELF_PATH=./emulator/test-vectors/minigeth BLOCK_NO=13284491 SEG_OUTPUT=/tmp/output SEG_SIZE=1024 ARGS="" \ +BASEDIR=./emulator/test-vectors RUST_LOG=info ELF_PATH=./emulator/test-vectors/minigeth BLOCK_NO=13284491 SEG_OUTPUT=/tmp/output SEG_SIZE=65536 ARGS="" \ cargo run --release --example zkmips split ``` * Generate proof for specific segment (Set SEG_START_ID to specific segment id and set SEG_NUM to 1) ``` -BASEDIR=./emulator/test-vectors RUST_LOG=info BLOCK_NO=13284491 SEG_FILE_DIR="/tmp/output" SEG_START_ID=0 SEG_NUM=1 SEG_SIZE=1024 \ +BASEDIR=./emulator/test-vectors RUST_LOG=info BLOCK_NO=13284491 SEG_FILE_DIR="/tmp/output" SEG_START_ID=0 SEG_NUM=1 SEG_SIZE=65536 \ cargo run --release --example zkmips prove_segments ``` * Aggregate proof all segments (Set SEG_START_ID to 0, and set SEG_NUM to the total segments number) ``` -BASEDIR=./emulator/test-vectors RUST_LOG=info BLOCK_NO=13284491 SEG_FILE_DIR="/tmp/output" SEG_START_ID=0 SEG_NUM=299 SEG_SIZE=1024 \ +BASEDIR=./emulator/test-vectors RUST_LOG=info BLOCK_NO=13284491 SEG_FILE_DIR="/tmp/output" SEG_START_ID=0 SEG_NUM=299 SEG_SIZE=65536 \ cargo run --release --example zkmips prove_segments ``` diff --git a/prover/examples/zkmips.rs b/prover/examples/zkmips.rs index 5646a7af..7c4aed55 100644 --- a/prover/examples/zkmips.rs +++ b/prover/examples/zkmips.rs @@ -278,7 +278,7 @@ fn prove_sha2_rust() { let mut seg_num = 1usize; if seg_size != 0 { - seg_num = (total_steps + seg_size - 1) / seg_size; + seg_num = (total_steps + seg_size - 1).div_ceil(seg_size); } prove_multi_seg_common(&seg_path, "", "", "", seg_size, seg_num, 0).unwrap() } @@ -319,7 +319,7 @@ fn prove_sha2_go() { let mut seg_num = 1usize; if seg_size != 0 { - seg_num = (total_steps + seg_size - 1) / seg_size; + seg_num = (total_steps + seg_size - 1).div_ceil(seg_size); } prove_multi_seg_common(&seg_path, "", "", "", seg_size, seg_num, 0).unwrap() @@ -344,7 +344,7 @@ fn prove_revm() { let mut seg_num = 1usize; if seg_size != 0 { - seg_num = (total_steps + seg_size - 1) / seg_size; + seg_num = (total_steps + seg_size - 1).div_ceil(seg_size); } if seg_num == 1 { @@ -430,7 +430,7 @@ fn prove_add_example() { let mut seg_num = 1usize; if seg_size != 0 { - seg_num = (total_steps + seg_size - 1) / seg_size; + seg_num = (total_steps + seg_size - 1).div_ceil(seg_size); } if seg_num == 1 { diff --git a/prover/src/cpu/kernel/assembler.rs b/prover/src/cpu/kernel/assembler.rs index 8edf9220..57517912 100644 --- a/prover/src/cpu/kernel/assembler.rs +++ b/prover/src/cpu/kernel/assembler.rs @@ -29,12 +29,18 @@ pub fn segment_kernel( let p: Program = Program::load_segment(seg_reader).unwrap(); let blockpath = get_block_path(basedir, block, file); + let mut final_step = steps; + if p.step != 0 { + assert!(p.step <= steps); + final_step = p.step; + } + Kernel { program: p, ordered_labels: vec![], global_labels: HashMap::new(), blockpath, - steps, + steps: final_step, } } diff --git a/prover/src/cpu/kernel/elf.rs b/prover/src/cpu/kernel/elf.rs index 66485cab..e276863d 100644 --- a/prover/src/cpu/kernel/elf.rs +++ b/prover/src/cpu/kernel/elf.rs @@ -26,6 +26,7 @@ pub struct Program { pub brk: usize, pub local_user: usize, pub end_pc: usize, + pub step: usize, pub image_id: [u8; 32], pub pre_image_id: [u8; 32], pub pre_hash_root: [u8; 32], @@ -250,6 +251,7 @@ impl Program { brk: brk as usize, local_user: 0, end_pc: end_pc as usize, + step: 0, image_id: image_id.try_into().unwrap(), pre_image_id: pre_image_id.try_into().unwrap(), pre_hash_root, @@ -332,6 +334,7 @@ impl Program { brk, local_user, end_pc, + step: segment.step as usize, image_id: segment.image_id, pre_image_id: segment.pre_image_id, pre_hash_root: segment.pre_hash_root, diff --git a/runtime/entrypoint/src/syscalls/mod.rs b/runtime/entrypoint/src/syscalls/mod.rs index 4e4020f7..f08ab8e5 100644 --- a/runtime/entrypoint/src/syscalls/mod.rs +++ b/runtime/entrypoint/src/syscalls/mod.rs @@ -12,7 +12,7 @@ pub use sys::*; /// These codes MUST match the codes in `core/src/runtime/syscall.rs`. There is a derived test /// that checks that the enum is consistent with the syscalls. - +/// /// Halts the program. pub const HALT: u32 = 4246u32;