Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Optimize segment size, use cycle size instead #168

Merged
merged 5 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions emulator/src/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@ impl Memory {
}
}

pub fn page_count(&self) -> usize {
self.pages.len()
pub fn page_count(&self) -> u64 {
self.rtrace.len() as u64
}

pub fn for_each_page<T: Fn(u32, &Rc<RefCell<CachedPage>>) -> Result<(), String>>(
Expand Down
33 changes: 32 additions & 1 deletion emulator/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ pub const FD_HINT: u32 = 4;
pub const MIPS_EBADF: u32 = 9;

pub const REGISTERS_START: u32 = 0x81020400u32;
pub const PAGE_LOAD_CYCLES: u64 = 128;
pub const PAGE_HASH_CYCLES: u64 = 1;
pub const PAGE_CYCLES: u64 = PAGE_LOAD_CYCLES + PAGE_HASH_CYCLES;
pub const IMAGE_ID_CYCLES: u64 = 3;
pub const MAX_INSTRUCTION_CYCLES: u64 = PAGE_CYCLES * 6; //TOFIX
pub const RESERVE_CYCLES: u64 = IMAGE_ID_CYCLES + MAX_INSTRUCTION_CYCLES;

// image_id = keccak(page_hash_root || end_pc)
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone, Default)]
Expand All @@ -31,6 +37,7 @@ pub struct Segment {
pub image_id: [u8; 32],
pub page_hash_root: [u8; 32],
pub end_pc: u32,
pub step: u64,
pub input_stream: Vec<Vec<u8>>,
pub input_stream_ptr: usize,
pub public_values_stream: Vec<u8>,
Expand Down Expand Up @@ -62,6 +69,11 @@ pub struct State {

/// step tracks the total step has been executed.
pub step: u64,
pub total_step: u64,

/// cycle tracks the total cycle has been executed.
pub cycle: u64,
pub total_cycle: u64,

/// A stream of input values (global to the entire program).
pub input_stream: Vec<Vec<u8>>,
Expand Down Expand Up @@ -111,6 +123,9 @@ impl State {
heap: 0,
local_user: 0,
step: 0,
total_step: 0,
cycle: 0,
total_cycle: 0,
brk: 0,
input_stream: Vec::new(),
input_stream_ptr: 0,
Expand All @@ -135,6 +150,9 @@ impl State {
heap: 0x20000000,
local_user: 0,
step: 0,
total_step: 0,
cycle: 0,
total_cycle: 0,
brk: 0,
input_stream: Vec::new(),
input_stream_ptr: 0,
Expand Down Expand Up @@ -406,6 +424,7 @@ impl State {
.set_memory_range(0x31000004, data)
.expect("set memory range failed");

self.cycle += (data_len as u64 + 35) / 32;
let len = data_len & 3;
let end = data_len % POSEIDON_RATE_BYTES;

Expand Down Expand Up @@ -535,6 +554,9 @@ impl InstrumentedState {
);
log::debug!("input: {:?}", vec);
assert_eq!(a0 % 4, 0, "hint read address not aligned to 4 bytes");
if a1 >= 1 {
self.state.cycle += (a1 as u64 + 31) / 32;
}
for i in (0..a1).step_by(4) {
// Get each byte in the chunk
let b1 = vec[i as usize];
Expand Down Expand Up @@ -837,6 +859,7 @@ impl InstrumentedState {
}

self.state.step += 1;
self.state.cycle += 1;

// fetch instruction
let insn = self.state.memory.get_memory(self.state.pc);
Expand Down Expand Up @@ -1229,7 +1252,7 @@ impl InstrumentedState {
);
}

pub fn step(&mut self) {
pub fn step(&mut self) -> u64 {
let dump: bool = self.state.dump_info;
self.state.dump_info = false;

Expand All @@ -1241,6 +1264,8 @@ impl InstrumentedState {
self.state.registers
);
};

self.state.cycle + (self.state.memory.page_count() + 1) * PAGE_CYCLES + RESERVE_CYCLES
}

/// the caller should provide a write to write segemnt if proof is true
Expand All @@ -1250,6 +1275,9 @@ impl InstrumentedState {
output: &str,
new_writer: fn(&str) -> Option<W>,
) {
self.state.total_cycle +=
self.state.cycle + (self.state.memory.page_count() + 1) * PAGE_CYCLES;
self.state.total_step += self.state.step;
self.state.memory.update_page_hash();
let regiters = self.state.get_registers_bytes();

Expand All @@ -1270,6 +1298,7 @@ impl InstrumentedState {
pre_image_id: self.pre_image_id,
image_id,
end_pc: self.state.pc,
step: self.state.step,
page_hash_root,
input_stream: self.pre_input.clone(),
input_stream_ptr: self.pre_input_ptr,
Expand All @@ -1291,6 +1320,8 @@ impl InstrumentedState {
self.pre_pc = self.state.pc;
self.pre_image_id = image_id;
self.pre_hash_root = page_hash_root;
self.state.cycle = 0;
self.state.step = 0;
}

pub fn dump_memory(&mut self) {
Expand Down
17 changes: 9 additions & 8 deletions emulator/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use elf::{endian::AnyEndian, ElfBytes};
use std::fs;
use std::fs::File;

pub const SEGMENT_STEPS: usize = 1024;
pub const SEGMENT_STEPS: usize = 65536;

/// From the minigeth's rule, the `block` starts with `0_`
pub fn get_block_path(basedir: &str, block: &str, file: &str) -> String {
Expand All @@ -30,25 +30,26 @@ pub fn split_prog_into_segs(
std::fs::create_dir_all(seg_path).unwrap();
let new_writer = |_: &str| -> Option<std::fs::File> { None };
instrumented_state.split_segment(false, seg_path, new_writer);
let mut segment_step: usize = seg_size;
let new_writer = |name: &str| -> Option<std::fs::File> { File::create(name).ok() };
loop {
if instrumented_state.state.exited {
break;
}
instrumented_state.step();
segment_step -= 1;
if segment_step == 0 {
segment_step = seg_size;
let cycles = instrumented_state.step();
if cycles >= seg_size as u64 {
instrumented_state.split_segment(true, seg_path, new_writer);
}
}
instrumented_state.split_segment(true, seg_path, new_writer);
log::info!("Split done {}", instrumented_state.state.step);
log::info!(
"Split done {} : {}",
instrumented_state.state.total_step,
instrumented_state.state.total_cycle
);

instrumented_state.dump_memory();
(
instrumented_state.state.step as usize,
instrumented_state.state.total_step as usize,
instrumented_state.state,
)
}
6 changes: 3 additions & 3 deletions prover/examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,21 @@ GOOS=linux GOARCH=mips GOMIPS=softfloat go build hello.go
* Split the ELF hello into segments. Note that the flag `BLOCK_NO` is only necessary for minigeth.

```
BASEDIR=./emulator/test-vectors RUST_LOG=info ELF_PATH=./emulator/test-vectors/minigeth BLOCK_NO=13284491 SEG_OUTPUT=/tmp/output SEG_SIZE=1024 ARGS="" \
BASEDIR=./emulator/test-vectors RUST_LOG=info ELF_PATH=./emulator/test-vectors/minigeth BLOCK_NO=13284491 SEG_OUTPUT=/tmp/output SEG_SIZE=65536 ARGS="" \
cargo run --release --example zkmips split
```

* Generate proof for specific segment (Set SEG_START_ID to specific segment id and set SEG_NUM to 1)

```
BASEDIR=./emulator/test-vectors RUST_LOG=info BLOCK_NO=13284491 SEG_FILE_DIR="/tmp/output" SEG_START_ID=0 SEG_NUM=1 SEG_SIZE=1024 \
BASEDIR=./emulator/test-vectors RUST_LOG=info BLOCK_NO=13284491 SEG_FILE_DIR="/tmp/output" SEG_START_ID=0 SEG_NUM=1 SEG_SIZE=65536 \
cargo run --release --example zkmips prove_segments
```

* Aggregate proof all segments (Set SEG_START_ID to 0, and set SEG_NUM to the total segments number)

```
BASEDIR=./emulator/test-vectors RUST_LOG=info BLOCK_NO=13284491 SEG_FILE_DIR="/tmp/output" SEG_START_ID=0 SEG_NUM=299 SEG_SIZE=1024 \
BASEDIR=./emulator/test-vectors RUST_LOG=info BLOCK_NO=13284491 SEG_FILE_DIR="/tmp/output" SEG_START_ID=0 SEG_NUM=299 SEG_SIZE=65536 \
cargo run --release --example zkmips prove_segments
```

Expand Down
8 changes: 4 additions & 4 deletions prover/examples/zkmips.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ fn prove_sha2_rust() {

let mut seg_num = 1usize;
if seg_size != 0 {
seg_num = (total_steps + seg_size - 1) / seg_size;
seg_num = (total_steps + seg_size - 1).div_ceil(seg_size);
}
prove_multi_seg_common(&seg_path, "", "", "", seg_size, seg_num, 0).unwrap()
}
Expand Down Expand Up @@ -319,7 +319,7 @@ fn prove_sha2_go() {

let mut seg_num = 1usize;
if seg_size != 0 {
seg_num = (total_steps + seg_size - 1) / seg_size;
seg_num = (total_steps + seg_size - 1).div_ceil(seg_size);
}

prove_multi_seg_common(&seg_path, "", "", "", seg_size, seg_num, 0).unwrap()
Expand All @@ -344,7 +344,7 @@ fn prove_revm() {

let mut seg_num = 1usize;
if seg_size != 0 {
seg_num = (total_steps + seg_size - 1) / seg_size;
seg_num = (total_steps + seg_size - 1).div_ceil(seg_size);
}

if seg_num == 1 {
Expand Down Expand Up @@ -430,7 +430,7 @@ fn prove_add_example() {

let mut seg_num = 1usize;
if seg_size != 0 {
seg_num = (total_steps + seg_size - 1) / seg_size;
seg_num = (total_steps + seg_size - 1).div_ceil(seg_size);
}

if seg_num == 1 {
Expand Down
8 changes: 7 additions & 1 deletion prover/src/cpu/kernel/assembler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,18 @@ pub fn segment_kernel<T: Read>(
let p: Program = Program::load_segment(seg_reader).unwrap();
let blockpath = get_block_path(basedir, block, file);

let mut final_step = steps;
if p.step != 0 {
assert!(p.step <= steps);
final_step = p.step;
}

Kernel {
program: p,
ordered_labels: vec![],
global_labels: HashMap::new(),
blockpath,
steps,
steps: final_step,
}
}

Expand Down
3 changes: 3 additions & 0 deletions prover/src/cpu/kernel/elf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub struct Program {
pub brk: usize,
pub local_user: usize,
pub end_pc: usize,
pub step: usize,
pub image_id: [u8; 32],
pub pre_image_id: [u8; 32],
pub pre_hash_root: [u8; 32],
Expand Down Expand Up @@ -250,6 +251,7 @@ impl Program {
brk: brk as usize,
local_user: 0,
end_pc: end_pc as usize,
step: 0,
image_id: image_id.try_into().unwrap(),
pre_image_id: pre_image_id.try_into().unwrap(),
pre_hash_root,
Expand Down Expand Up @@ -332,6 +334,7 @@ impl Program {
brk,
local_user,
end_pc,
step: segment.step as usize,
image_id: segment.image_id,
pre_image_id: segment.pre_image_id,
pre_hash_root: segment.pre_hash_root,
Expand Down
2 changes: 1 addition & 1 deletion runtime/entrypoint/src/syscalls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub use sys::*;

/// These codes MUST match the codes in `core/src/runtime/syscall.rs`. There is a derived test
/// that checks that the enum is consistent with the syscalls.

///
/// Halts the program.
pub const HALT: u32 = 4246u32;

Expand Down
Loading