Skip to content

Commit

Permalink
feat: Optimize segment size, use cycle size instead (#168)
Browse files Browse the repository at this point in the history
* feat: use cycle instead of instructions for segment size

* fix return step in split_*_segs function

* fix clippy

* update default segment size to 65536

* fix bugs in step check
  • Loading branch information
weilzkm authored Sep 27, 2024
1 parent 9aed73f commit 80a0f88
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 20 deletions.
4 changes: 2 additions & 2 deletions emulator/src/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@ impl Memory {
}
}

pub fn page_count(&self) -> usize {
self.pages.len()
pub fn page_count(&self) -> u64 {
self.rtrace.len() as u64
}

pub fn for_each_page<T: Fn(u32, &Rc<RefCell<CachedPage>>) -> Result<(), String>>(
Expand Down
33 changes: 32 additions & 1 deletion emulator/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ pub const FD_HINT: u32 = 4;
pub const MIPS_EBADF: u32 = 9;

pub const REGISTERS_START: u32 = 0x81020400u32;
pub const PAGE_LOAD_CYCLES: u64 = 128;
pub const PAGE_HASH_CYCLES: u64 = 1;
pub const PAGE_CYCLES: u64 = PAGE_LOAD_CYCLES + PAGE_HASH_CYCLES;
pub const IMAGE_ID_CYCLES: u64 = 3;
pub const MAX_INSTRUCTION_CYCLES: u64 = PAGE_CYCLES * 6; //TOFIX
pub const RESERVE_CYCLES: u64 = IMAGE_ID_CYCLES + MAX_INSTRUCTION_CYCLES;

// image_id = keccak(page_hash_root || end_pc)
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone, Default)]
Expand All @@ -31,6 +37,7 @@ pub struct Segment {
pub image_id: [u8; 32],
pub page_hash_root: [u8; 32],
pub end_pc: u32,
pub step: u64,
pub input_stream: Vec<Vec<u8>>,
pub input_stream_ptr: usize,
pub public_values_stream: Vec<u8>,
Expand Down Expand Up @@ -62,6 +69,11 @@ pub struct State {

/// step tracks the total step has been executed.
pub step: u64,
pub total_step: u64,

/// cycle tracks the total cycle has been executed.
pub cycle: u64,
pub total_cycle: u64,

/// A stream of input values (global to the entire program).
pub input_stream: Vec<Vec<u8>>,
Expand Down Expand Up @@ -111,6 +123,9 @@ impl State {
heap: 0,
local_user: 0,
step: 0,
total_step: 0,
cycle: 0,
total_cycle: 0,
brk: 0,
input_stream: Vec::new(),
input_stream_ptr: 0,
Expand All @@ -135,6 +150,9 @@ impl State {
heap: 0x20000000,
local_user: 0,
step: 0,
total_step: 0,
cycle: 0,
total_cycle: 0,
brk: 0,
input_stream: Vec::new(),
input_stream_ptr: 0,
Expand Down Expand Up @@ -406,6 +424,7 @@ impl State {
.set_memory_range(0x31000004, data)
.expect("set memory range failed");

self.cycle += (data_len as u64 + 35) / 32;
let len = data_len & 3;
let end = data_len % POSEIDON_RATE_BYTES;

Expand Down Expand Up @@ -535,6 +554,9 @@ impl InstrumentedState {
);
log::debug!("input: {:?}", vec);
assert_eq!(a0 % 4, 0, "hint read address not aligned to 4 bytes");
if a1 >= 1 {
self.state.cycle += (a1 as u64 + 31) / 32;
}
for i in (0..a1).step_by(4) {
// Get each byte in the chunk
let b1 = vec[i as usize];
Expand Down Expand Up @@ -837,6 +859,7 @@ impl InstrumentedState {
}

self.state.step += 1;
self.state.cycle += 1;

// fetch instruction
let insn = self.state.memory.get_memory(self.state.pc);
Expand Down Expand Up @@ -1229,7 +1252,7 @@ impl InstrumentedState {
);
}

pub fn step(&mut self) {
pub fn step(&mut self) -> u64 {
let dump: bool = self.state.dump_info;
self.state.dump_info = false;

Expand All @@ -1241,6 +1264,8 @@ impl InstrumentedState {
self.state.registers
);
};

self.state.cycle + (self.state.memory.page_count() + 1) * PAGE_CYCLES + RESERVE_CYCLES
}

/// the caller should provide a write to write segemnt if proof is true
Expand All @@ -1250,6 +1275,9 @@ impl InstrumentedState {
output: &str,
new_writer: fn(&str) -> Option<W>,
) {
self.state.total_cycle +=
self.state.cycle + (self.state.memory.page_count() + 1) * PAGE_CYCLES;
self.state.total_step += self.state.step;
self.state.memory.update_page_hash();
let regiters = self.state.get_registers_bytes();

Expand All @@ -1270,6 +1298,7 @@ impl InstrumentedState {
pre_image_id: self.pre_image_id,
image_id,
end_pc: self.state.pc,
step: self.state.step,
page_hash_root,
input_stream: self.pre_input.clone(),
input_stream_ptr: self.pre_input_ptr,
Expand All @@ -1291,6 +1320,8 @@ impl InstrumentedState {
self.pre_pc = self.state.pc;
self.pre_image_id = image_id;
self.pre_hash_root = page_hash_root;
self.state.cycle = 0;
self.state.step = 0;
}

pub fn dump_memory(&mut self) {
Expand Down
17 changes: 9 additions & 8 deletions emulator/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use elf::{endian::AnyEndian, ElfBytes};
use std::fs;
use std::fs::File;

pub const SEGMENT_STEPS: usize = 1024;
pub const SEGMENT_STEPS: usize = 65536;

/// From the minigeth's rule, the `block` starts with `0_`
pub fn get_block_path(basedir: &str, block: &str, file: &str) -> String {
Expand All @@ -30,25 +30,26 @@ pub fn split_prog_into_segs(
std::fs::create_dir_all(seg_path).unwrap();
let new_writer = |_: &str| -> Option<std::fs::File> { None };
instrumented_state.split_segment(false, seg_path, new_writer);
let mut segment_step: usize = seg_size;
let new_writer = |name: &str| -> Option<std::fs::File> { File::create(name).ok() };
loop {
if instrumented_state.state.exited {
break;
}
instrumented_state.step();
segment_step -= 1;
if segment_step == 0 {
segment_step = seg_size;
let cycles = instrumented_state.step();
if cycles >= seg_size as u64 {
instrumented_state.split_segment(true, seg_path, new_writer);
}
}
instrumented_state.split_segment(true, seg_path, new_writer);
log::info!("Split done {}", instrumented_state.state.step);
log::info!(
"Split done {} : {}",
instrumented_state.state.total_step,
instrumented_state.state.total_cycle
);

instrumented_state.dump_memory();
(
instrumented_state.state.step as usize,
instrumented_state.state.total_step as usize,
instrumented_state.state,
)
}
6 changes: 3 additions & 3 deletions prover/examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,21 @@ GOOS=linux GOARCH=mips GOMIPS=softfloat go build hello.go
* Split the ELF hello into segments. Note that the flag `BLOCK_NO` is only necessary for minigeth.

```
BASEDIR=./emulator/test-vectors RUST_LOG=info ELF_PATH=./emulator/test-vectors/minigeth BLOCK_NO=13284491 SEG_OUTPUT=/tmp/output SEG_SIZE=1024 ARGS="" \
BASEDIR=./emulator/test-vectors RUST_LOG=info ELF_PATH=./emulator/test-vectors/minigeth BLOCK_NO=13284491 SEG_OUTPUT=/tmp/output SEG_SIZE=65536 ARGS="" \
cargo run --release --example zkmips split
```

* Generate proof for specific segment (Set SEG_START_ID to specific segment id and set SEG_NUM to 1)

```
BASEDIR=./emulator/test-vectors RUST_LOG=info BLOCK_NO=13284491 SEG_FILE_DIR="/tmp/output" SEG_START_ID=0 SEG_NUM=1 SEG_SIZE=1024 \
BASEDIR=./emulator/test-vectors RUST_LOG=info BLOCK_NO=13284491 SEG_FILE_DIR="/tmp/output" SEG_START_ID=0 SEG_NUM=1 SEG_SIZE=65536 \
cargo run --release --example zkmips prove_segments
```

* Aggregate proof all segments (Set SEG_START_ID to 0, and set SEG_NUM to the total segments number)

```
BASEDIR=./emulator/test-vectors RUST_LOG=info BLOCK_NO=13284491 SEG_FILE_DIR="/tmp/output" SEG_START_ID=0 SEG_NUM=299 SEG_SIZE=1024 \
BASEDIR=./emulator/test-vectors RUST_LOG=info BLOCK_NO=13284491 SEG_FILE_DIR="/tmp/output" SEG_START_ID=0 SEG_NUM=299 SEG_SIZE=65536 \
cargo run --release --example zkmips prove_segments
```

Expand Down
8 changes: 4 additions & 4 deletions prover/examples/zkmips.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ fn prove_sha2_rust() {

let mut seg_num = 1usize;
if seg_size != 0 {
seg_num = (total_steps + seg_size - 1) / seg_size;
seg_num = (total_steps + seg_size - 1).div_ceil(seg_size);
}
prove_multi_seg_common(&seg_path, "", "", "", seg_size, seg_num, 0).unwrap()
}
Expand Down Expand Up @@ -319,7 +319,7 @@ fn prove_sha2_go() {

let mut seg_num = 1usize;
if seg_size != 0 {
seg_num = (total_steps + seg_size - 1) / seg_size;
seg_num = (total_steps + seg_size - 1).div_ceil(seg_size);
}

prove_multi_seg_common(&seg_path, "", "", "", seg_size, seg_num, 0).unwrap()
Expand All @@ -344,7 +344,7 @@ fn prove_revm() {

let mut seg_num = 1usize;
if seg_size != 0 {
seg_num = (total_steps + seg_size - 1) / seg_size;
seg_num = (total_steps + seg_size - 1).div_ceil(seg_size);
}

if seg_num == 1 {
Expand Down Expand Up @@ -430,7 +430,7 @@ fn prove_add_example() {

let mut seg_num = 1usize;
if seg_size != 0 {
seg_num = (total_steps + seg_size - 1) / seg_size;
seg_num = (total_steps + seg_size - 1).div_ceil(seg_size);
}

if seg_num == 1 {
Expand Down
8 changes: 7 additions & 1 deletion prover/src/cpu/kernel/assembler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,18 @@ pub fn segment_kernel<T: Read>(
let p: Program = Program::load_segment(seg_reader).unwrap();
let blockpath = get_block_path(basedir, block, file);

let mut final_step = steps;
if p.step != 0 {
assert!(p.step <= steps);
final_step = p.step;
}

Kernel {
program: p,
ordered_labels: vec![],
global_labels: HashMap::new(),
blockpath,
steps,
steps: final_step,
}
}

Expand Down
3 changes: 3 additions & 0 deletions prover/src/cpu/kernel/elf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub struct Program {
pub brk: usize,
pub local_user: usize,
pub end_pc: usize,
pub step: usize,
pub image_id: [u8; 32],
pub pre_image_id: [u8; 32],
pub pre_hash_root: [u8; 32],
Expand Down Expand Up @@ -250,6 +251,7 @@ impl Program {
brk: brk as usize,
local_user: 0,
end_pc: end_pc as usize,
step: 0,
image_id: image_id.try_into().unwrap(),
pre_image_id: pre_image_id.try_into().unwrap(),
pre_hash_root,
Expand Down Expand Up @@ -332,6 +334,7 @@ impl Program {
brk,
local_user,
end_pc,
step: segment.step as usize,
image_id: segment.image_id,
pre_image_id: segment.pre_image_id,
pre_hash_root: segment.pre_hash_root,
Expand Down
2 changes: 1 addition & 1 deletion runtime/entrypoint/src/syscalls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub use sys::*;

/// These codes MUST match the codes in `core/src/runtime/syscall.rs`. There is a derived test
/// that checks that the enum is consistent with the syscalls.

///
/// Halts the program.
pub const HALT: u32 = 4246u32;

Expand Down

0 comments on commit 80a0f88

Please sign in to comment.