diff --git a/Cargo.lock b/Cargo.lock index 326ad3e7..5b736ab7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1039,9 +1039,9 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.12" +version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a73e9fe3c49d7afb2ace819fa181a287ce54a0983eda4e0eb05c22f82ffe534" +checksum = "540654e97a3f4470a492cd30ff187bc95d89557a903a2bbf112e2fae98104ef2" [[package]] name = "jobserver" @@ -1467,7 +1467,6 @@ dependencies = [ [[package]] name = "pil-std-lib" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.10#cb182461a9e8dc4077be76b81733b5384d640a21" dependencies = [ "log", "num-bigint", @@ -1485,7 +1484,6 @@ dependencies = [ [[package]] name = "pilout" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.10#cb182461a9e8dc4077be76b81733b5384d640a21" dependencies = [ "bytes", "log", @@ -1595,9 +1593,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.89" +version = "1.0.92" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f139b0662de085916d1fb67d2b4169d1addddda1919e696f3252b740b629986e" +checksum = "37d3544b3f2748c54e147655edb5025752e2303145b5aefb3c3ea2c78b973bb0" dependencies = [ "unicode-ident", ] @@ -1605,7 +1603,6 @@ dependencies = [ [[package]] name = "proofman" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.10#cb182461a9e8dc4077be76b81733b5384d640a21" dependencies = [ "colored", "env_logger", @@ -1626,11 +1623,11 @@ dependencies = [ [[package]] name = "proofman-common" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.10#cb182461a9e8dc4077be76b81733b5384d640a21" dependencies = [ "env_logger", "log", "p3-field", + "p3-goldilocks", "pilout", "proofman-macros", "proofman-starks-lib-c", @@ -1644,7 +1641,6 @@ dependencies = [ [[package]] name = "proofman-hints" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.10#cb182461a9e8dc4077be76b81733b5384d640a21" dependencies = [ "p3-field", "proofman-common", @@ -1654,7 +1650,6 @@ dependencies = [ [[package]] name = "proofman-macros" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.10#cb182461a9e8dc4077be76b81733b5384d640a21" dependencies = [ "proc-macro2", "quote", @@ -1664,7 +1659,6 @@ dependencies = [ [[package]] name = "proofman-starks-lib-c" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.10#cb182461a9e8dc4077be76b81733b5384d640a21" dependencies = [ "log", ] @@ -1672,7 +1666,6 @@ dependencies = [ [[package]] name = "proofman-util" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.10#cb182461a9e8dc4077be76b81733b5384d640a21" dependencies = [ "colored", "sysinfo 0.31.4", @@ -2242,10 +2235,14 @@ name = "sm-mem" version = "0.1.0" dependencies = [ "log", + "num-bigint", + "num-traits", "p3-field", + "pil-std-lib", "proofman", "proofman-common", "proofman-macros", + "proofman-util", "rayon", "sm-common", "zisk-core", @@ -2316,7 +2313,6 @@ checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" [[package]] name = "stark" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.10#cb182461a9e8dc4077be76b81733b5384d640a21" dependencies = [ "log", "p3-field", @@ -2380,9 +2376,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.87" +version = "2.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d" +checksum = "44d46482f1c1c87acd84dea20c1bf5ebff4c757009ed6bf19cfd36fb10e92c4e" dependencies = [ "proc-macro2", "quote", @@ -2391,9 +2387,9 @@ dependencies = [ [[package]] name = "sync_wrapper" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394" +checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" dependencies = [ "futures-core", ] @@ -2666,7 +2662,6 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.10#cb182461a9e8dc4077be76b81733b5384d640a21" dependencies = [ "proofman-starks-lib-c", ] @@ -2898,9 +2893,9 @@ dependencies = [ [[package]] name = "webpki-roots" -version = "0.26.6" +version = "0.26.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "841c67bff177718f1d4dfefde8d8f0e78f9b6589319ba88312f567fc5841a958" +checksum = "5d642ff16b7e79272ae451b7322067cdc17cadf68c23264be9d94a32319efe7e" dependencies = [ "rustls-pki-types", ] diff --git a/Cargo.toml b/Cargo.toml index b97f5cb2..1f8e4798 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,19 +26,19 @@ opt-level = 3 opt-level = 3 [workspace.dependencies] -proofman-common = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.10" } -proofman-macros = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.10" } -proofman-util = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.10" } -proofman = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.10" } -pil-std-lib = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.10" } -stark = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.10" } +# proofman-common = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.10" } +# proofman-macros = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.10" } +# proofman-util = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.10" } +# proofman = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.10" } +# pil-std-lib = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.10" } +# stark = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.10" } # Local development -# proofman-common = { path = "../pil2-proofman/common" } -# proofman-macros = { path = "../pil2-proofman/macros" } -# proofman-util = { path = "../pil2-proofman/util" } -# proofman = { path = "../pil2-proofman/proofman" } -# pil-std-lib = { path = "../pil2-proofman/pil2-components/lib/std/rs" } -# stark = { path = "../pil2-proofman/provers/stark" } +proofman-common = { path = "../pil2-proofman/common" } +proofman-macros = { path = "../pil2-proofman/macros" } +proofman-util = { path = "../pil2-proofman/util" } +proofman = { path = "../pil2-proofman/proofman" } +pil-std-lib = { path = "../pil2-proofman/pil2-components/lib/std/rs" } +stark = { path = "../pil2-proofman/provers/stark" } p3-field = { git = "https://github.com/Plonky3/Plonky3.git", rev = "c3d754ef77b9fce585b46b972af751fe6e7a9803" } log = "0.4" diff --git a/core/src/zisk_required_operation.rs b/core/src/zisk_required_operation.rs index 59a7aee6..04410a0f 100644 --- a/core/src/zisk_required_operation.rs +++ b/core/src/zisk_required_operation.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::{collections::HashMap, fmt}; #[derive(Clone)] pub struct ZiskRequiredOperation { @@ -10,13 +10,30 @@ pub struct ZiskRequiredOperation { #[derive(Clone)] pub struct ZiskRequiredMemory { - pub step: u64, + pub address: u32, pub is_write: bool, - pub address: u64, - pub width: u64, + pub width: u8, + pub step_offset: u8, + pub step: u64, pub value: u64, } +impl fmt::Debug for ZiskRequiredMemory { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let label = if self.is_write { "WR" } else { "RD" }; + write!( + f, + "{0} addr:{1:#08X}({1}) offset:{5} with:{2} value:{3:#016X}({3}) step:{4}", + label, + self.address, + self.width, + self.value, + self.step, + self.address & 0x07 + ) + } +} + #[derive(Clone, Default)] pub struct ZiskRequired { pub arith: Vec, diff --git a/emulator/src/emu.rs b/emulator/src/emu.rs index 76b11fab..9c82aaff 100644 --- a/emulator/src/emu.rs +++ b/emulator/src/emu.rs @@ -9,9 +9,9 @@ use riscv::RiscVRegisters; // #[cfg(feature = "sp")] // use zisk_core::SRC_SP; use zisk_core::{ - InstContext, ZiskInst, ZiskOperationType, ZiskPcHistogram, ZiskRequiredOperation, ZiskRom, - OUTPUT_ADDR, ROM_ENTRY, SRC_C, SRC_IMM, SRC_IND, SRC_MEM, SRC_STEP, STORE_IND, STORE_MEM, - STORE_NONE, SYS_ADDR, ZISK_OPERATION_TYPE_VARIANTS, + InstContext, ZiskInst, ZiskOperationType, ZiskPcHistogram, ZiskRequiredMemory, + ZiskRequiredOperation, ZiskRom, OUTPUT_ADDR, ROM_ENTRY, SRC_C, SRC_IMM, SRC_IND, SRC_MEM, + SRC_STEP, STORE_IND, STORE_MEM, STORE_NONE, SYS_ADDR, ZISK_OPERATION_TYPE_VARIANTS, }; /// ZisK emulator structure, containing the ZisK rom, the list of ZisK operations, and the @@ -92,6 +92,46 @@ impl<'a> Emu<'a> { } } + /// Calculate the 'a' register value based on the source specified by the current instruction + #[inline(always)] + pub fn source_a_memory( + &mut self, + instruction: &ZiskInst, + emu_mem: &mut Vec, + ) { + match instruction.a_src { + SRC_C => self.ctx.inst_ctx.a = self.ctx.inst_ctx.c, + SRC_MEM => { + let mut addr = instruction.a_offset_imm0; + if instruction.a_use_sp_imm1 != 0 { + addr += self.ctx.inst_ctx.sp; + } + self.ctx.inst_ctx.a = self.ctx.inst_ctx.mem.read(addr, 8); + + let required_memory = ZiskRequiredMemory { + step: self.ctx.inst_ctx.step, + step_offset: 0, + is_write: false, + address: addr as u32, + width: 8, + value: self.ctx.inst_ctx.a, + }; + + emu_mem.push(required_memory); + } + SRC_IMM => { + self.ctx.inst_ctx.a = instruction.a_offset_imm0 | (instruction.a_use_sp_imm1 << 32) + } + SRC_STEP => self.ctx.inst_ctx.a = self.ctx.inst_ctx.step, + // #[cfg(feature = "sp")] + // SRC_SP => self.ctx.inst_ctx.a = self.ctx.inst_ctx.sp, + _ => panic!( + "Emu::source_a() Invalid a_src={} pc={}", + instruction.a_src, self.ctx.inst_ctx.pc + ), + } + } + /// Calculate the 'b' register value based on the source specified by the current instruction #[inline(always)] pub fn source_b(&mut self, instruction: &ZiskInst) { @@ -128,6 +168,59 @@ impl<'a> Emu<'a> { } } + /// Calculate the 'b' register value based on the source specified by the current instruction + #[inline(always)] + pub fn source_b_memory( + &mut self, + instruction: &ZiskInst, + emu_mem: &mut Vec, + ) { + match instruction.b_src { + SRC_C => self.ctx.inst_ctx.b = self.ctx.inst_ctx.c, + SRC_MEM => { + let mut addr = instruction.b_offset_imm0; + if instruction.b_use_sp_imm1 != 0 { + addr += self.ctx.inst_ctx.sp; + } + self.ctx.inst_ctx.b = self.ctx.inst_ctx.mem.read(addr, 8); + + let required_memory = ZiskRequiredMemory { + step: self.ctx.inst_ctx.step, + step_offset: 1, + is_write: false, + address: addr as u32, + width: 8, + value: self.ctx.inst_ctx.b, + }; + emu_mem.push(required_memory); + } + SRC_IMM => { + self.ctx.inst_ctx.b = instruction.b_offset_imm0 | (instruction.b_use_sp_imm1 << 32) + } + SRC_IND => { + let mut addr = + (self.ctx.inst_ctx.a as i64 + instruction.b_offset_imm0 as i64) as u64; + if instruction.b_use_sp_imm1 != 0 { + addr += self.ctx.inst_ctx.sp; + } + self.ctx.inst_ctx.b = self.ctx.inst_ctx.mem.read(addr, instruction.ind_width); + let required_memory = ZiskRequiredMemory { + step: self.ctx.inst_ctx.step, + step_offset: 1, + is_write: false, + address: addr as u32, + width: instruction.ind_width as u8, + value: self.ctx.inst_ctx.b, + }; + emu_mem.push(required_memory); + } + _ => panic!( + "Emu::source_b() Invalid b_src={} pc={}", + instruction.b_src, self.ctx.inst_ctx.pc + ), + } + } + /// Store the 'c' register value based on the storage specified by the current instruction #[inline(always)] pub fn store_c(&mut self, instruction: &ZiskInst) { @@ -171,6 +264,72 @@ impl<'a> Emu<'a> { } } + /// Store the 'c' register value based on the storage specified by the current instruction + #[inline(always)] + pub fn store_c_memory( + &mut self, + instruction: &ZiskInst, + emu_mem: &mut Vec, + ) { + match instruction.store { + STORE_NONE => {} + STORE_MEM => { + let val: i64 = if instruction.store_ra { + self.ctx.inst_ctx.pc as i64 + instruction.jmp_offset2 + } else { + self.ctx.inst_ctx.c as i64 + }; + let mut addr: i64 = instruction.store_offset; + if instruction.store_use_sp { + addr += self.ctx.inst_ctx.sp as i64; + } + self.ctx.inst_ctx.mem.write_silent(addr as u64, val as u64, 8); + + let required_memory = ZiskRequiredMemory { + step: self.ctx.inst_ctx.step, + step_offset: 2, + is_write: true, + address: addr as u32, + width: 8, + value: val as u64, + }; + emu_mem.push(required_memory); + } + STORE_IND => { + let val: i64 = if instruction.store_ra { + self.ctx.inst_ctx.pc as i64 + instruction.jmp_offset2 + } else { + self.ctx.inst_ctx.c as i64 + }; + let mut addr = instruction.store_offset; + if instruction.store_use_sp { + addr += self.ctx.inst_ctx.sp as i64; + } + addr += self.ctx.inst_ctx.a as i64; + self.ctx.inst_ctx.mem.write_silent(addr as u64, val as u64, instruction.ind_width); + + let required_memory = ZiskRequiredMemory { + step: self.ctx.inst_ctx.step, + step_offset: 2, + is_write: true, + address: addr as u32, + width: instruction.ind_width as u8, + value: val as u64, + }; + emu_mem.push(required_memory); + } + _ => panic!( + "Emu::store_c() Invalid store={} pc={}", + instruction.store, self.ctx.inst_ctx.pc + ), + } + } + + #[inline(always)] + fn is_8_aligned(address: u64, width: u64) -> bool { + address & 7 == 0 && width == 8 + } + /// Store the 'c' register value based on the storage specified by the current instruction and /// log memory access if required #[inline(always)] @@ -449,6 +608,22 @@ impl<'a> Emu<'a> { (emu_traces, emu_segments) } + pub fn par_run_memory(&mut self, inputs: Vec) -> Vec { + // Context, where the state of the execution is stored and modified at every execution step + self.ctx = self.create_emu_context(inputs); + + // Init pc to the rom entry address + self.ctx.trace.start_state.pc = ROM_ENTRY; + + let mut emu_mem = Vec::new(); + + while !self.ctx.inst_ctx.end { + self.par_step_memory::(&mut emu_mem); + } + + emu_mem + } + /// Performs one single step of the emulation #[inline(always)] #[allow(unused_variables)] @@ -622,6 +797,41 @@ impl<'a> Emu<'a> { self.ctx.inst_ctx.step += 1; } + /// Performs one single step of the emulation + #[inline(always)] + #[allow(unused_variables)] + pub fn par_step_memory(&mut self, emu_mem: &mut Vec) { + let last_pc = self.ctx.inst_ctx.pc; + let last_c = self.ctx.inst_ctx.c; + + let instruction = self.rom.get_instruction(self.ctx.inst_ctx.pc); + + // Build the 'a' register value based on the source specified by the current instruction + self.source_a_memory(instruction, emu_mem); + + // Build the 'b' register value based on the source specified by the current instruction + self.source_b_memory(instruction, emu_mem); + + // Call the operation + (instruction.func)(&mut self.ctx.inst_ctx); + + // Store the 'c' register value based on the storage specified by the current instruction + self.store_c_memory(instruction, emu_mem); + + // Set SP, if specified by the current instruction + // #[cfg(feature = "sp")] + // self.set_sp(instruction); + + // Set PC, based on current PC, current flag and current instruction + self.set_pc(instruction); + + // If this is the last instruction, stop executing + self.ctx.inst_ctx.end = instruction.end; + + // Increment step counter + self.ctx.inst_ctx.step += 1; + } + /// Performs one single step of the emulation #[inline(always)] #[allow(unused_variables)] diff --git a/emulator/src/emulator.rs b/emulator/src/emulator.rs index d7996937..4b24cb27 100644 --- a/emulator/src/emulator.rs +++ b/emulator/src/emulator.rs @@ -11,8 +11,8 @@ use std::{ }; use sysinfo::System; use zisk_core::{ - Riscv2zisk, ZiskOperationType, ZiskPcHistogram, ZiskRequiredOperation, ZiskRom, - ZISK_OPERATION_TYPE_VARIANTS, + Riscv2zisk, ZiskOperationType, ZiskPcHistogram, ZiskRequiredMemory, ZiskRequiredOperation, + ZiskRom, ZISK_OPERATION_TYPE_VARIANTS, }; pub trait Emulator { @@ -240,6 +240,22 @@ impl ZiskEmulator { Ok((vec_traces, emu_slices)) } + pub fn par_process_rom_memory( + rom: &ZiskRom, + inputs: &[u8], + ) -> Result, ZiskEmulatorErr> { + let mut emu = Emu::new(rom); + let result = emu.par_run_memory::(inputs.to_owned()); + + if !emu.terminated() { + panic!("Emulation did not complete"); + // TODO! + // return Err(ZiskEmulatorErr::EmulationNoCompleted); + } + + Ok(result) + } + #[inline] pub fn process_slice_required( rom: &ZiskRom, diff --git a/pil/src/lib.rs b/pil/src/lib.rs index aee8bab5..27705cb0 100644 --- a/pil/src/lib.rs +++ b/pil/src/lib.rs @@ -6,8 +6,5 @@ pub use pil_helpers::*; pub const ARITH32_AIR_IDS: &[usize] = &[4, 5]; pub const ARITH64_AIR_IDS: &[usize] = &[6]; pub const ARITH3264_AIR_IDS: &[usize] = &[7]; -pub const MEM_AIRGROUP_ID: usize = 105; -pub const MEM_ALIGN_AIR_IDS: &[usize] = &[1]; -pub const MEM_UNALIGNED_AIR_IDS: &[usize] = &[2, 3]; pub const QUICKOPS_AIRGROUP_ID: usize = 102; pub const QUICKOPS_AIR_IDS: &[usize] = &[10]; diff --git a/pil/src/pil_helpers/pilout.rs b/pil/src/pil_helpers/pilout.rs index 9a796335..ee4ac1d7 100644 --- a/pil/src/pil_helpers/pilout.rs +++ b/pil/src/pil_helpers/pilout.rs @@ -14,21 +14,29 @@ pub const MAIN_AIR_IDS: &[usize] = &[0]; pub const ROM_AIR_IDS: &[usize] = &[1]; -pub const ARITH_AIR_IDS: &[usize] = &[2]; +pub const MEM_AIR_IDS: &[usize] = &[2]; -pub const ARITH_TABLE_AIR_IDS: &[usize] = &[3]; +pub const MEM_ALIGN_AIR_IDS: &[usize] = &[3]; -pub const ARITH_RANGE_TABLE_AIR_IDS: &[usize] = &[4]; +pub const MEM_ALIGN_ROM_AIR_IDS: &[usize] = &[4]; -pub const BINARY_AIR_IDS: &[usize] = &[5]; +pub const ARITH_AIR_IDS: &[usize] = &[5]; -pub const BINARY_TABLE_AIR_IDS: &[usize] = &[6]; +pub const ARITH_TABLE_AIR_IDS: &[usize] = &[6]; -pub const BINARY_EXTENSION_AIR_IDS: &[usize] = &[7]; +pub const ARITH_RANGE_TABLE_AIR_IDS: &[usize] = &[7]; -pub const BINARY_EXTENSION_TABLE_AIR_IDS: &[usize] = &[8]; +pub const BINARY_AIR_IDS: &[usize] = &[8]; -pub const SPECIFIED_RANGES_AIR_IDS: &[usize] = &[9]; +pub const BINARY_TABLE_AIR_IDS: &[usize] = &[9]; + +pub const BINARY_EXTENSION_AIR_IDS: &[usize] = &[10]; + +pub const BINARY_EXTENSION_TABLE_AIR_IDS: &[usize] = &[11]; + +pub const SPECIFIED_RANGES_AIR_IDS: &[usize] = &[12]; + +pub const U_8_AIR_AIR_IDS: &[usize] = &[13]; pub struct Pilout; @@ -39,7 +47,10 @@ impl Pilout { let air_group = pilout.add_air_group(Some("Zisk")); air_group.add_air(Some("Main"), 2097152); - air_group.add_air(Some("Rom"), 1048576); + air_group.add_air(Some("Rom"), 4194304); + air_group.add_air(Some("Mem"), 2097152); + air_group.add_air(Some("MemAlign"), 2097152); + air_group.add_air(Some("MemAlignRom"), 256); air_group.add_air(Some("Arith"), 2097152); air_group.add_air(Some("ArithTable"), 128); air_group.add_air(Some("ArithRangeTable"), 4194304); @@ -48,6 +59,7 @@ impl Pilout { air_group.add_air(Some("BinaryExtension"), 2097152); air_group.add_air(Some("BinaryExtensionTable"), 4194304); air_group.add_air(Some("SpecifiedRanges"), 16777216); + air_group.add_air(Some("U8Air"), 256); pilout } diff --git a/pil/src/pil_helpers/traces.rs b/pil/src/pil_helpers/traces.rs index 32cfe09f..e9631826 100644 --- a/pil/src/pil_helpers/traces.rs +++ b/pil/src/pil_helpers/traces.rs @@ -11,6 +11,18 @@ trace!(RomRow, RomTrace { line: F, a_offset_imm0: F, a_imm1: F, b_offset_imm0: F, b_imm1: F, ind_width: F, op: F, store_offset: F, jmp_offset1: F, jmp_offset2: F, flags: F, multiplicity: F, }); +trace!(MemRow, MemTrace { + addr: F, step: F, sel: F, wr: F, value: [F; 2], addr_changes: F, increment: F, same_value: F, first_addr_access_is_read: F, +}); + +trace!(MemAlignRow, MemAlignTrace { + addr: F, offset: F, width: F, wr: F, pc: F, reset: F, sel_up_to_down: F, sel_down_to_up: F, reg: [F; 8], sel: [F; 8], step: F, delta_addr: F, sel_prove: F, value: [F; 2], +}); + +trace!(MemAlignRomRow, MemAlignRomTrace { + multiplicity: F, +}); + trace!(ArithRow, ArithTrace { carry: [F; 7], a: [F; 4], b: [F; 4], c: [F; 4], d: [F; 4], na: F, nb: F, nr: F, np: F, sext: F, m32: F, div: F, fab: F, na_fb: F, nb_fa: F, debug_main_step: F, main_div: F, main_mul: F, signed: F, div_by_zero: F, div_overflow: F, inv_sum_all_bs: F, op: F, bus_res1: F, multiplicity: F, range_ab: F, range_cd: F, }); @@ -40,5 +52,9 @@ trace!(BinaryExtensionTableRow, BinaryExtensionTableTrace { }); trace!(SpecifiedRangesRow, SpecifiedRangesTrace { - mul: [F; 1], + mul: [F; 2], +}); + +trace!(U8AirRow, U8AirTrace { + mul: F, }); diff --git a/pil/zisk.pil b/pil/zisk.pil index 0e97aeb6..49727ce0 100644 --- a/pil/zisk.pil +++ b/pil/zisk.pil @@ -1,19 +1,24 @@ - -require "constants.pil" -require "rom/pil/rom.pil" require "main/pil/main.pil" +require "rom/pil/rom.pil" +require "mem/pil/mem.pil" +require "mem/pil/mem_align.pil" +require "mem/pil/mem_align_rom.pil" require "binary/pil/binary.pil" require "binary/pil/binary_table.pil" require "binary/pil/binary_extension.pil" require "binary/pil/binary_extension_table.pil" require "arith/pil/arith.pil" -// require "mem/pil/mem.pil" const int OPERATION_BUS_ID = 5000; + airgroup Zisk { Main(N: 2**21, RC: 2, operation_bus_id: OPERATION_BUS_ID); - Rom(N: 2**20); - // Mem(N: 2**21, RC: 2); + Rom(N: 2**22); + + Mem(N: 2**21, RC: 2); + MemAlign(N: 2**21); + MemAlignRom(disable_fixed: 0); + Arith(N: 2**21, operation_bus_id: OPERATION_BUS_ID); ArithTable(); ArithRangeTable(); diff --git a/state-machines/binary/src/binary.rs b/state-machines/binary/src/binary.rs index 7dcec2c8..9b020312 100644 --- a/state-machines/binary/src/binary.rs +++ b/state-machines/binary/src/binary.rs @@ -78,13 +78,12 @@ impl BinarySM { pub fn unregister_predecessor(&self) { if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { - /* as Provable>::prove( - self, - &[], - true, - scope, - );*/ - //self.threads_controller.wait_for_threads(); + // as Provable>::prove( + // self, + // &[], + // true, + // scope, + // ); self.binary_basic_sm.unregister_predecessor(); self.binary_extension_sm.unregister_predecessor(); diff --git a/state-machines/binary/src/binary_basic.rs b/state-machines/binary/src/binary_basic.rs index 765ed11f..1cafdbd6 100644 --- a/state-machines/binary/src/binary_basic.rs +++ b/state-machines/binary/src/binary_basic.rs @@ -12,7 +12,7 @@ use rayon::Scope; use sm_common::{create_prover_buffer, OpResult, Provable}; use std::cmp::Ordering as CmpOrdering; use zisk_core::{zisk_ops::ZiskOp, ZiskRequiredOperation}; -use zisk_pil::*; +use zisk_pil::{BinaryRow, BinaryTrace, BINARY_AIR_IDS, BINARY_TABLE_AIR_IDS, ZISK_AIRGROUP_ID}; use crate::{BinaryBasicTableOp, BinaryBasicTableSM}; diff --git a/state-machines/binary/src/binary_extension.rs b/state-machines/binary/src/binary_extension.rs index eee0c226..8de7d459 100644 --- a/state-machines/binary/src/binary_extension.rs +++ b/state-machines/binary/src/binary_extension.rs @@ -17,7 +17,10 @@ use proofman_util::{timer_start_debug, timer_stop_and_log_debug}; use rayon::Scope; use sm_common::{create_prover_buffer, OpResult, Provable}; use zisk_core::{zisk_ops::ZiskOp, ZiskRequiredOperation}; -use zisk_pil::*; +use zisk_pil::{ + BinaryExtensionRow, BinaryExtensionTrace, BINARY_EXTENSION_AIR_IDS, + BINARY_EXTENSION_TABLE_AIR_IDS, ZISK_AIRGROUP_ID, +}; const MASK_32: u64 = 0xFFFFFFFF; const MASK_64: u64 = 0xFFFFFFFFFFFFFFFF; diff --git a/state-machines/main/pil/main.pil b/state-machines/main/pil/main.pil index 027bee94..7ebcfe3a 100644 --- a/state-machines/main/pil/main.pil +++ b/state-machines/main/pil/main.pil @@ -79,7 +79,7 @@ airtemplate Main(int N = 2**21, int RC = 2, int stack_enabled = 0, const int ope col witness air.b_imm1; } col witness b_src_ind; - col witness ind_width; // 8 , 4, 2, 1 + col witness ind_width; // 8, 4, 2, 1 // Operations related @@ -112,8 +112,6 @@ airtemplate Main(int N = 2**21, int RC = 2, int stack_enabled = 0, const int ope col witness jmp_offset1, jmp_offset2; // if flag, goto2, else goto 1 col witness m32; - const expr addr_step = STEP * 3; - const expr sel_mem_b; sel_mem_b = b_src_mem + b_src_ind; @@ -135,17 +133,18 @@ airtemplate Main(int N = 2**21, int RC = 2, int stack_enabled = 0, const int ope } // Mem.load - //mem_load(sel: a_src_mem, - // step: addr_step, - // addr: addr0, - // value: a); + mem_load(sel: a_src_mem, + step: STEP, + addr: addr0, + value: a); // Mem.load - //mem_load(sel: sel_mem_b, - // step: addr_step + 1, - // bytes: ind_width, - // addr: addr1, - // value: b); + mem_load(sel: sel_mem_b, + step: STEP, + step_offset: 1, + bytes: b_src_ind * (ind_width - 8) + 8, + addr: addr1, + value: b); const expr store_value[2]; @@ -153,11 +152,12 @@ airtemplate Main(int N = 2**21, int RC = 2, int stack_enabled = 0, const int ope store_value[1] = (1 - store_ra) * c[1]; // Mem.store - //mem_store(sel: store_mem + store_ind, - // step: addr_step + 2, - // bytes: ind_width, - // addr: addr2, - // value: store_value); + mem_store(sel: store_mem + store_ind, + step: STEP, + step_offset: 2, + bytes: store_ind * (ind_width - 8) + 8, + addr: addr2, + value: store_value); // Operation.assume => how organize software col witness __debug_operation_bus_enabled; diff --git a/state-machines/main/src/main_sm.rs b/state-machines/main/src/main_sm.rs index 30240f83..baaa29f8 100644 --- a/state-machines/main/src/main_sm.rs +++ b/state-machines/main/src/main_sm.rs @@ -1,5 +1,6 @@ use log::info; use p3_field::PrimeField; +use sm_mem::MemProxy; use crate::InstanceExtensionCtx; use proofman_util::{timer_start_debug, timer_stop_and_log_debug}; @@ -12,7 +13,6 @@ use proofman_common::{AirInstance, ProofCtx}; use proofman::WitnessComponent; use sm_arith::ArithSM; -use sm_mem::MemSM; use zisk_pil::{ MainRow, MainTrace, ARITH_AIR_IDS, BINARY_AIR_IDS, BINARY_EXTENSION_AIR_IDS, MAIN_AIR_IDS, ZISK_AIRGROUP_ID, @@ -28,14 +28,14 @@ pub struct MainSM { /// Witness computation manager wcm: Arc>, + /// Memory state machine + mem_proxy_sm: Arc>, + /// Arithmetic state machine arith_sm: Arc>, /// Binary state machine binary_sm: Arc>, - - /// Memory state machine - mem_sm: Arc, } impl MainSM { @@ -54,16 +54,16 @@ impl MainSM { /// * Arc to the MainSM state machine pub fn new( wcm: Arc>, + mem_proxy_sm: Arc>, arith_sm: Arc>, binary_sm: Arc>, - mem_sm: Arc, ) -> Arc { - let main_sm = Arc::new(Self { wcm: wcm.clone(), arith_sm, binary_sm, mem_sm }); + let main_sm = Arc::new(Self { wcm: wcm.clone(), mem_proxy_sm, arith_sm, binary_sm }); wcm.register_component(main_sm.clone(), Some(ZISK_AIRGROUP_ID), Some(MAIN_AIR_IDS)); // For all the secondary state machines, register the main state machine as a predecessor - main_sm.mem_sm.register_predecessor(); + main_sm.mem_proxy_sm.register_predecessor(); main_sm.binary_sm.register_predecessor(); main_sm.arith_sm.register_predecessor(); @@ -151,6 +151,39 @@ impl MainSM { segment_trace.steps[slice_start..slice_end].iter().enumerate() { partial_trace[i] = emu.step_slice_full_trace(emu_trace_step); + // if partial_trace[i].a_src_mem == F::one() { + // println!( + // "A=MEM_OP_RD({}) [{},{}] PC:{}", + // partial_trace[i].a_offset_imm0, + // partial_trace[i].a[0], + // partial_trace[i].a[1], + // partial_trace[i].pc + // ); + // } + // if partial_trace[i].b_src_mem == F::one() || partial_trace[i].b_src_ind == + // F::one() { + // println!( + // "B=MEM_OP_RD({0}) [{1},{2}] PC:{3}", + // partial_trace[i].addr1, + // partial_trace[i].b[0], + // partial_trace[i].b[1], + // partial_trace[i].pc + // ); + // } + // if partial_trace[i].b_src_mem == F::one() || partial_trace[i].b_src_ind == + // F::one() { + // println!( + // "MEM_OP_WR({}) [{}, {}] PC:{}", + // partial_trace[i].store_offset + // + partial_trace[i].store_ind * partial_trace[i].a[0], + // partial_trace[i].store_ra + // * (partial_trace[i].pc + partial_trace[i].jmp_offset2 + // - partial_trace[i].c[0]) + // + partial_trace[i].c[0], + // (F::one() - partial_trace[i].store_ra) * partial_trace[i].c[1], + // partial_trace[i].pc + // ); + // } } // if there are steps in the chunk update last row if slice_end - slice_start > 0 { diff --git a/state-machines/mem/Cargo.toml b/state-machines/mem/Cargo.toml index 3f8ee914..7cdb344d 100644 --- a/state-machines/mem/Cargo.toml +++ b/state-machines/mem/Cargo.toml @@ -7,14 +7,21 @@ edition = "2021" sm-common = { path = "../common" } zisk-core = { path = "../../core" } zisk-pil = { path = "../../pil" } +num-traits = "0.2" -p3-field = { workspace=true } proofman-common = { workspace = true } proofman-macros = { workspace = true } +proofman-util = { workspace = true } proofman = { workspace = true } +pil-std-lib = { workspace = true } + +p3-field = { workspace=true } log = { workspace = true } rayon = { workspace = true } +num-bigint = { workspace = true } [features] default = [] -no_lib_link = ["proofman-common/no_lib_link", "proofman/no_lib_link"] \ No newline at end of file +no_lib_link = ["proofman-common/no_lib_link", "proofman/no_lib_link"] +debug_mem_proxy_engine = [] +debug_mem_align = [] \ No newline at end of file diff --git a/state-machines/mem/pil/mem.pil b/state-machines/mem/pil/mem.pil index 50bd652e..23740730 100644 --- a/state-machines/mem/pil/mem.pil +++ b/state-machines/mem/pil/mem.pil @@ -6,24 +6,26 @@ const int MEMORY_CONT_ID = 11; const int MEMORY_LOAD_OP = 1; const int MEMORY_STORE_OP = 2; -const int MEMORY_MAX_DIFF = 2**22; +const int MEMORY_MAX_DIFF = 2**24; -const int MAX_MEM_STEP_OFFSET = 3; +const int MAX_MEM_STEP_OFFSET = 2; +const int MAX_MEM_OPS_PER_MAIN_STEP = (MAX_MEM_STEP_OFFSET + 1) * 2; -airtemplate Mem (int N = 2**21, int RC = 2, int id = MEMORY_ID, int MAX_STEP = 2 ** 23, int MEM_BYTES = 8 ) { +airtemplate Mem(const int N = 2**21, const int id = MEMORY_ID, const int RC = 2, const int MEM_BYTES = 8, const int INITIAL_ADDRESS = 0xA0000000) { col fixed SEGMENT_L1 = [1,0...]; const expr SEGMENT_LAST = SEGMENT_L1'; airval mem_segment; airval mem_last_segment; - col witness addr; // n-byte address, real address = addr * MEM_BYTES + col witness addr; // n-byte address, real address = addr * MEM_BYTES col witness step; - col witness sel, wr; + col witness sel; + col witness wr; col witness value[RC]; col witness addr_changes; - const expr rd = (1 - wr); + const expr rd = 1 - wr; sel * (1 - sel) === 0; wr * (1 - wr) === 0; @@ -36,21 +38,26 @@ airtemplate Mem (int N = 2**21, int RC = 2, int id = MEMORY_ID, int MAX_STEP = 2 addr_changes * (1 - addr_changes) === 0; // check increment of memory - range_check(sel: (1 - SEGMENT_L1), colu: addr_changes * (addr - 'addr - step + 'step) + step - 'step, min: 1, max: MEMORY_MAX_DIFF); + col witness increment; + increment === SEGMENT_L1 * (addr - 1 - INITIAL_ADDRESS) + (1 - SEGMENT_L1) * (addr_changes * (addr - 'addr - step + 'step) + step - 'step); + range_check(colu: increment, min: 1, max: MEMORY_MAX_DIFF); // PADDING: At end of memory fill with same addr, incrementing step, same value, sel = 0, rd = 1, wr = 0 // setting mem_last_segment = 1 // if addr_changes == 0 means that addr and previous address are the same - (1 - addr_changes) * ('addr - addr) === 0; + const expr same_addr = 1 - SEGMENT_L1 - addr_changes; + same_addr * ('addr - addr) === 0; col witness same_value; - (1 - same_value) * (1 - wr) * (1 - addr_changes) === 0; + same_value * (1 - same_value) === 0; + (1 - same_value) * (1 - wr) * same_addr === 0; col witness first_addr_access_is_read; - (1 - first_addr_access_is_read) * rd * (1 - addr_changes) === 0; + first_addr_access_is_read * (1 - first_addr_access_is_read) === 0; + (1 - first_addr_access_is_read) * rd * addr_changes === 0; - for (int index = 0; index < length(value); index = index + 1) { + for (int index = 0; index < length(value); index++) { same_value * (value[index] - 'value[index]) === 0; first_addr_access_is_read * value[index] === 0; } @@ -85,22 +92,26 @@ airtemplate Mem (int N = 2**21, int RC = 2, int id = MEMORY_ID, int MAX_STEP = 2 // permutation_proves(MEMORY_CONT_ID, [(mem_segment + 1), addr, step, ...value], sel: mem_last_segment * 'SEGMENT_L1); // last row // permutation_assumes(MEMORY_CONT_ID, [mem_segment, 0, addr, step, ...value], sel: SEGMENT_L1); // first row - permutation_proves(MEMORY_ID, cols: [wr, addr * MEM_BYTES, step, MEM_BYTES, ...value], sel: sel); + // The Memory component is only able to prove aligned memory access, since we force the bus address to be a multiple of MEM_BYTES + // and the width to be exactly MEM_BYTES + // Notice, however, that the main can also use widths of 4, 2, 1 and addresses that are not multiples of MEM_BYTES. + // These are handled with the Memory Align component + permutation_proves(MEMORY_ID, cols: [wr * (MEMORY_STORE_OP - MEMORY_LOAD_OP) + MEMORY_LOAD_OP, addr * MEM_BYTES, step, MEM_BYTES, ...value], sel: sel); } -// TODO: detect non default value but not called, mandatory parameter. -function mem_load(int id = MEMORY_ID, expr sel = 1, expr addr, expr step, expr step_offset = 0, expr bytes = 8, expr value[]) { - if (step_offset > MAX_MEM_STEP_OFFSET) { - error("max step_offset ${step_offset} is greater than max value ${MAX_MEM_STEP_OFFSET}"); - } - // adding one for first continuation - permutation_assumes(id, [MEMORY_LOAD_OP, addr, 1 + ((MAX_MEM_STEP_OFFSET + 1) * step) + step_offset, bytes, ...value], sel:sel); +function mem_load(int id = MEMORY_ID, expr addr, expr step, expr step_offset = 0, expr bytes = 8, expr value[], expr sel = 1) { + mem_assumes(id, MEMORY_LOAD_OP, addr, step, step_offset, bytes, value, sel); } -function mem_store(int id = MEMORY_ID, expr sel = 1, expr addr, expr step, expr step_offset = 0, expr bytes = 8, expr value[]) { +function mem_store(int id = MEMORY_ID, expr addr, expr step, expr step_offset = 0, expr bytes = 8, expr value[], expr sel = 1) { + mem_assumes(id, MEMORY_STORE_OP, addr, step, step_offset, bytes, value, sel); +} + +private function mem_assumes(int id, int mem_op, expr addr, expr step, expr step_offset, expr bytes, expr value[], expr sel) { if (step_offset > MAX_MEM_STEP_OFFSET) { - error("max step_offset ${step_offset} is greater than max value ${MAX_MEM_STEP_OFFSET}"); + error("step_offset ${step_offset} is greater than max value allowed ${MAX_MEM_STEP_OFFSET}"); } - // adding one for first continuation - permutation_assumes(id, [MEMORY_STORE_OP, addr, 1 + ((MAX_MEM_STEP_OFFSET + 1) * step), bytes, ...value], sel:sel); + + // adding 1 at step for first continuation + permutation_assumes(id, [mem_op, addr, 1 + MAX_MEM_OPS_PER_MAIN_STEP * step + 2 * step_offset, bytes, ...value], sel: sel); } \ No newline at end of file diff --git a/state-machines/mem/pil/mem_align.pil b/state-machines/mem/pil/mem_align.pil index e69de29b..8a23ab2a 100644 --- a/state-machines/mem/pil/mem_align.pil +++ b/state-machines/mem/pil/mem_align.pil @@ -0,0 +1,188 @@ +require "std_permutation.pil" +require "std_lookup.pil" +require "std_range_check.pil" + +// Problem to solve: +// ================= +// We are given an op (rd,wr), an addr, a step and a bytes-width (8,4,2,1) and we should prove that the memory access is correct. +// Note: Either the original addr is not a multiple of 8 or width < 8 to ensure it is a non-aligned access that should be +// handled by this component. + +/* + We will model it as a very specified processor with 8 registers and a very limited instruction set. + + This processor is limited to 4 possible subprograms: + + 1] Read operation that spans one memory word w = [w_0, w_1]: + w_0 w_1 + +---+===+===+===+ +===+---+---+---+ + | 0 | 1 | 2 | 3 | | 4 | 5 | 6 | 7 | + +---+===+===+===+ +===+---+---+---+ + |<------ v ------>| + + [R] In the first clock cycle, we perform an aligned read to w + [V] In the second clock cycle, we return the demanded value v from w + + 2] Write operation that spans one memory word w = [w_0, w_1]: + w_0 w_1 + +---+---+---+---+ +---+===+===+---+ + | 0 | 1 | 2 | 3 | | 4 | 5 | 6 | 7 | + +---+---+---+---+ +---+===+===+---+ + |<- v ->| + + [R] In the first clock cycle, we perform an aligned read to w + [W] In the second clock cycle, we compute an aligned write of v to w + [V] In the third clock cycle, we restore the demanded value from w + + 3] Read operation that spans two memory words w1 = [w1_0, w1_1] and w2 = [w2_0, w2_1]: + w1_0 w1_1 w2_0 w2_1 + +---+---+---+---+ +---+===+===+===+ +===+===+===+===+ +===+---+---+---+ + | 0 | 1 | 2 | 3 | | 4 | 5 | 6 | 7 | | 0 | 1 | 2 | 3 | | 4 | 5 | 6 | 7 | + +---+---+---+---+ +---+===+===+===+ +===+===+===+===+ +===+---+---+---+ + |<---------------- v ---------------->| + + [R] In the first clock cycle, we perform an aligned read to w1 + [V] In the second clock cycle, we return the demanded value v from w1 and w2 + [R] In the third clock cycle, we perform an aligned read to w2 + + 4] Write operation that spans two memory words w1 = [w1_0, w1_1] and w2 = [w2_0, w2_1]: + w1_0 w1_1 w2_0 w2_1 + +---+===+===+===+ +===+===+===+===+ +===+---+---+---+ +---+---+---+---+ + | 0 | 1 | 2 | 3 | | 4 | 5 | 6 | 7 | | 0 | 1 | 2 | 3 | | 4 | 5 | 6 | 7 | + +---+===+===+===+ +===+===+===+===+ +===+---+---+---+ +---+---+---+---+ + |<---------------- v ---------------->| + + [R] In the first clock cycle, we perform an aligned read to w1 + [W] In the second clock cycle, we compute an aligned write of v to w1 + [V] In the third clock cycle, we restore the demanded value from w1 and w2 + [R] In the fourth clock cycle, we perform an aligned read to w2 + [W] In the fiveth clock cycle, we compute an aligned write of v to w2 + + Example: + ========================================================== + (offset = 6, width = 4) + +----+----+----+----+----+----+----+----+ + | R0 | R1 | R2 | R3 | R4 | R5 | R6 | R7 | [R1] (assume, up_to_down) sel = [1,1,1,1,1,1,0,0] + +----+----+----+----+----+----+----+----+ + ⇓ + +----+----+----+----+----+----+====+====+ + | W0 | W1 | W2 | W3 | W4 | W5 | W6 | W7 | [W1] (assume, up_to_down) sel = [0,0,0,0,0,0,1,1] + +----+----+----+----+----+----+====+====+ + ⇓ + +====+====+----+----+----+----+====+====+ + | V6 | V7 | V0 | V1 | V2 | V3 | V4 | V5 | [V] (prove) (shift (offset + width) % 8) sel = [0,0,0,0,0,0,1,0] (*) + +====+====+----+----+----+----+====+====+ + ⇓ + +====+====+----+----+----+----+----+----+ + | W0 | W1 | W2 | W3 | W4 | W5 | W6 | W7 | [W2] (assume, down_to_up) sel = [1,1,0,0,0,0,0,0] + +====+====+----+----+----+----+----+----+ + ⇓ + +----+----+----+----+----+----+----+----+ + | R0 | R1 | R2 | R3 | R4 | R5 | R6 | R7 | [R2] (assume, down_to_up) sel = [0,0,1,1,1,1,1,1] + +----+----+----+----+----+----+----+----+ + + (*) In this step, we use the selectors to indicate the "scanning" needed to form the bus value: + v_0 = sel[0] * [V1,V0,V7,V6] + sel[1] * [V0,V7,V6,V5] + sel[2] * [V7,V6,V5,V4] + sel[3] * [V6,V5,V4,V3] + v_1 = sel[4] * [V5,V4,V3,V2] + sel[5] * [V4,V3,V2,V1] + sel[6] * [V3,V2,V1,V0] + sel[7] * [V2,V1,V0,V7] + Notice that it is enough with 8 combinations. +*/ + +airtemplate MemAlign(const int N = 2**10, const int RC = 2, const int CHUNK_NUM = 8, const int CHUNK_BITS = 8) { + const int CHUNKS_BY_RC = CHUNK_NUM / RC; + + col witness addr; // CHUNK_NUM-byte address, real address = addr * CHUNK_NUM + col witness offset; // 0..7, position at which the operation starts + col witness width; // 1,2,4,8, width of the operation + col witness wr; // 1 if the operation is a write, 0 otherwise + col witness pc; // line of the program to execute + col witness reset; // 1 at the beginning of the operation (indicating an address reset), 0 otherwise + col witness sel_up_to_down; // 1 if the next value is the current value (e.g. R -> W) + col witness sel_down_to_up; // 1 if the next value is the previous value (e.g. W -> R) + col witness reg[CHUNK_NUM]; // Register values, 1 byte each + col witness sel[CHUNK_NUM]; // Selectors, 1 if the value is used, 0 otherwise + col witness step; // Memory step + + // 1] Ensure the MemAlign follows the program + + // Registers should be bytes and be shuch that: + // - reg' == reg in transitions R -> V, R -> W, W -> V, + // - 'reg == reg in transitions V <- W, W <- R, + // in any case, sel_up_to_down,sel_down_to_up are 0 in [V] steps. + for (int i = 0; i < CHUNK_NUM; i++) { + range_check(reg[i], 0, 2**CHUNK_BITS-1); + + (reg[i]' - reg[i]) * sel[i] * sel_up_to_down === 0; + ('reg[i] - reg[i]) * sel[i] * sel_down_to_up === 0; + } + + col fixed L1 = [1,0...]; + L1 * pc === 0; // The program should start at the first line + + // We compress selectors, so we should ensure they are binary + for (int i = 0; i < CHUNK_NUM; i++) { + sel[i] * (1 - sel[i]) === 0; + } + wr * (1 - wr) === 0; + reset * (1 - reset) === 0; + sel_up_to_down * (1 - sel_up_to_down) === 0; + sel_down_to_up * (1 - sel_down_to_up) === 0; + + expr flags = 0; + for (int i = 0; i < CHUNK_NUM; i++) { + flags += sel[i] * 2**i; + } + flags += wr * 2**CHUNK_NUM + reset * 2**(CHUNK_NUM + 1) + sel_up_to_down * 2**(CHUNK_NUM + 2) + sel_down_to_up * 2**(CHUNK_NUM + 3); + + // Perform the lookup against the program + expr delta_pc; + col witness delta_addr; // Auxiliary column + delta_pc = pc' - pc; + delta_addr === (addr - 'addr) * (1 - reset); + lookup_assumes(MEM_ALIGN_ROM_ID, [pc, delta_pc, delta_addr, offset, width, flags]); + + // 2] Assume aligned memory accesses against the Memory component + const expr sel_assume = sel_up_to_down + sel_down_to_up; + + // Offset should be 0 in aligned memory accesses, but this is ensured by the rom + // Width should be 8 in aligned memory accesses, but this is ensured by the rom + + // On assume steps, we reconstruct the value from the registers directly + expr assume_val[RC]; + for (int rc_index = 0; rc_index < RC; rc_index++) { + assume_val[rc_index] = 0; + int base = 1; + for (int _offset = 0; _offset < CHUNKS_BY_RC; _offset++) { + assume_val[rc_index] += reg[_offset + rc_index * CHUNKS_BY_RC] * base; + base *= 256; + } + } + + // 3] Prove unaligned memory accesses against the Main component + col witness sel_prove; + + sel_prove * sel_assume === 0; // Disjoint selectors + + // On prove steps, we reconstruct the value in the correct manner chosen by the selectors + expr prove_val[RC]; + for (int rc_index = 0; rc_index < RC; rc_index++) { + prove_val[rc_index] = 0; + } + for (int _offset = 0; _offset < CHUNK_NUM; _offset++) { + for (int rc_index = 0; rc_index < RC; rc_index++) { + expr _tmp = 0; + int base = 1; + for (int ichunk = 0; ichunk < CHUNKS_BY_RC; ichunk++) { + _tmp += reg[(_offset + rc_index * CHUNKS_BY_RC + ichunk) % CHUNK_NUM] * base; + base *= 256; + } + prove_val[rc_index] += sel[_offset] * _tmp; + } + } + + // We prove and assume with the same permutation check but with disjoint and different sign selectors + col witness value[RC]; // Auxiliary columns + for (int i = 0; i < RC; i++) { + value[i] === sel_prove * prove_val[i] + sel_assume * assume_val[i]; + } + permutation(MEMORY_ID, cols: [wr * (MEMORY_STORE_OP - MEMORY_LOAD_OP) + MEMORY_LOAD_OP, addr * CHUNK_NUM + offset, step, width, ...value], sel: sel_prove - sel_assume); +} \ No newline at end of file diff --git a/state-machines/mem/pil/mem_align_rom.pil b/state-machines/mem/pil/mem_align_rom.pil new file mode 100644 index 00000000..db0d4440 --- /dev/null +++ b/state-machines/mem/pil/mem_align_rom.pil @@ -0,0 +1,324 @@ +require "std_lookup.pil" +require "constants.pil" + +const int MEM_ALIGN_ROM_ID = 133; +const int MEM_ALIGN_ROM_SIZE = P2_8; + +airtemplate MemAlignRom(const int N = MEM_ALIGN_ROM_SIZE, const int CHUNK_NUM = 8, const int DEFAULT_OFFSET = 0, const int DEFAULT_WIDTH = 8, const int disable_fixed = 0) { + if (N < MEM_ALIGN_ROM_SIZE) { + error(`N must be at least ${MEM_ALIGN_ROM_SIZE}, but N=${N} was provided`); + } + + col witness multiplicity; + + if (disable_fixed) { + col fixed _K = [0...]; + multiplicity * _K === 0; + + println("*** DISABLE_FIXED ***"); + return; + } + + // Define the size of each sub-program: RV, RWV, RVR, RWVWR + const int spsize[4] = [2, 3, 3, 5]; + + // Not all combinations of offset and width are valid for each program: + const int one_word_combinations = 20; // (0..4,[1,2,4]), (5,6,[1,2]), (7,[1]) -> 5*3 + 2*2 + 1*1 = 20 + const int two_word_combinations = 11; // (1..4,[8]), (5,6,[4,8]), (7,[2,4,8]) -> 4*1 + 2*2 + 1*3 = 11 + + // table_size = combinations * program_size + const int tsize[4] = [one_word_combinations*spsize[0], one_word_combinations*spsize[1], two_word_combinations*spsize[2], two_word_combinations*spsize[3]]; + const int psize = tsize[0] + tsize[1] + tsize[2] + tsize[3]; + + // Offset is set to DEFAULT_OFFSET and width to DEFAULT_WIDTH in aligned memory accesses. + // Offset and width are set to 0 in padding lines. + // size + col fixed OFFSET = [0, // Padding 1 = 1 | 1 + [[0,0]:3, [0,1]:3, [0,2]:3, [0,3]:3, [0,4]:3, [0,5]:2, [0,6]:2, [0,7]], // RV 6+6*4+4+4+2 = 40 | 41 + [[0,0,0]:3, [0,0,1]:3, [0,0,2]:3, [0,0,3]:3, [0,0,4]:3, [0,0,5]:2, [0,0,6]:2, [0,0,7]], // RWV 9+9*4+6+6+3 = 60 | 101 + [[0,1,0], [0,2,0], [0,3,0], [0,4,0], [0,5,0]:2, [0,6,0]:2, [0,7,0]:3], // RVR 3*4+6+6+9 = 33 | 134 + [[0,0,1,0,0], [0,0,2,0,0], [0,0,3,0,0], [0,0,4,0,0], [0,0,5,0,0]:2, [0,0,6,0,0]:2, [0,0,7,0,0]:3], // RWVWR 5*4+10+10+15 = 55 | 189 => N = 2^8 + 0...]; // Padding + + col fixed WIDTH = [0, // Padding + [[8,1,8,2,8,4]:5, [8,1,8,2]:2, [8,1]], // RV + [[8,8,1,8,8,2,8,8,4]:5, [8,8,1,8,8,2]:2, [8,8,1]], // RWV + [[8,8,8]:4, [8,4,8,8,8,8]:2, [8,2,8,8,4,8,8,8,8]], // RVR + [[8,8,8,8,8]:4, [8,8,4,8,8,8,8,8,8,8]:2, [8,8,2,8,8,8,8,4,8,8,8,8,8,8,8]], // RWVWR + 0...]; // Padding + + // line | pc | pc'-pc | reset | addr | (addr-'addr)*(1-reset) | + // 0 | 0 | 0 | 1 | 0 | 0 | // for padding + // 1 | 0 | 1 | 1 | X1 | 0 | // (RV) + // 2 | 1 | -1 | 0 | X1 | 0 | + // 3 | 0 | 3 | 1 | X2 | 0 | // (RV) + // 4 | 3 | -3 | 0 | X2 | 0 | + // 5 | 0 | 5 | 1 | X3 | 0 | // (RV) + // 6 | 5 | -5 | 0 | X3 | 0 | + // 7 | 0 | 7 | 1 | ⋮ | ⋮ | // (RV) + // ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | + // 41 | 0 | 41 | 1 | X4 | 0 | // (RWV) + // 42 | 41 | 1 | 0 | X4 | 0 | + // 43 | 42 | -42 | 0 | X4 | 0 | + // 44 | 0 | 44 | 1 | X5 | 0 | // (RWV) + // 45 | 44 | 1 | 0 | X5 | 0 | + // 46 | 45 | -45 | 0 | X5 | 0 | + // 47 | 0 | 47 | 1 | X6 | 0 | // (RWV) + // ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | + // 101 | 0 | 101 | 1 | X7 | 0 | // (RVR) + // 102 |101 | 1 | 0 | X7 | 0 | + // 103 |102 | -102 | 0 | X7+1 | 1 | + // 104 | 0 | 104 | 1 | X8 | 0 | // (RVR) + // 105 |104 | 1 | 0 | X8 | 0 | + // 106 |105 | -105 | 0 | X8+1 | 1 | + // 107 | 0 | 107 | 1 | X9 | 0 | // (RVR) + // ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | + // 134 | 0 | 134 | 1 | X10 | 0 | // (RWVWR) + // 135 |134 | 1 | 0 | X10 | 0 | + // 136 |135 | 1 | 0 | X10 | 0 | + // 137 |136 | 1 | 0 | X10+1 | 1 | + // 138 |137 | -137 | 0 | X10+1 | 0 | + // 139 | 0 | 139 | 1 | X11 | 0 | // (RWVWR) + // 140 |139 | 1 | 0 | X11 | 0 | + // 141 |140 | 1 | 0 | X11 | 0 | + // 142 |141 | 1 | 0 | X11+1 | 1 | + // 143 |142 | -142 | 0 | X11+1 | 0 | + // 144 | 0 | 144 | 1 | X12 | 0 | // (RWVWR) + // ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | + // 188 |187 | -187 | 0 | X13+1 | 0 | + // 189 | 0 | 0 | 1 | 0 | 0 | // for padding + // ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | + + // Note: The overall program contains "holes", meaning that pc can vary + // from program to program by any constant, as long as it is unique for each program. + // For example, the first program has pc=0,1, while the second has pc=0,3. + + col fixed PC; + col fixed DELTA_PC; + col fixed DELTA_ADDR; + col fixed FLAGS; + for (int i = 0; i < N; i++) { + int pc = 0; + int delta_pc = 0; + int delta_addr = 0; + int is_write = 0; + int reset = 0; + int sel[CHUNK_NUM]; + for (int j = 0; j < CHUNK_NUM; j++) { + sel[j] = 0; + } + int sel_up_to_down = 0; + int sel_down_to_up = 0; + + const int prev_line = i == 0 ? 0 : i-1; + const int line = i; + if (line == 0 || line > psize) + { + // pc = 0; + // delta_pc = 0; + // delta_addr = 0; + // is_write = 0; + reset = 1; + // sel = [0:CHUNK_NUM] + // sel_up_to_down = 0; + // sel_down_to_up = 0; + } + else if (line < 1+tsize[0]) // RV + { + if (line % 2 == 1) { + // pc = 0; + delta_pc = line; + // delta_addr = 0; + // is_write = 0; + reset = 1; + for (int j = 0; j < CHUNK_NUM; j++) { + if (j >= OFFSET[i+1] && j < OFFSET[i+1] + WIDTH[i+1]) { + sel[j] = 1; + } + } + sel_up_to_down = 1; + // sel_down_to_up = 0; + } else { + pc = prev_line; + delta_pc = -pc; + // delta_addr = 0; + // is_write = 0; + // reset = 0; + for (int j = 0; j < CHUNK_NUM; j++) { + if (j == OFFSET[i]) { + sel[j] = 1; + } + } + // sel_up_to_down = 0; + // sel_down_to_up = 0; + } + } + else if (line < 1+tsize[0]+tsize[1]) // RWV + { + if (line % 3 == 2) { + // pc = 0; + delta_pc = line; + // delta_addr = 0; + // is_write = 0; + reset = 1; + for (int j = 0; j < CHUNK_NUM; j++) { + if (j < OFFSET[i+2] || j >= OFFSET[i+2] + WIDTH[i+2]) { + sel[j] = 1; + } + } + sel_up_to_down = 1; + // sel_down_to_up = 0; + } else if (line % 3 == 0) { + pc = prev_line; + delta_pc = 1; + // delta_addr = 0; + is_write = 1; + // reset = 0; + for (int j = 0; j < CHUNK_NUM; j++) { + if (j >= OFFSET[i+1] && j < OFFSET[i+1] + WIDTH[i+1]) { + sel[j] = 1; + } + } + sel_up_to_down = 1; + // sel_down_to_up = 0; + } else { + pc = prev_line; + delta_pc = -pc; + // delta_addr = 0; + is_write = 1; + // reset = 0; + for (int j = 0; j < CHUNK_NUM; j++) { + if (j == OFFSET[i]) { + sel[j] = 1; + } + } + // sel_up_to_down = 0; + // sel_down_to_up = 0; + } + } + else if (line < 1+tsize[0]+tsize[1]+tsize[2]) // RVR + { + if (line % 3 == 2) { + // pc = 0; + delta_pc = line; + // delta_addr = 0; + // is_write = 0; + reset = 1; + for (int j = 0; j < CHUNK_NUM; j++) { + if (j >= OFFSET[i+1]) { + sel[j] = 1; + } + } + sel_up_to_down = 1; + // sel_down_to_up = 0; + } else if (line % 3 == 0) { + pc = prev_line; + delta_pc = 1; + // delta_addr = 0; + // is_write = 0; + // reset = 0; + for (int j = 0; j < CHUNK_NUM; j++) { + if (j == OFFSET[i]) { + sel[j] = 1; + } + } + // sel_up_to_down = 0; + // sel_down_to_up = 0; + } else { + pc = prev_line; + delta_pc = -pc; + delta_addr = 1; + // is_write = 0; + // reset = 0; + for (int j = 0; j < CHUNK_NUM; j++) { + if (j < (OFFSET[i-1] + WIDTH[i-1]) % CHUNK_NUM) { + sel[j] = 1; + } + } + // sel_up_to_down = 0; + sel_down_to_up = 1; + } + } + else if (line < 1+tsize[0]+tsize[1]+tsize[2]+tsize[3]) // RWVWR + { + if (line % 5 == 4) { + // pc = 0; + delta_pc = line; + // delta_addr = 0; + // is_write = 0; + reset = 1; + for (int j = 0; j < CHUNK_NUM; j++) { + if (j < OFFSET[i+2]) { + sel[j] = 1; + } + } + sel_up_to_down = 1; + // sel_down_to_up = 0; + } else if (line % 5 == 0) { + pc = prev_line; + delta_pc = 1; + // delta_addr = 0; + is_write = 1; + // reset = 0; + for (int j = 0; j < CHUNK_NUM; j++) { + if (j >= OFFSET[i+1]) { + sel[j] = 1; + } + } + sel_up_to_down = 1; + // sel_down_to_up = 0; + } else if (line % 5 == 1) { + pc = prev_line; + delta_pc = 1; + // delta_addr = 0; + is_write = 1; + // reset = 0; + for (int j = 0; j < CHUNK_NUM; j++) { + if (j == OFFSET[i]) { + sel[j] = 1; + } + } + // sel_up_to_down = 0; + // sel_down_to_up = 0; + } else if (line % 5 == 2) { + pc = prev_line; + delta_pc = 1; + delta_addr = 1; + is_write = 1; + // reset = 0; + for (int j = 0; j < CHUNK_NUM; j++) { + if (j < (OFFSET[i-1] + WIDTH[i-1]) % CHUNK_NUM) { + sel[j] = 1; + } + } + // sel_up_to_down = 0; + sel_down_to_up = 1; + } else { + pc = prev_line; + delta_pc = -pc; + // delta_addr = 0; + // is_write = 0; + // reset = 0; + for (int j = 0; j < CHUNK_NUM; j++) { + if (j >= (OFFSET[i-2] + WIDTH[i-2]) % CHUNK_NUM) { + sel[j] = 1; + } + } + // sel_up_to_down = 0; + sel_down_to_up = 1; + } + } + PC[i] = pc; + DELTA_PC[i] = delta_pc; + DELTA_ADDR[i] = delta_addr; + int flags = 0; + for (int j = 0; j < CHUNK_NUM; j++) { + flags += sel[j] * 2**j; + } + flags += is_write * 2**CHUNK_NUM + reset * 2**(CHUNK_NUM + 1) + sel_up_to_down * 2**(CHUNK_NUM + 2) + sel_down_to_up * 2**(CHUNK_NUM + 3); + FLAGS[i] = flags; + } + + // Ensure the program is being followed by the MemAlign + lookup_proves(MEM_ALIGN_ROM_ID, [PC, DELTA_PC, DELTA_ADDR, OFFSET, WIDTH, FLAGS], multiplicity); +} \ No newline at end of file diff --git a/state-machines/mem/src/lib.rs b/state-machines/mem/src/lib.rs index 67bf225c..6e04d6e9 100644 --- a/state-machines/mem/src/lib.rs +++ b/state-machines/mem/src/lib.rs @@ -1,9 +1,17 @@ -mod mem; -mod mem_aligned; -mod mem_traces; -mod mem_unaligned; +mod mem_align_rom_sm; +mod mem_align_sm; +mod mem_constants; +mod mem_helpers; +mod mem_proxy; +mod mem_proxy_engine; +mod mem_sm; +mod mem_unmapped; -pub use mem::*; -pub use mem_aligned::*; -pub use mem_traces::*; -pub use mem_unaligned::*; +pub use mem_align_rom_sm::*; +pub use mem_align_sm::*; +pub use mem_constants::*; +pub use mem_helpers::*; +pub use mem_proxy::*; +pub use mem_proxy_engine::*; +pub use mem_sm::*; +pub use mem_unmapped::*; diff --git a/state-machines/mem/src/mem.rs b/state-machines/mem/src/mem.rs deleted file mode 100644 index 391bca7b..00000000 --- a/state-machines/mem/src/mem.rs +++ /dev/null @@ -1,101 +0,0 @@ -use std::sync::{ - atomic::{AtomicU32, Ordering}, - Arc, Mutex, -}; - -use crate::{MemAlignedSM, MemUnalignedSM}; -use p3_field::Field; -use rayon::Scope; -use sm_common::{MemOp, MemUnalignedOp, OpResult, Provable}; -use zisk_core::ZiskRequiredMemory; - -use proofman::{WitnessComponent, WitnessManager}; -use proofman_common::{ExecutionCtx, ProofCtx, SetupCtx}; - -#[allow(dead_code)] -const PROVE_CHUNK_SIZE: usize = 1 << 12; - -#[allow(dead_code)] -pub struct MemSM { - // Count of registered predecessors - registered_predecessors: AtomicU32, - - // Inputs - inputs_aligned: Mutex>, - inputs_unaligned: Mutex>, - - // Secondary State machines - mem_aligned_sm: Arc, - mem_unaligned_sm: Arc, -} - -impl MemSM { - pub fn new(wcm: Arc>) -> Arc { - let mem_aligned_sm = MemAlignedSM::new(wcm.clone()); - let mem_unaligned_sm = MemUnalignedSM::new(wcm.clone()); - - let mem_sm = Self { - registered_predecessors: AtomicU32::new(0), - inputs_aligned: Mutex::new(Vec::new()), - inputs_unaligned: Mutex::new(Vec::new()), - mem_aligned_sm: mem_aligned_sm.clone(), - mem_unaligned_sm: mem_unaligned_sm.clone(), - }; - let mem_sm = Arc::new(mem_sm); - - wcm.register_component(mem_sm.clone(), None, None); - - // For all the secondary state machines, register the main state machine as a predecessor - mem_sm.mem_aligned_sm.register_predecessor(); - mem_sm.mem_unaligned_sm.register_predecessor(); - - mem_sm - } - - pub fn register_predecessor(&self) { - self.registered_predecessors.fetch_add(1, Ordering::SeqCst); - } - - pub fn unregister_predecessor(&self, scope: &Scope) { - if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { - >::prove(self, &[], true, scope); - - self.mem_aligned_sm.unregister_predecessor::(scope); - self.mem_unaligned_sm.unregister_predecessor::(scope); - } - } -} - -impl WitnessComponent for MemSM { - fn calculate_witness( - &self, - _stage: u32, - _air_instance: Option, - _pctx: Arc>, - _ectx: Arc>, - _sctx: Arc>, - ) { - } -} - -impl Provable for MemSM { - fn calculate( - &self, - _operation: ZiskRequiredMemory, - ) -> Result> { - unimplemented!() - } - - fn prove(&self, _operations: &[ZiskRequiredMemory], _drain: bool, _scope: &Scope) { - // TODO! - } - - fn calculate_prove( - &self, - _operation: ZiskRequiredMemory, - _drain: bool, - _scope: &Scope, - ) -> Result> { - unimplemented!() - } -} diff --git a/state-machines/mem/src/mem_align_rom_sm.rs b/state-machines/mem/src/mem_align_rom_sm.rs new file mode 100644 index 00000000..df6081e9 --- /dev/null +++ b/state-machines/mem/src/mem_align_rom_sm.rs @@ -0,0 +1,220 @@ +use std::{ + collections::HashMap, + sync::{ + atomic::{AtomicU32, Ordering}, + Arc, Mutex, + }, +}; + +use log::info; +use p3_field::PrimeField; +use proofman::{WitnessComponent, WitnessManager}; +use proofman_common::AirInstance; + +use sm_common::create_prover_buffer; +use zisk_pil::{MemAlignRomRow, MemAlignRomTrace, MEM_ALIGN_ROM_AIR_IDS, ZISK_AIRGROUP_ID}; + +#[derive(Debug, Clone, Copy)] +pub enum MemOp { + OneRead, + OneWrite, + TwoReads, + TwoWrites, +} + +const OP_SIZES: [u64; 4] = [2, 3, 3, 5]; +const ONE_WORD_COMBINATIONS: u64 = 20; // (0..4,[1,2,4]), (5,6,[1,2]), (7,[1]) -> 5*3 + 2*2 + 1*1 = 20 +const TWO_WORD_COMBINATIONS: u64 = 11; // (1..4,[8]), (5,6,[4,8]), (7,[2,4,8]) -> 4*1 + 2*2 + 1*3 = 11 + +pub struct MemAlignRomSM { + // Witness computation manager + wcm: Arc>, + + // Count of registered predecessors + registered_predecessors: AtomicU32, + + // Rom data + num_rows: usize, + multiplicity: Mutex>, // row_num -> multiplicity +} + +#[derive(Debug)] +pub enum ExtensionTableSMErr { + InvalidOpcode, +} + +impl MemAlignRomSM { + const MY_NAME: &'static str = "MemAlignRom"; + + pub fn new(wcm: Arc>) -> Arc { + let pctx = wcm.get_pctx(); + let air = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_ALIGN_ROM_AIR_IDS[0]); + let num_rows = air.num_rows(); + + let mem_align_rom = Self { + wcm: wcm.clone(), + registered_predecessors: AtomicU32::new(0), + num_rows, + multiplicity: Mutex::new(HashMap::with_capacity(num_rows)), + }; + let mem_align_rom = Arc::new(mem_align_rom); + wcm.register_component( + mem_align_rom.clone(), + Some(ZISK_AIRGROUP_ID), + Some(MEM_ALIGN_ROM_AIR_IDS), + ); + + mem_align_rom + } + + pub fn register_predecessor(&self) { + self.registered_predecessors.fetch_add(1, Ordering::SeqCst); + } + + pub fn unregister_predecessor(&self) { + if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { + self.create_air_instance(); + } + } + + pub fn calculate_next_pc(&self, opcode: MemOp, offset: usize, width: usize) -> u64 { + // Get the table offset + let (table_offset, one_word) = match opcode { + MemOp::OneRead => (1, true), + + MemOp::OneWrite => (1 + ONE_WORD_COMBINATIONS * OP_SIZES[0], true), + + MemOp::TwoReads => ( + 1 + ONE_WORD_COMBINATIONS * OP_SIZES[0] + ONE_WORD_COMBINATIONS * OP_SIZES[1], + false, + ), + + MemOp::TwoWrites => ( + 1 + ONE_WORD_COMBINATIONS * OP_SIZES[0] + + ONE_WORD_COMBINATIONS * OP_SIZES[1] + + TWO_WORD_COMBINATIONS * OP_SIZES[2], + false, + ), + }; + + // Get the first row index + let first_row_idx = Self::get_first_row_idx(opcode, offset, width, table_offset, one_word); + + // Based on the program size, return the row indices + let opcode_idx = opcode as usize; + let op_size = OP_SIZES[opcode_idx]; + for i in 0..op_size { + let row_idx = first_row_idx + i; + // Check whether the row index is within the bounds + debug_assert!(row_idx < self.num_rows as u64); + // Update the multiplicity + self.update_multiplicity_by_row_idx(row_idx, 1); + } + + first_row_idx + } + + fn get_first_row_idx( + opcode: MemOp, + offset: usize, + width: usize, + table_offset: u64, + one_word: bool, + ) -> u64 { + let opcode_idx = opcode as usize; + let op_size = OP_SIZES[opcode_idx]; + + // Go to the actual operation + let mut first_row_idx = table_offset; + + // Go to the actual offset + let first_valid_offset = if one_word { 0 } else { 1 }; + for i in first_valid_offset..offset { + let possible_widths = Self::calculate_possible_widths(one_word, i); + first_row_idx += op_size * possible_widths.len() as u64; + } + + // Go to the right width + let width_idx = Self::calculate_possible_widths(one_word, offset) + .iter() + .position(|&w| w == width) + .expect("Invalid width"); + first_row_idx += op_size * width_idx as u64; + + first_row_idx + } + + fn calculate_possible_widths(one_word: bool, offset: usize) -> Vec { + // Calculate the ROM rows based on the requested opcode, offset, and width + match one_word { + true => match offset { + x if x <= 4 => vec![1, 2, 4], + x if x <= 6 => vec![1, 2], + x if x == 7 => vec![1], + _ => panic!("Invalid offset={}", offset), + }, + false => match offset { + x if x == 0 => panic!("Invalid offset={}", offset), + x if x <= 4 => vec![8], + x if x <= 6 => vec![4, 8], + x if x == 7 => vec![2, 4, 8], + _ => panic!("Invalid offset={}", offset), + }, + } + } + + pub fn update_padding_row(&self, padding_len: u64) { + // Update entry at the padding row (pos = 0) with the given padding length + self.update_multiplicity_by_row_idx(0, padding_len); + } + + pub fn update_multiplicity_by_row_idx(&self, row_idx: u64, mul: u64) { + let mut multiplicity = self.multiplicity.lock().unwrap(); + *multiplicity.entry(row_idx).or_insert(0) += mul; + } + + pub fn create_air_instance(&self) { + // Get the contexts + let wcm = self.wcm.clone(); + let pctx = wcm.get_pctx(); + let ectx = wcm.get_ectx(); + let sctx = wcm.get_sctx(); + + // Get the Mem Align ROM AIR + let air_mem_align_rom = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_ALIGN_ROM_AIR_IDS[0]); + let air_mem_align_rom_rows = air_mem_align_rom.num_rows(); + + // Create a prover buffer + let (mut prover_buffer, offset) = + create_prover_buffer(&ectx, &sctx, ZISK_AIRGROUP_ID, MEM_ALIGN_ROM_AIR_IDS[0]); + + // Create the Mem Align ROM trace buffer + let mut trace_buffer = MemAlignRomTrace::::map_buffer( + &mut prover_buffer, + air_mem_align_rom_rows, + offset as usize, + ) + .unwrap(); + + // Initialize the trace buffer to zero + for i in 0..air_mem_align_rom_rows { + trace_buffer[i] = MemAlignRomRow { multiplicity: F::zero() }; + } + + // Fill the trace buffer with the multiplicity values + if let Ok(multiplicity) = self.multiplicity.lock() { + for (row_idx, multiplicity) in multiplicity.iter() { + trace_buffer[*row_idx as usize] = + MemAlignRomRow { multiplicity: F::from_canonical_u64(*multiplicity) }; + } + } + + info!("{}: ··· Creating Mem Align Rom instance", Self::MY_NAME,); + + let air_instance = + AirInstance::new(sctx, ZISK_AIRGROUP_ID, MEM_ALIGN_ROM_AIR_IDS[0], None, prover_buffer); + pctx.air_instance_repo.add_air_instance(air_instance, None); + } +} + +impl WitnessComponent for MemAlignRomSM {} diff --git a/state-machines/mem/src/mem_align_sm.rs b/state-machines/mem/src/mem_align_sm.rs new file mode 100644 index 00000000..40e1e148 --- /dev/null +++ b/state-machines/mem/src/mem_align_sm.rs @@ -0,0 +1,1034 @@ +use core::panic; +use std::sync::{ + atomic::{AtomicU32, Ordering}, + Arc, Mutex, +}; + +use log::info; +use num_bigint::BigInt; +use num_traits::cast::ToPrimitive; +use p3_field::PrimeField; +use pil_std_lib::Std; +use proofman::{WitnessComponent, WitnessManager}; +use proofman_common::AirInstance; + +use sm_common::create_prover_buffer; +use zisk_pil::{MemAlignRow, MemAlignTrace, MEM_ALIGN_AIR_IDS, ZISK_AIRGROUP_ID}; + +use crate::{MemAlignInput, MemAlignRomSM, MemOp}; + +const RC: usize = 2; +const CHUNK_NUM: usize = 8; +const CHUNKS_BY_RC: usize = CHUNK_NUM / RC; +const CHUNK_BITS: usize = 8; +const RC_BITS: u64 = (CHUNKS_BY_RC * CHUNK_BITS) as u64; +const RC_MASK: u64 = (1 << RC_BITS) - 1; +const OFFSET_MASK: u32 = 0x07; +const OFFSET_BITS: u32 = 3; +const CHUNK_BITS_MASK: u64 = (1 << CHUNK_BITS) - 1; + +const fn generate_allowed_offsets() -> [u8; CHUNK_NUM] { + let mut offsets = [0; CHUNK_NUM]; + let mut i = 0; + while i < CHUNK_NUM { + offsets[i] = i as u8; + i += 1; + } + offsets +} + +const ALLOWED_OFFSETS: [u8; CHUNK_NUM] = generate_allowed_offsets(); +const ALLOWED_WIDTHS: [u8; 4] = [1, 2, 4, 8]; +const DEFAULT_OFFSET: u64 = 0; +const DEFAULT_WIDTH: u64 = 8; + +pub struct MemAlignResponse { + pub more_address: bool, + pub step: u64, + pub value: Option, +} +pub struct MemAlignSM { + // Witness computation manager + wcm: Arc>, + + // STD + std: Arc>, + + // Count of registered predecessors + registered_predecessors: AtomicU32, + + // Computed row information + rows: Mutex>>, + #[cfg(feature = "debug_mem_align")] + num_computed_rows: Mutex, + + // Secondary State machines + mem_align_rom_sm: Arc>, +} + +macro_rules! debug_info { + ($prefix:expr, $($arg:tt)*) => { + #[cfg(feature = "debug_mem_align")] + { + info!(concat!("MemAlign: ",$prefix), $($arg)*); + } + }; +} + +impl MemAlignSM { + const MY_NAME: &'static str = "MemAlign"; + + pub fn new( + wcm: Arc>, + std: Arc>, + mem_align_rom_sm: Arc>, + ) -> Arc { + let mem_align_sm = Self { + wcm: wcm.clone(), + std: std.clone(), + registered_predecessors: AtomicU32::new(0), + rows: Mutex::new(Vec::new()), + #[cfg(feature = "debug_mem_align")] + num_computed_rows: Mutex::new(0), + mem_align_rom_sm, + }; + let mem_align_sm = Arc::new(mem_align_sm); + + wcm.register_component( + mem_align_sm.clone(), + Some(ZISK_AIRGROUP_ID), + Some(MEM_ALIGN_AIR_IDS), + ); + + // Register the predecessors + std.register_predecessor(); + mem_align_sm.mem_align_rom_sm.register_predecessor(); + + mem_align_sm + } + + pub fn register_predecessor(&self) { + self.registered_predecessors.fetch_add(1, Ordering::SeqCst); + } + + pub fn unregister_predecessor(&self) { + if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { + let pctx = self.wcm.get_pctx(); + + // If there are remaining rows, generate the last instance + if let Ok(mut rows) = self.rows.lock() { + // Get the Mem Align AIR + let air_mem_align = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_ALIGN_AIR_IDS[0]); + + let rows_len = rows.len(); + debug_assert!(rows_len <= air_mem_align.num_rows()); + + let drained_rows = rows.drain(..rows_len).collect::>(); + + self.fill_new_air_instance(&drained_rows); + } + + self.mem_align_rom_sm.unregister_predecessor(); + self.std.unregister_predecessor(pctx, None); + } + } + + #[inline(always)] + pub fn get_mem_op(&self, input: &MemAlignInput, phase: usize) -> MemAlignResponse { + let addr = input.address; + let width = input.width; + + // Compute the width + debug_assert!( + ALLOWED_WIDTHS.contains(&width), + "Width={} is not allowed. Allowed widths are {:?}", + width, + ALLOWED_WIDTHS + ); + let width = width as usize; + + // Compute the offset + let offset = (addr & OFFSET_MASK) as u8; + debug_assert!( + ALLOWED_OFFSETS.contains(&offset), + "Offset={} is not allowed. Allowed offsets are {:?}", + offset, + ALLOWED_OFFSETS + ); + let offset = offset as usize; + + #[cfg(feature = "debug_mem_align")] + let num_rows = self.num_computed_rows.lock().unwrap(); + match (input.is_write, offset + width > CHUNK_NUM) { + (false, false) => { + /* RV with offset=2, width=4 + +----+----+====+====+====+====+----+----+ + | R0 | R1 | R2 | R3 | R4 | R5 | R6 | R7 | + +----+----+====+====+====+====+----+----+ + ⇓ + +----+----+====+====+====+====+----+----+ + | V6 | V7 | V0 | V1 | V2 | V3 | V4 | V5 | + +----+----+====+====+====+====+----+----+ + */ + debug_assert!(phase == 0); + + // Unaligned memory op information thrown into the bus + let step = input.step; + let value = input.value; + + // Get the aligned address + let addr_read = addr >> OFFSET_BITS; + + // Get the aligned value + let value_read = input.mem_values[phase]; + + // Get the next pc + let next_pc = + self.mem_align_rom_sm.calculate_next_pc(MemOp::OneRead, offset, width); + + let mut read_row = MemAlignRow:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u32(addr_read), + // delta_addr: F::zero(), + offset: F::from_canonical_u64(DEFAULT_OFFSET), + width: F::from_canonical_u64(DEFAULT_WIDTH), + // wr: F::from_bool(false), + // pc: F::from_canonical_u64(0), + reset: F::from_bool(true), + sel_up_to_down: F::from_bool(true), + ..Default::default() + }; + + let mut value_row = MemAlignRow:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u32(addr_read), + // delta_addr: F::zero(), + offset: F::from_canonical_usize(offset), + width: F::from_canonical_usize(width), + // wr: F::from_bool(false), + pc: F::from_canonical_u64(next_pc), + // reset: F::from_bool(false), + sel_prove: F::from_bool(true), + ..Default::default() + }; + + for i in 0..CHUNK_NUM { + read_row.reg[i] = F::from_canonical_u64(Self::get_byte(value_read, i, 0)); + if i >= offset && i < offset + width { + read_row.sel[i] = F::from_bool(true); + } + + value_row.reg[i] = + F::from_canonical_u64(Self::get_byte(value, i, CHUNK_NUM - offset)); + if i == offset { + value_row.sel[i] = F::from_bool(true); + } + } + + let mut _value_read = value_read; + let mut _value = value; + for i in 0..RC { + read_row.value[i] = F::from_canonical_u64(_value_read & RC_MASK); + value_row.value[i] = F::from_canonical_u64(_value & RC_MASK); + _value_read >>= RC_BITS; + _value >>= RC_BITS; + } + + #[rustfmt::skip] + debug_info!( + "\nOne Word Read\n\ + Num Rows: {:?}\n\ + Input: {:?}\n\ + Phase: {:?}\n\ + Value Read: {:?}\n\ + Value: {:?}\n\ + Flags Read: {:?}\n\ + Flags Value: {:?}", + [*num_rows, *num_rows + 1], + input, + phase, + value_read.to_le_bytes(), + value.to_le_bytes(), + [ + read_row.sel[0], read_row.sel[1], read_row.sel[2], read_row.sel[3], + read_row.sel[4], read_row.sel[5], read_row.sel[6], read_row.sel[7], + read_row.wr, read_row.reset, read_row.sel_up_to_down, read_row.sel_down_to_up + ], + [ + value_row.sel[0], value_row.sel[1], value_row.sel[2], value_row.sel[3], + value_row.sel[4], value_row.sel[5], value_row.sel[6], value_row.sel[7], + value_row.wr, value_row.reset, value_row.sel_up_to_down, value_row.sel_down_to_up + ] + ); + + #[cfg(feature = "debug_mem_align")] + drop(num_rows); + + // Prove the generated rows + self.prove(&[read_row, value_row]); + + MemAlignResponse { more_address: false, step, value: None } + } + (true, false) => { + /* RWV with offset=3, width=4 + +----+----+----+====+====+====+====+----+ + | R0 | R1 | R2 | R3 | R4 | R5 | R6 | R7 | + +----+----+----+====+====+====+====+----+ + ⇓ + +----+----+----+====+====+====+====+----+ + | W0 | W1 | W2 | W3 | W4 | W5 | W6 | W7 | + +----+----+----+====+====+====+====+----+ + ⇓ + +----+----+----+====+====+====+====+----+ + | V5 | V6 | V7 | V0 | V1 | V2 | V3 | V4 | + +----+----+----+====+====+====+====+----+ + */ + debug_assert!(phase == 0); + + // Unaligned memory op information thrown into the bus + let step = input.step; + let value = input.value; + + // Get the aligned address + let addr_read = addr >> OFFSET_BITS; + + // Get the aligned value + let value_read = input.mem_values[phase]; + + // Get the next pc + let next_pc = + self.mem_align_rom_sm.calculate_next_pc(MemOp::OneWrite, offset, width); + + // Compute the write value + let value_write = { + // with:1 offset:4 + let width_bytes: u64 = (1 << (width * CHUNK_BITS)) - 1; + + let mask: u64 = width_bytes << (offset * CHUNK_BITS); + + // Get the first width bytes of the unaligned value + let value_to_write = (value & width_bytes) << (offset * CHUNK_BITS); + + // Write zeroes to value_read from offset to offset + width + // and add the value to write to the value read + (value_read & !mask) | value_to_write + }; + + let mut read_row = MemAlignRow:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u32(addr_read), + // delta_addr: F::zero(), + offset: F::from_canonical_u64(DEFAULT_OFFSET), + width: F::from_canonical_u64(DEFAULT_WIDTH), + // wr: F::from_bool(false), + // pc: F::from_canonical_u64(0), + reset: F::from_bool(true), + sel_up_to_down: F::from_bool(true), + ..Default::default() + }; + + let mut write_row = MemAlignRow:: { + step: F::from_canonical_u64(step + 1), + addr: F::from_canonical_u32(addr_read), + // delta_addr: F::zero(), + offset: F::from_canonical_u64(DEFAULT_OFFSET), + width: F::from_canonical_u64(DEFAULT_WIDTH), + wr: F::from_bool(true), + pc: F::from_canonical_u64(next_pc), + // reset: F::from_bool(false), + sel_up_to_down: F::from_bool(true), + ..Default::default() + }; + + let mut value_row = MemAlignRow:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u32(addr_read), + // delta_addr: F::zero(), + offset: F::from_canonical_usize(offset), + width: F::from_canonical_usize(width), + wr: F::from_bool(true), + pc: F::from_canonical_u64(next_pc + 1), + // reset: F::from_bool(false), + sel_prove: F::from_bool(true), + ..Default::default() + }; + + for i in 0..CHUNK_NUM { + read_row.reg[i] = F::from_canonical_u64(Self::get_byte(value_read, i, 0)); + if i < offset || i >= offset + width { + read_row.sel[i] = F::from_bool(true); + } + + write_row.reg[i] = F::from_canonical_u64(Self::get_byte(value_write, i, 0)); + if i >= offset && i < offset + width { + write_row.sel[i] = F::from_bool(true); + } + + value_row.reg[i] = { + if i >= offset && i < offset + width { + write_row.reg[i] + } else { + F::from_canonical_u64(Self::get_byte(value, i, CHUNK_NUM - offset)) + } + }; + if i == offset { + value_row.sel[i] = F::from_bool(true); + } + } + + let mut _value_read = value_read; + let mut _value_write = value_write; + let mut _value = value; + for i in 0..RC { + read_row.value[i] = F::from_canonical_u64(_value_read & RC_MASK); + write_row.value[i] = F::from_canonical_u64(_value_write & RC_MASK); + value_row.value[i] = F::from_canonical_u64(_value & RC_MASK); + _value_read >>= RC_BITS; + _value_write >>= RC_BITS; + _value >>= RC_BITS; + } + + #[rustfmt::skip] + debug_info!( + "\nOne Word Write\n\ + Num Rows: {:?}\n\ + Input: {:?}\n\ + Phase: {:?}\n\ + Value Read: {:?}\n\ + Value Write: {:?}\n\ + Value: {:?}\n\ + Flags Read: {:?}\n\ + Flags Write: {:?}\n\ + Flags Value: {:?}", + [*num_rows, *num_rows + 2], + input, + phase, + value_read.to_le_bytes(), + value_write.to_le_bytes(), + value.to_le_bytes(), + [ + read_row.sel[0], read_row.sel[1], read_row.sel[2], read_row.sel[3], + read_row.sel[4], read_row.sel[5], read_row.sel[6], read_row.sel[7], + read_row.wr, read_row.reset, read_row.sel_up_to_down, read_row.sel_down_to_up + ], + [ + write_row.sel[0], write_row.sel[1], write_row.sel[2], write_row.sel[3], + write_row.sel[4], write_row.sel[5], write_row.sel[6], write_row.sel[7], + write_row.wr, write_row.reset, write_row.sel_up_to_down, write_row.sel_down_to_up + ], + [ + value_row.sel[0], value_row.sel[1], value_row.sel[2], value_row.sel[3], + value_row.sel[4], value_row.sel[5], value_row.sel[6], value_row.sel[7], + value_row.wr, value_row.reset, value_row.sel_up_to_down, value_row.sel_down_to_up + ] + ); + + #[cfg(feature = "debug_mem_align")] + drop(num_rows); + + // Prove the generated rows + self.prove(&[read_row, write_row, value_row]); + + MemAlignResponse { more_address: false, step, value: Some(value_write) } + } + (false, true) => { + /* RVR with offset=5, width=8 + +----+----+----+----+----+====+====+====+ + | R0 | R1 | R2 | R3 | R4 | R5 | R6 | R7 | + +----+----+----+----+----+====+====+====+ + ⇓ + +====+====+====+====+====+====+====+====+ + | V3 | V4 | V5 | V6 | V7 | V0 | V1 | V2 | + +====+====+====+====+====+====+====+====+ + ⇓ + +====+====+====+====+====+----+----+----+ + | R0 | R1 | R2 | R3 | R4 | R5 | R6 | R7 | + +====+====+====+====+====+----+----+----+ + */ + debug_assert!(phase == 0 || phase == 1); + + match phase { + // If phase == 0, do nothing, just ask for more + 0 => MemAlignResponse { more_address: true, step: input.step, value: None }, + + // Otherwise, do the RVR + 1 => { + debug_assert!(input.mem_values.len() == 2); + + // Unaligned memory op information thrown into the bus + let step = input.step; + let value = input.value; + + // Compute the remaining bytes + let rem_bytes = (offset + width) % CHUNK_NUM; + + // Get the aligned address + let addr_first_read = addr >> OFFSET_BITS; + let addr_second_read = addr_first_read + 1; + + // Get the aligned value + let value_first_read = input.mem_values[0]; + let value_second_read = input.mem_values[1]; + + // Get the next pc + let next_pc = + self.mem_align_rom_sm.calculate_next_pc(MemOp::TwoReads, offset, width); + + let mut first_read_row = MemAlignRow:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u32(addr_first_read), + // delta_addr: F::zero(), + offset: F::from_canonical_u64(DEFAULT_OFFSET), + width: F::from_canonical_u64(DEFAULT_WIDTH), + // wr: F::from_bool(false), + // pc: F::from_canonical_u64(0), + reset: F::from_bool(true), + sel_up_to_down: F::from_bool(true), + ..Default::default() + }; + + let mut value_row = MemAlignRow:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u32(addr_first_read), + // delta_addr: F::zero(), + offset: F::from_canonical_usize(offset), + width: F::from_canonical_usize(width), + // wr: F::from_bool(false), + pc: F::from_canonical_u64(next_pc), + // reset: F::from_bool(false), + sel_prove: F::from_bool(true), + ..Default::default() + }; + + let mut second_read_row = MemAlignRow:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u32(addr_second_read), + delta_addr: F::one(), + offset: F::from_canonical_u64(DEFAULT_OFFSET), + width: F::from_canonical_u64(DEFAULT_WIDTH), + // wr: F::from_bool(false), + pc: F::from_canonical_u64(next_pc + 1), + // reset: F::from_bool(false), + sel_down_to_up: F::from_bool(true), + ..Default::default() + }; + + for i in 0..CHUNK_NUM { + first_read_row.reg[i] = + F::from_canonical_u64(Self::get_byte(value_first_read, i, 0)); + if i >= offset { + first_read_row.sel[i] = F::from_bool(true); + } + + value_row.reg[i] = + F::from_canonical_u64(Self::get_byte(value, i, CHUNK_NUM - offset)); + + if i == offset { + value_row.sel[i] = F::from_bool(true); + } + + second_read_row.reg[i] = + F::from_canonical_u64(Self::get_byte(value_second_read, i, 0)); + if i < rem_bytes { + second_read_row.sel[i] = F::from_bool(true); + } + } + + let mut _value_first_read = value_first_read; + let mut _value = value; + let mut _value_second_read = value_second_read; + for i in 0..RC { + first_read_row.value[i] = + F::from_canonical_u64(_value_first_read & RC_MASK); + value_row.value[i] = F::from_canonical_u64(_value & RC_MASK); + second_read_row.value[i] = + F::from_canonical_u64(_value_second_read & RC_MASK); + _value_first_read >>= RC_BITS; + _value >>= RC_BITS; + _value_second_read >>= RC_BITS; + } + + #[rustfmt::skip] + debug_info!( + "\nTwo Words Read\n\ + Num Rows: {:?}\n\ + Input: {:?}\n\ + Phase: {:?}\n\ + Value First Read: {:?}\n\ + Value: {:?}\n\ + Value Second Read: {:?}\n\ + Flags First Read: {:?}\n\ + Flags Value: {:?}\n\ + Flags Second Read: {:?}", + [*num_rows, *num_rows + 2], + input, + phase, + value_first_read.to_le_bytes(), + value.to_le_bytes(), + value_second_read.to_le_bytes(), + [ + first_read_row.sel[0], first_read_row.sel[1], first_read_row.sel[2], first_read_row.sel[3], + first_read_row.sel[4], first_read_row.sel[5], first_read_row.sel[6], first_read_row.sel[7], + first_read_row.wr, first_read_row.reset, first_read_row.sel_up_to_down, first_read_row.sel_down_to_up + ], + [ + value_row.sel[0], value_row.sel[1], value_row.sel[2], value_row.sel[3], + value_row.sel[4], value_row.sel[5], value_row.sel[6], value_row.sel[7], + value_row.wr, value_row.reset, value_row.sel_up_to_down, value_row.sel_down_to_up + ], + [ + second_read_row.sel[0], second_read_row.sel[1], second_read_row.sel[2], second_read_row.sel[3], + second_read_row.sel[4], second_read_row.sel[5], second_read_row.sel[6], second_read_row.sel[7], + second_read_row.wr, second_read_row.reset, second_read_row.sel_up_to_down, second_read_row.sel_down_to_up + ] + ); + + #[cfg(feature = "debug_mem_align")] + drop(num_rows); + + // Prove the generated rows + self.prove(&[first_read_row, value_row, second_read_row]); + + MemAlignResponse { more_address: false, step, value: None } + } + _ => panic!("Invalid phase={}", phase), + } + } + (true, true) => { + /* RWVWR with offset=6, width=4 + +----+----+----+----+----+----+====+====+ + | R0 | R1 | R2 | R3 | R4 | R5 | R6 | R7 | + +----+----+----+----+----+----+====+====+ + ⇓ + +----+----+----+----+----+----+====+====+ + | W0 | W1 | W2 | W3 | W4 | W5 | W6 | W7 | + +----+----+----+----+----+----+====+====+ + ⇓ + +====+====+----+----+----+----+====+====+ + | V2 | V3 | V4 | V5 | V6 | V7 | V0 | V1 | + +====+====+----+----+----+----+====+====+ + ⇓ + +====+====+----+----+----+----+----+----+ + | W0 | W1 | W2 | W3 | W4 | W5 | W6 | W7 | + +====+====+----+----+----+----+----+----+ + ⇓ + +====+====+----+----+----+----+----+----+ + | R0 | R1 | R2 | R3 | R4 | R5 | R6 | R7 | + +====+====+----+----+----+----+----+----+ + */ + debug_assert!(phase == 0 || phase == 1); + + match phase { + // If phase == 0, compute the resulting write value and ask for more + 0 => { + debug_assert!(input.mem_values.len() == 1); + + // Unaligned memory op information thrown into the bus + let value = input.value; + let step = input.step; + + // Get the aligned value + let value_first_read = input.mem_values[0]; + + // Compute the write value + let value_first_write = { + // Normalize the width + let width_norm = CHUNK_NUM - offset; + + let width_bytes: u64 = (1 << (width_norm * CHUNK_BITS)) - 1; + + let mask: u64 = width_bytes << (offset * CHUNK_BITS); + + // Get the first width bytes of the unaligned value + let value_to_write = (value & width_bytes) << (offset * CHUNK_BITS); + + // Write zeroes to value_read from offset to offset + width + // and add the value to write to the value read + (value_first_read & !mask) | value_to_write + }; + + MemAlignResponse { + more_address: true, + step, + value: Some(value_first_write), + } + } + // Otherwise, do the RWVRW + 1 => { + debug_assert!(input.mem_values.len() == 2); + + // Unaligned memory op information thrown into the bus + let step = input.step; + let value = input.value; + + // Compute the shift + let rem_bytes = (offset + width) % CHUNK_NUM; + + // Get the aligned address + let addr_first_read_write = addr >> OFFSET_BITS; + let addr_second_read_write = addr_first_read_write + 1; + + // Get the first aligned value + let value_first_read = input.mem_values[0]; + + // Recompute the first write value + let value_first_write = { + // Normalize the width + let width_norm = CHUNK_NUM - offset; + + let width_bytes: u64 = (1 << (width_norm * CHUNK_BITS)) - 1; + + let mask: u64 = width_bytes << (offset * CHUNK_BITS); + + // Get the first width bytes of the unaligned value + let value_to_write = (value & width_bytes) << (offset * CHUNK_BITS); + + // Write zeroes to value_read from offset to offset + width + // and add the value to write to the value read + (value_first_read & !mask) | value_to_write + }; + + // Get the second aligned value + let value_second_read = input.mem_values[1]; + + // Compute the second write value + let value_second_write = { + // Normalize the width + let width_norm = CHUNK_NUM - offset; + + let mask: u64 = (1 << (rem_bytes * CHUNK_BITS)) - 1; + + // Get the first width bytes of the unaligned value + let value_to_write = (value >> width_norm * CHUNK_BITS) & mask; + + // Write zeroes to value_read from 0 to offset + width + // and add the value to write to the value read + (value_second_read & !mask) | value_to_write + }; + + // Get the next pc + let next_pc = self.mem_align_rom_sm.calculate_next_pc( + MemOp::TwoWrites, + offset, + width, + ); + + // RWVWR + let mut first_read_row = MemAlignRow:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u32(addr_first_read_write), + // delta_addr: F::zero(), + offset: F::from_canonical_u64(DEFAULT_OFFSET), + width: F::from_canonical_u64(DEFAULT_WIDTH), + // wr: F::from_bool(false), + // pc: F::from_canonical_u64(0), + reset: F::from_bool(true), + sel_up_to_down: F::from_bool(true), + ..Default::default() + }; + + let mut first_write_row = MemAlignRow:: { + step: F::from_canonical_u64(step + 1), + addr: F::from_canonical_u32(addr_first_read_write), + // delta_addr: F::zero(), + offset: F::from_canonical_u64(DEFAULT_OFFSET), + width: F::from_canonical_u64(DEFAULT_WIDTH), + wr: F::from_bool(true), + pc: F::from_canonical_u64(next_pc), + // reset: F::from_bool(false), + sel_up_to_down: F::from_bool(true), + ..Default::default() + }; + + let mut value_row = MemAlignRow:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u32(addr_first_read_write), + // delta_addr: F::zero(), + offset: F::from_canonical_usize(offset), + width: F::from_canonical_usize(width), + wr: F::from_bool(true), + pc: F::from_canonical_u64(next_pc + 1), + // reset: F::from_bool(false), + sel_prove: F::from_bool(true), + ..Default::default() + }; + + let mut second_write_row = MemAlignRow:: { + step: F::from_canonical_u64(step + 1), + addr: F::from_canonical_u32(addr_second_read_write), + delta_addr: F::one(), + offset: F::from_canonical_u64(DEFAULT_OFFSET), + width: F::from_canonical_u64(DEFAULT_WIDTH), + wr: F::from_bool(true), + pc: F::from_canonical_u64(next_pc + 2), + // reset: F::from_bool(false), + sel_down_to_up: F::from_bool(true), + ..Default::default() + }; + + let mut second_read_row = MemAlignRow:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u32(addr_second_read_write), + // delta_addr: F::zero(), + offset: F::from_canonical_u64(DEFAULT_OFFSET), + width: F::from_canonical_u64(DEFAULT_WIDTH), + // wr: F::from_bool(false), + pc: F::from_canonical_u64(next_pc + 3), + reset: F::from_bool(false), + sel_down_to_up: F::from_bool(true), + ..Default::default() + }; + + for i in 0..CHUNK_NUM { + first_read_row.reg[i] = + F::from_canonical_u64(Self::get_byte(value_first_read, i, 0)); + if i < offset { + first_read_row.sel[i] = F::from_bool(true); + } + + first_write_row.reg[i] = + F::from_canonical_u64(Self::get_byte(value_first_write, i, 0)); + if i >= offset { + first_write_row.sel[i] = F::from_bool(true); + } + + value_row.reg[i] = { + if i < rem_bytes { + second_write_row.reg[i] + } else if i >= offset { + first_write_row.reg[i] + } else { + F::from_canonical_u64(Self::get_byte( + value, + i, + CHUNK_NUM - offset, + )) + } + }; + if i == offset { + value_row.sel[i] = F::from_bool(true); + } + + second_write_row.reg[i] = + F::from_canonical_u64(Self::get_byte(value_second_write, i, 0)); + if i < rem_bytes { + second_write_row.sel[i] = F::from_bool(true); + } + + second_read_row.reg[i] = + F::from_canonical_u64(Self::get_byte(value_second_read, i, 0)); + if i >= rem_bytes { + second_read_row.sel[i] = F::from_bool(true); + } + } + + let mut _value_first_read = value_first_read; + let mut _value_first_write = value_first_write; + let mut _value = value; + let mut _value_second_write = value_second_write; + let mut _value_second_read = value_second_read; + for i in 0..RC { + first_read_row.value[i] = + F::from_canonical_u64(_value_first_read & RC_MASK); + first_write_row.value[i] = + F::from_canonical_u64(_value_first_write & RC_MASK); + value_row.value[i] = F::from_canonical_u64(_value & RC_MASK); + second_write_row.value[i] = + F::from_canonical_u64(_value_second_write & RC_MASK); + second_read_row.value[i] = + F::from_canonical_u64(_value_second_read & RC_MASK); + _value_first_read >>= RC_BITS; + _value_first_write >>= RC_BITS; + _value >>= RC_BITS; + _value_second_write >>= RC_BITS; + _value_second_read >>= RC_BITS; + } + + #[rustfmt::skip] + debug_info!( + "\nTwo Words Write\n\ + Num Rows: {:?}\n\ + Input: {:?}\n\ + Phase: {:?}\n\ + Value First Read: {:?}\n\ + Value First Write: {:?}\n\ + Value: {:?}\n\ + Value Second Read: {:?}\n\ + Value Second Write: {:?}\n\ + Flags First Read: {:?}\n\ + Flags First Write: {:?}\n\ + Flags Value: {:?}\n\ + Flags Second Write: {:?}\n\ + Flags Second Read: {:?}", + [*num_rows, *num_rows + 4], + input, + phase, + value_first_read.to_le_bytes(), + value_first_write.to_le_bytes(), + value.to_le_bytes(), + value_second_write.to_le_bytes(), + value_second_read.to_le_bytes(), + [ + first_read_row.sel[0], first_read_row.sel[1], first_read_row.sel[2], first_read_row.sel[3], + first_read_row.sel[4], first_read_row.sel[5], first_read_row.sel[6], first_read_row.sel[7], + first_read_row.wr, first_read_row.reset, first_read_row.sel_up_to_down, first_read_row.sel_down_to_up + ], + [ + first_write_row.sel[0], first_write_row.sel[1], first_write_row.sel[2], first_write_row.sel[3], + first_write_row.sel[4], first_write_row.sel[5], first_write_row.sel[6], first_write_row.sel[7], + first_write_row.wr, first_write_row.reset, first_write_row.sel_up_to_down, first_write_row.sel_down_to_up + ], + [ + value_row.sel[0], value_row.sel[1], value_row.sel[2], value_row.sel[3], + value_row.sel[4], value_row.sel[5], value_row.sel[6], value_row.sel[7], + value_row.wr, value_row.reset, value_row.sel_up_to_down, value_row.sel_down_to_up + ], + [ + second_write_row.sel[0], second_write_row.sel[1], second_write_row.sel[2], second_write_row.sel[3], + second_write_row.sel[4], second_write_row.sel[5], second_write_row.sel[6], second_write_row.sel[7], + second_write_row.wr, second_write_row.reset, second_write_row.sel_up_to_down, second_write_row.sel_down_to_up + ], + [ + second_read_row.sel[0], second_read_row.sel[1], second_read_row.sel[2], second_read_row.sel[3], + second_read_row.sel[4], second_read_row.sel[5], second_read_row.sel[6], second_read_row.sel[7], + second_read_row.wr, second_read_row.reset, second_read_row.sel_up_to_down, second_read_row.sel_down_to_up + ] + ); + + #[cfg(feature = "debug_mem_align")] + drop(num_rows); + + // Prove the generated rows + self.prove(&[ + first_read_row, + first_write_row, + value_row, + second_write_row, + second_read_row, + ]); + + MemAlignResponse { + more_address: false, + step, + value: Some(value_second_write), + } + } + _ => panic!("Invalid phase={}", phase), + } + } + } + } + + fn get_byte(value: u64, index: usize, offset: usize) -> u64 { + let chunk = (offset + index) % CHUNK_NUM; + (value >> (chunk * CHUNK_BITS)) & CHUNK_BITS_MASK + } + + pub fn prove(&self, computed_rows: &[MemAlignRow]) { + if let Ok(mut rows) = self.rows.lock() { + rows.extend_from_slice(computed_rows); + + #[cfg(feature = "debug_mem_align")] + { + let mut num_rows = self.num_computed_rows.lock().unwrap(); + *num_rows += computed_rows.len(); + drop(num_rows); + } + + let pctx = self.wcm.get_pctx(); + let air_mem_align = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_ALIGN_AIR_IDS[0]); + + while rows.len() >= air_mem_align.num_rows() { + let num_drained = std::cmp::min(air_mem_align.num_rows(), rows.len()); + let drained_rows = rows.drain(..num_drained).collect::>(); + + self.fill_new_air_instance(&drained_rows); + } + } + } + + fn fill_new_air_instance(&self, rows: &[MemAlignRow]) { + // Get the proof context + let wcm = self.wcm.clone(); + let pctx = wcm.get_pctx(); + + // Get the Mem Align AIR + let air_mem_align = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_ALIGN_AIR_IDS[0]); + let air_mem_align_rows = air_mem_align.num_rows(); + let rows_len = rows.len(); + + // You cannot feed to the AIR more rows than it has + debug_assert!(rows_len <= air_mem_align_rows); + + // Get the execution and setup context + let ectx = wcm.get_ectx(); + let sctx = wcm.get_sctx(); + + // Create a prover buffer + let (mut prover_buffer, offset) = + create_prover_buffer(&ectx, &sctx, ZISK_AIRGROUP_ID, MEM_ALIGN_AIR_IDS[0]); + + // Create a Mem Align trace buffer + let mut trace_buffer = + MemAlignTrace::::map_buffer(&mut prover_buffer, air_mem_align_rows, offset as usize) + .unwrap(); + + let mut reg_range_check: Vec = vec![0; 1 << CHUNK_BITS]; + + // Add the input rows to the trace + for (i, &row) in rows.iter().enumerate() { + // Store the entire row + trace_buffer[i] = row; + // Store the value of all reg columns so that they can be range checked + for j in 0..CHUNK_NUM { + let element = + row.reg[j].as_canonical_biguint().to_usize().expect("Cannot convert to usize"); + reg_range_check[element] += 1; + } + } + + // Pad the remaining rows with trivially satisfying rows + let padding_row = MemAlignRow:: { reset: F::from_bool(true), ..Default::default() }; + let padding_size = air_mem_align_rows - rows_len; + + // Store the padding rows + for i in rows_len..air_mem_align_rows { + trace_buffer[i] = padding_row; + } + + // Store the value of all padding reg columns so that they can be range checked + for _ in 0..CHUNK_NUM { + reg_range_check[0] += padding_size as u64; + } + + // Perform the range checks + let std = self.std.clone(); + let range_id = std.get_range(BigInt::from(0), BigInt::from(CHUNK_BITS_MASK), None); + for (value, &multiplicity) in reg_range_check.iter().enumerate() { + std.range_check( + F::from_canonical_usize(value), + F::from_canonical_u64(multiplicity), + range_id, + ); + } + + // Compute the program multiplicity + let mem_align_rom_sm = self.mem_align_rom_sm.clone(); + mem_align_rom_sm.update_padding_row(padding_size as u64); + + info!( + "{}: ··· Creating Mem Align instance [{} / {} rows filled {:.2}%]", + Self::MY_NAME, + rows_len, + air_mem_align_rows, + rows_len as f64 / air_mem_align_rows as f64 * 100.0 + ); + + // Add a new Mem Align instance + let air_instance = + AirInstance::new(sctx, ZISK_AIRGROUP_ID, MEM_ALIGN_AIR_IDS[0], None, prover_buffer); + pctx.air_instance_repo.add_air_instance(air_instance, None); + } +} + +impl WitnessComponent for MemAlignSM {} diff --git a/state-machines/mem/src/mem_aligned.rs b/state-machines/mem/src/mem_aligned.rs deleted file mode 100644 index 47feebfb..00000000 --- a/state-machines/mem/src/mem_aligned.rs +++ /dev/null @@ -1,112 +0,0 @@ -use std::sync::{ - atomic::{AtomicU32, Ordering}, - Arc, Mutex, -}; - -use p3_field::Field; -use proofman::{WitnessComponent, WitnessManager}; -use proofman_common::{ExecutionCtx, ProofCtx, SetupCtx}; -use rayon::Scope; -use sm_common::{MemOp, OpResult, Provable}; -use zisk_pil::{MEM_AIRGROUP_ID, MEM_ALIGN_AIR_IDS}; - -const PROVE_CHUNK_SIZE: usize = 1 << 12; - -pub struct MemAlignedSM { - // Count of registered predecessors - registered_predecessors: AtomicU32, - - // Inputs - inputs: Mutex>, -} - -#[allow(unused, unused_variables)] -impl MemAlignedSM { - pub fn new(wcm: Arc>) -> Arc { - let mem_aligned_sm = - Self { registered_predecessors: AtomicU32::new(0), inputs: Mutex::new(Vec::new()) }; - let mem_aligned_sm = Arc::new(mem_aligned_sm); - - wcm.register_component( - mem_aligned_sm.clone(), - Some(MEM_AIRGROUP_ID), - Some(MEM_ALIGN_AIR_IDS), - ); - - mem_aligned_sm - } - - pub fn register_predecessor(&self) { - self.registered_predecessors.fetch_add(1, Ordering::SeqCst); - } - - pub fn unregister_predecessor(&self, scope: &Scope) { - if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { - >::prove(self, &[], true, scope); - } - } - - fn read( - &self, - _addr: u64, /* , _ctx: &mut ProofCtx, _ectx: &ExecutionCtx */ - ) -> Result> { - Ok((0, true)) - } - - fn write( - &self, - _addr: u64, - _val: u64, /* , _ctx: &mut ProofCtx, _ectx: &ExecutionCtx */ - ) -> Result> { - Ok((0, true)) - } -} - -impl WitnessComponent for MemAlignedSM { - fn calculate_witness( - &self, - _stage: u32, - _air_instance: Option, - _pctx: Arc>, - _ectx: Arc>, - _sctx: Arc>, - ) { - } -} - -impl Provable for MemAlignedSM { - fn calculate(&self, operation: MemOp) -> Result> { - match operation { - MemOp::Read(addr) => self.read(addr), - MemOp::Write(addr, val) => self.write(addr, val), - } - } - - fn prove(&self, operations: &[MemOp], drain: bool, scope: &Scope) { - if let Ok(mut inputs) = self.inputs.lock() { - inputs.extend_from_slice(operations); - - while inputs.len() >= PROVE_CHUNK_SIZE || (drain && !inputs.is_empty()) { - let num_drained = std::cmp::min(PROVE_CHUNK_SIZE, inputs.len()); - let _drained_inputs = inputs.drain(..num_drained).collect::>(); - - scope.spawn(move |_| { - // TODO! Implement prove drained_inputs (a chunk of operations) - }); - } - } - } - - fn calculate_prove( - &self, - operation: MemOp, - drain: bool, - scope: &Scope, - ) -> Result> { - let result = self.calculate(operation.clone()); - - self.prove(&[operation], drain, scope); - - result - } -} diff --git a/state-machines/mem/src/mem_constants.rs b/state-machines/mem/src/mem_constants.rs index 4e177ee3..cb113775 100644 --- a/state-machines/mem/src/mem_constants.rs +++ b/state-machines/mem/src/mem_constants.rs @@ -9,4 +9,4 @@ pub const MEM_STEP_MASK: u64 = (1 << MEM_STEP_BITS) - 1; // 256 MB pub const MEM_ADDR_BITS: u64 = 64 - MEM_STEP_BITS; pub const MAX_MEM_STEP: u64 = (1 << MEM_STEP_BITS) - 1; -pub const MAX_MEM_ADDR: u64 = (1 << MEM_ADDR_BITS) - 1; +pub const MAX_MEM_ADDR: u64 = 0xFFFF_FFFF; diff --git a/state-machines/mem/src/mem_helpers.rs b/state-machines/mem/src/mem_helpers.rs index ac4ca198..3f4db4d9 100644 --- a/state-machines/mem/src/mem_helpers.rs +++ b/state-machines/mem/src/mem_helpers.rs @@ -1,4 +1,4 @@ -use crate::MemAlignResponse; +use crate::{MemAlignResponse, MEM_BYTES}; use std::fmt; use zisk_core::ZiskRequiredMemory; @@ -12,6 +12,73 @@ fn format_u64_hex(value: u64) -> String { .join("_") } +const MAX_MEM_STEP_OFFSET: u64 = 2; +const MAX_MEM_OPS_PER_MAIN_STEP: u64 = (MAX_MEM_STEP_OFFSET + 1) * 2; + +#[derive(Debug, Clone)] +pub struct MemAlignInput { + pub address: u32, + pub is_write: bool, + pub width: u8, + pub step: u64, + pub value: u64, + pub mem_values: [u64; 2], +} + +#[derive(Debug, Clone)] +pub struct MemInput { + pub address: u32, + pub is_write: bool, + pub step: u64, + pub value: u64, +} + +impl MemInput { + pub fn new(address: u32, is_write: bool, step: u64, value: u64) -> Self { + MemInput { address, is_write, step, value } + } + pub fn from(mem_op: &ZiskRequiredMemory) -> Self { + // debug_assert_eq!(mem_op.width, MEM_BYTES as u8); + MemInput { + address: mem_op.address, + is_write: mem_op.is_write, + step: MemHelpers::main_step_to_address_step(mem_op.step, mem_op.step_offset), + value: mem_op.value, + } + } +} + +impl MemAlignInput { + pub fn new( + address: u32, + is_write: bool, + width: u8, + step: u64, + value: u64, + mem_values: [u64; 2], + ) -> Self { + MemAlignInput { address, is_write, width, step, value, mem_values } + } + pub fn from(mem_op: &MemInput, width: u8, mem_values: &[u64; 2]) -> Self { + MemAlignInput { + address: mem_op.address, + is_write: mem_op.is_write, + step: mem_op.step, + width, + value: mem_op.value, + mem_values: [mem_values[0], mem_values[1]], + } + } +} + +pub struct MemHelpers {} + +impl MemHelpers { + pub fn main_step_to_address_step(step: u64, step_offset: u8) -> u64 { + 1 + MAX_MEM_OPS_PER_MAIN_STEP * step + 2 * step_offset as u64 + } +} + impl fmt::Debug for MemAlignResponse { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( @@ -41,8 +108,8 @@ pub fn mem_align_call( more_address: double_address, step: mem_op.step + 1, value: Some( - (mem_value & (0xFFFF_FFFF_FFFF_FFFFu64 ^ (mask << offset))) - | ((mem_op.value & mask) << offset), + (mem_value & (0xFFFF_FFFF_FFFF_FFFFu64 ^ (mask << offset))) | + ((mem_op.value & mask) << offset), ), } } else { @@ -50,8 +117,8 @@ pub fn mem_align_call( more_address: false, step: mem_op.step + 1, value: Some( - (mem_value & (0xFFFF_FFFF_FFFF_FFFFu64 << (offset + width as u32 - 64))) - | ((mem_op.value & mask) >> (128 - (offset + width as u32))), + (mem_value & (0xFFFF_FFFF_FFFF_FFFFu64 << (offset + width as u32 - 64))) | + ((mem_op.value & mask) >> (128 - (offset + width as u32))), ), } } diff --git a/state-machines/mem/src/mem_proxy.rs b/state-machines/mem/src/mem_proxy.rs new file mode 100644 index 00000000..b54a0d18 --- /dev/null +++ b/state-machines/mem/src/mem_proxy.rs @@ -0,0 +1,67 @@ +use std::sync::{ + atomic::{AtomicU32, Ordering}, + Arc, +}; + +use crate::{MemAlignRomSM, MemAlignSM, MemProxyEngine, MemSM}; +use p3_field::PrimeField; +use pil_std_lib::Std; +use zisk_core::ZiskRequiredMemory; + +use proofman::{WitnessComponent, WitnessManager}; + +pub struct MemProxy { + // Count of registered predecessors + registered_predecessors: AtomicU32, + + // Secondary State machines + mem_sm: Arc>, + mem_align_sm: Arc>, + mem_align_rom_sm: Arc>, +} + +impl MemProxy { + pub fn new(wcm: Arc>, std: Arc>) -> Arc { + let mem_align_rom_sm = MemAlignRomSM::new(wcm.clone()); + let mem_align_sm = MemAlignSM::new(wcm.clone(), std.clone(), mem_align_rom_sm.clone()); + let mem_sm = MemSM::new(wcm.clone(), std); + + let mem_proxy = Self { + registered_predecessors: AtomicU32::new(0), + mem_align_sm, + mem_align_rom_sm, + mem_sm, + }; + let mem_proxy = Arc::new(mem_proxy); + + wcm.register_component(mem_proxy.clone(), None, None); + + // For all the secondary state machines, register the main state machine as a predecessor + mem_proxy.mem_align_rom_sm.register_predecessor(); + mem_proxy.mem_align_sm.register_predecessor(); + mem_proxy.mem_sm.register_predecessor(); + mem_proxy + } + pub fn register_predecessor(&self) { + self.registered_predecessors.fetch_add(1, Ordering::SeqCst); + } + + pub fn unregister_predecessor(&self) { + if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { + self.mem_align_rom_sm.unregister_predecessor(); + self.mem_align_sm.unregister_predecessor(); + self.mem_sm.unregister_predecessor(); + } + } + + pub fn prove( + &self, + mem_operations: &mut Vec, + ) -> Result<(), Box> { + let mut engine = MemProxyEngine::::new(); + engine.add_module("mem", self.mem_sm.clone()); + engine.prove(&self.mem_align_sm, mem_operations) + } +} + +impl WitnessComponent for MemProxy {} diff --git a/state-machines/mem/src/mem_proxy_engine.rs b/state-machines/mem/src/mem_proxy_engine.rs new file mode 100644 index 00000000..2b167108 --- /dev/null +++ b/state-machines/mem/src/mem_proxy_engine.rs @@ -0,0 +1,470 @@ +use std::{collections::VecDeque, sync::Arc}; + +use crate::{ + MemAlignInput, MemAlignResponse, MemAlignSM, MemHelpers, MemInput, MemUnmapped, MAX_MEM_ADDR, + MAX_MEM_OPS_PER_MAIN_STEP, MEM_ADDR_MASK, MEM_BYTES, +}; +use log::info; +use p3_field::PrimeField; +use proofman_util::{timer_start_debug, timer_stop_and_log_debug}; +use zisk_core::ZiskRequiredMemory; + +#[cfg(feature = "debug_mem_proxy_engine")] +const DEBUG_ADDR: u32 = 0xA0008F10; + +macro_rules! debug_info { + ($prefix:expr, $($arg:tt)*) => { + #[cfg(feature = "debug_mem_proxy_engine")] + { + info!(concat!("MemPE : ",$prefix), $($arg)*); + } + }; +} + +pub trait MemModule: Send + Sync { + fn send_inputs(&self, mem_op: &[MemInput]); + fn get_addr_ranges(&self) -> Vec<(u32, u32)>; + fn get_flush_input_size(&self) -> u32; +} + +trait MemAlignSm { + fn get_mem_op(&self, mem_op: &MemInput, phase: u8) -> MemAlignResponse; +} + +struct MemModuleData { + pub name: String, + pub inputs: Vec, + pub flush_input_size: u32, +} + +struct MemAlignOperation { + addr: u32, + input: MemAlignInput, +} + +#[derive(Debug)] +pub struct AddressRegion { + from_address: u32, + to_address: u32, + module_id: u8, +} +pub struct MemProxyEngine { + modules: Vec>>, + modules_data: Vec, + open_mem_align_ops: VecDeque, + address_map: Vec, + address_map_closed: bool, + last_addr: u32, + last_addr_value: u64, + current_module_id: usize, + current_module: String, + module_end_addr: u32, +} + +impl MemProxyEngine { + pub fn new() -> Self { + Self { + modules: Vec::new(), + modules_data: Vec::new(), + last_addr: 0, + last_addr_value: 0, + current_module_id: 0, + current_module: String::new(), + module_end_addr: 0, + open_mem_align_ops: VecDeque::new(), + address_map: Vec::new(), + address_map_closed: false, + } + } + + pub fn add_module(&mut self, name: &str, module: Arc>) { + if self.modules.is_empty() { + self.current_module = String::from(name); + } + let module_id = self.modules.len() as u8; + self.modules.push(module.clone()); + + let ranges = module.get_addr_ranges(); + let flush_input_size = module.get_flush_input_size(); + + for range in ranges.iter() { + debug_info!("adding range 0x{:X} 0x{:X}", range.0, range.1); + self.insert_address_range(range.0, range.1, module_id); + } + self.modules_data.push(MemModuleData { + name: String::from(name), + inputs: Vec::new(), + flush_input_size, + }); + } + /* insert in sort way the address map and verify that */ + fn insert_address_range(&mut self, from_address: u32, to_address: u32, module_id: u8) { + let region = AddressRegion { from_address, to_address, module_id }; + if let Some(index) = self.address_map.iter().position(|x| x.from_address >= from_address) { + self.address_map.insert(index, region); + } else { + self.address_map.push(region); + } + } + + pub fn prove( + &mut self, + mem_align_sm: &MemAlignSM, + mem_operations: &mut Vec, + ) -> Result<(), Box> { + self.init_prove(); + + // Step 1. Sort the aligned memory accesses + // original vector is sorted by step, sort_by_key is stable, no reordering of elements with + // the same key. + timer_start_debug!(MEM_SORT); + mem_operations.sort_by_key(|mem| (mem.address & 0xFFFF_FFF8)); + timer_stop_and_log_debug!(MEM_SORT); + + // Step2. Add a final mark mem_op to force flush of open_mem_align_ops, because always the + // last operation is mem_op. + mem_operations.push(Self::end_of_memory_mark()); + + // Step3. Process each memory operation ordered by address and step. When a non-aligned + // memory access there are two possible situations: + // + // 1) the operation applies only applies to one memory address (read or read+write). In + // this case mem_align helper return the aligned operation for this address, and loop + // continues. + // 2) the operation applies to two consecutive memory addresses, mem_align helper returns + // the aligned operation involved for the current address, and the second part of the + // operation is enqueued to open_mem_align_ops, it will processed when processing next + // address. + // + // Inside loop, first of all, we verify if exists "previous" open mem align operations that + // be processed before current mem_op, in this case process all "previous" and after process + // the current mem_op. + + for mem_extern_op in mem_operations.iter_mut() { + self.log_mem_op(mem_extern_op); + let mem_op = MemInput::from(mem_extern_op); + let aligned_mem_addr = Self::to_aligned_addr(mem_op.address); + let mem_step = mem_op.step; + + // Check if there are open mem align operations to be processed in this moment, with + // address (or step) less than the aligned of current mem_op. + self.process_all_previous_open_mem_align_ops(aligned_mem_addr, mem_step, mem_align_sm); + + // check if we are at end of loop + if self.check_if_end_of_memory_mark(&mem_op) { + break; + } + + // TODO: edge case special memory with free-input memory data as input data + let mem_value = self.get_mem_value(aligned_mem_addr); + + // all open mem align operations are processed, check if new mem operation is aligned + if !Self::is_aligned(&mem_extern_op) { + // In this point found non-aligned memory access, phase-0 + let mem_align_input = + MemAlignInput::from(&mem_op, mem_extern_op.width, &[mem_value, 0]); + let mem_align_response = mem_align_sm.get_mem_op(&mem_align_input, 0); + + #[cfg(feature = "debug_mem_proxy_engine")] + if mem_align_input.address >= DEBUG_ADDR - 8 && + mem_align_input.address <= DEBUG_ADDR + 8 + { + debug_info!( + "mem_align_input_{:X}: phase: 0 {:?}", + mem_align_input.address, + mem_align_input + ); + debug_info!( + "mem_align_response_{:X}: phase: 0 {:?}", + mem_align_input.address, + mem_align_response + ); + } + // if operation applies to two consecutive memory addresses, add the second part + // is enqueued to be processed in future when processing next address on phase-1 + if mem_align_response.more_address { + self.push_open_mem_align_op(aligned_mem_addr, &mem_align_input); + } + self.push_mem_align_response_ops( + aligned_mem_addr, + mem_value, + &mem_align_input, + &mem_align_response, + ); + } else { + self.push_mem_op(&mem_op); + } + } + self.finish_prove(); + Ok(()) + } + + fn process_all_previous_open_mem_align_ops( + &mut self, + mem_addr: u32, + mem_step: u64, + mem_align_sm: &MemAlignSM, + ) { + // Two possible situations to process open mem align operations: + // + // 1) the address of open operation is less than the aligned address. + // 2) the address of open operation is equal to the aligned address, but the step of the + // open operation is less than the step of the current operation. + + while self.has_open_mem_align_lt(mem_addr, mem_step) { + let mut open_op = self.open_mem_align_ops.pop_front().unwrap(); + let mem_value = self.get_mem_value(open_op.addr); + + // call to mem_align to get information of the aligned memory access needed + // to prove the unaligned open operation. + open_op.input.mem_values[1] = mem_value; + let mem_align_resp = mem_align_sm.get_mem_op(&open_op.input, 1); + + #[cfg(feature = "debug_mem_proxy_engine")] + if open_op.input.address >= DEBUG_ADDR - 8 && open_op.input.address <= DEBUG_ADDR + 8 { + debug_info!( + "mem_align_input_{:X}: phase:1 {:?}", + open_op.input.address, + open_op.input + ); + debug_info!( + "mem_align_response_{:X}: phase:1 {:?}", + open_op.input.address, + mem_align_resp + ); + } + // push the aligned memory operations for current address (read or read+write) and + // update last_address and last_value. + self.push_mem_align_response_ops( + open_op.addr, + mem_value, + &open_op.input, + &mem_align_resp, + ); + } + } + + pub fn main_step_to_mem_step(step: u64, step_offset: u8) -> u64 { + 1 + MAX_MEM_OPS_PER_MAIN_STEP * step + 2 * step_offset as u64 + } + + /// Static method to decide it the memory operation needs to be processed by + /// memAlign, because it isn't a 8-byte and 8-byte aligned memory access. + fn is_aligned(mem_op: &ZiskRequiredMemory) -> bool { + let aligned_mem_address = (mem_op.address as u64 & MEM_ADDR_MASK) as u32; + aligned_mem_address == mem_op.address && mem_op.width == MEM_BYTES as u8 + } + fn push_mem_op(&mut self, mem_op: &MemInput) { + self.push_aligned_op(mem_op.is_write, mem_op.address, mem_op.value, mem_op.step); + } + + fn push_aligned_op(&mut self, is_write: bool, addr: u32, value: u64, step: u64) { + self.update_last_addr(addr, value); + let mem_op = MemInput { step, is_write, address: addr as u32, value }; + debug_info!( + "route ==> {}[{:X}] {} {} #{}", + self.current_module, + mem_op.address, + if is_write { "W" } else { "R" }, + value, + step, + ); + self.modules_data[self.current_module_id].inputs.push(mem_op); + self.last_addr_value = value; + self.check_flush_inputs(); + } + // method to add aligned read operation + #[inline(always)] + fn push_aligned_read(&mut self, addr: u32, value: u64, step: u64) { + self.push_aligned_op(false, addr, value, step); + } + // method to add aligned write operation + #[inline(always)] + fn push_aligned_write(&mut self, addr: u32, value: u64, step: u64) { + self.push_aligned_op(true, addr, value, step); + } + /// Process information of mem_op and mem_align_op to push mem_op operation. Only two possible + /// situations: + /// 1) read, only on single mem_op is pushed + /// 2) read+write, two mem_op are pushed, one read and one write. + /// + /// This process is used for each aligned memory address, means that the "second part" of non + /// aligned memory operation is processed on addr + MEM_BYTES. + fn push_mem_align_response_ops( + &mut self, + mem_addr: u32, + mem_value: u64, + mem_align_input: &MemAlignInput, + mem_align_resp: &MemAlignResponse, + ) { + self.push_aligned_read(mem_addr, mem_value, mem_align_resp.step); + if mem_align_input.is_write { + #[cfg(feature = "debug_mem_proxy_engine")] + if mem_addr >= DEBUG_ADDR - 8 && mem_addr <= DEBUG_ADDR - 8 { + debug_info!( + "push_mem_align_response_ops_{:X}-A: value:{} {:?}", + mem_addr, + mem_align_resp.value.unwrap(), + mem_align_resp + ); + debug_info!( + "push_mem_align_response_ops_{:X}-B: mem_value:{} {:?}", + mem_addr, + mem_value, + mem_align_input + ); + } + self.push_aligned_write( + mem_addr, + mem_align_resp.value.unwrap(), + mem_align_resp.step + 1, + ); + } + } + fn create_modules_inputs(&self) -> Vec> { + let mut mem_module_inputs: Vec> = Default::default(); + for _module in self.modules.iter() { + mem_module_inputs.push(Vec::new()); + } + mem_module_inputs + } + fn set_active_region(&mut self, region_id: usize) { + self.current_module_id = self.address_map[region_id].module_id as usize; + self.current_module = self.modules_data[self.current_module_id].name.clone(); + self.module_end_addr = self.address_map[region_id].to_address; + } + fn update_mem_module_id(&mut self, addr: u32) { + debug_info!("search module for address 0x{:X}", addr); + if let Some(index) = + self.address_map.iter().position(|x| x.from_address <= addr && x.to_address >= addr) + { + self.set_active_region(index); + } else { + assert!(false, "out-of-memory 0x{:X}", addr); + } + } + fn update_last_addr(&mut self, addr: u32, value: u64) { + self.last_addr = addr; + self.last_addr_value = value; + self.update_mem_module(addr); + } + fn update_mem_module(&mut self, addr: u32) { + // check if need to reevaluate the module id + if addr > self.module_end_addr { + self.update_mem_module_id(addr); + } + } + fn check_flush_inputs(&mut self) { + // check if need to flush the inputs of the module + let mid = self.current_module_id; + let inputs = self.modules_data[mid].inputs.len() as u32; + if inputs >= self.modules_data[mid].flush_input_size { + // TODO: optimize passing ownership of inputs to module, and creating a new input + // object + debug_info!("flush {} inputs => {}", inputs, self.current_module); + self.modules[mid].send_inputs(&self.modules_data[mid].inputs); + self.modules_data[mid].inputs.clear(); + } + } + + fn has_open_mem_align_lt(&self, addr: u32, step: u64) -> bool { + self.open_mem_align_ops.len() > 0 && + (self.open_mem_align_ops[0].addr < addr || + (self.open_mem_align_ops[0].addr == addr && + self.open_mem_align_ops[0].input.step < step)) + } + // method to process open mem align operations, second part of non aligned memory operations + // applies to two consecutive memory addresses. + + fn end_of_memory_mark() -> ZiskRequiredMemory { + ZiskRequiredMemory { + step: 0, + step_offset: 0, + is_write: false, + address: MAX_MEM_ADDR as u32, + width: MEM_BYTES as u8, + value: 0, + } + } + #[inline(always)] + fn check_if_end_of_memory_mark(&self, mem_op: &MemInput) -> bool { + // TODO: 0xFFFF_FFFF not valid address + if mem_op.address == MAX_MEM_ADDR as u32 { + assert!( + self.open_mem_align_ops.len() == 0, + "open_mem_align_ops not empty, has {} elements", + self.open_mem_align_ops.len() + ); + true + } else { + false + } + } + fn init_prove(&mut self) { + if !self.address_map_closed { + self.close_address_map(); + } + self.current_module_id = self.address_map[0].module_id as usize; + self.current_module = self.modules_data[self.current_module_id].name.clone(); + self.module_end_addr = self.address_map[0].to_address; + } + fn finish_prove(&self) { + for (module_id, module) in self.modules.iter().enumerate() { + debug_info!( + "{}: flush all({}) inputs", + self.modules_data[module_id].name, + self.modules_data[module_id].inputs.len() + ); + module.send_inputs(&self.modules_data[module_id].inputs); + } + } + fn get_mem_value(&self, addr: u32) -> u64 { + if addr == self.last_addr { + self.last_addr_value + } else { + 0 + } + } + fn close_address_map(&mut self) { + let mut next_address = 0; + let mut unmapped_regions: Vec<(u32, u32)> = Vec::new(); + for address_region in self.address_map.iter() { + if next_address < address_region.from_address { + unmapped_regions.push((next_address, address_region.from_address - 1)); + } + next_address = address_region.to_address + 1; + } + if !unmapped_regions.is_empty() { + let mut unmapped_module = MemUnmapped::::new(); + for unmapped_region in unmapped_regions.iter() { + unmapped_module.add_range(unmapped_region.0, unmapped_region.1); + } + self.add_module("unmapped", Arc::new(unmapped_module)); + } + self.address_map_closed = true; + } + + #[inline(always)] + fn push_open_mem_align_op(&mut self, aligned_mem_addr: u32, input: &MemAlignInput) { + self.open_mem_align_ops.push_back(MemAlignOperation { + addr: aligned_mem_addr + MEM_BYTES as u32, + input: input.clone(), + }); + } + fn log_mem_op(&self, mem_op: &ZiskRequiredMemory) { + debug_info!( + "next input [0x{:x}] {} {} {}b #{} [0x{:x},{}]", + mem_op.address, + if mem_op.is_write { "W" } else { "R" }, + mem_op.value, + mem_op.width, + mem_op.step, + self.last_addr, + self.last_addr_value + ); + } + #[inline(always)] + fn to_aligned_addr(addr: u32) -> u32 { + (addr as u64 & MEM_ADDR_MASK) as u32 + } +} diff --git a/state-machines/mem/src/mem_sm.rs b/state-machines/mem/src/mem_sm.rs new file mode 100644 index 00000000..66eb8df9 --- /dev/null +++ b/state-machines/mem/src/mem_sm.rs @@ -0,0 +1,323 @@ +use std::sync::{ + atomic::{AtomicU32, Ordering}, + Arc, Mutex, +}; + +use crate::{MemInput, MemModule}; +use num_bigint::BigInt; +use num_traits::cast::ToPrimitive; +use p3_field::PrimeField; +use pil_std_lib::Std; +use proofman::{WitnessComponent, WitnessManager}; +use proofman_common::AirInstance; +use rayon::prelude::*; + +use sm_common::create_prover_buffer; +use zisk_pil::{MemTrace, MEM_AIR_IDS, ZISK_AIRGROUP_ID}; + +const MEM_INITIAL_ADDRESS: u32 = 0xA0000000; +const MEM_FINAL_ADDRESS: u32 = MEM_INITIAL_ADDRESS + 128 * 1024 * 1024; +const MEMORY_MAX_DIFF: u32 = 0x1000000; + +pub struct MemSM { + // Witness computation manager + wcm: Arc>, + + // STD + std: Arc>, + + num_rows: usize, + // Count of registered predecessors + registered_predecessors: AtomicU32, +} + +#[allow(unused, unused_variables)] +impl MemSM { + pub fn new(wcm: Arc>, std: Arc>) -> Arc { + let pctx = wcm.get_pctx(); + let air = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_AIR_IDS[0]); + let mem_sm = Self { + wcm: wcm.clone(), + std: std.clone(), + num_rows: air.num_rows(), + registered_predecessors: AtomicU32::new(0), + }; + let mem_sm = Arc::new(mem_sm); + + wcm.register_component(mem_sm.clone(), Some(ZISK_AIRGROUP_ID), Some(MEM_AIR_IDS)); + + std.register_predecessor(); + + mem_sm + } + + pub fn register_predecessor(&self) { + self.registered_predecessors.fetch_add(1, Ordering::SeqCst); + } + + pub fn unregister_predecessor(&self) { + if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { + let pctx = self.wcm.get_pctx(); + self.std.unregister_predecessor(pctx, None); + } + } + + pub fn prove(&self, inputs: &[MemInput]) { + let wcm = self.wcm.clone(); + let pctx = wcm.get_pctx(); + let ectx = wcm.get_ectx(); + let sctx = wcm.get_sctx(); + + let air_mem = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_AIR_IDS[0]); + let air_mem_rows = air_mem.num_rows(); + + let inputs_len = inputs.len(); + let num_chunks = (inputs_len as f64 / air_mem_rows as f64).ceil() as usize; + + let mut prover_buffers = Mutex::new(vec![Vec::new(); num_chunks]); + let mut offsets = vec![0; num_chunks]; + let mut global_idxs = vec![0; num_chunks]; + + for i in 0..num_chunks { + if let (true, global_idx) = + ectx.dctx.write().unwrap().add_instance(ZISK_AIRGROUP_ID, MEM_AIR_IDS[0], 1) + { + let (buffer, offset) = + create_prover_buffer::(&ectx, &sctx, ZISK_AIRGROUP_ID, MEM_AIR_IDS[0]); + + prover_buffers.lock().unwrap()[i] = buffer; + offsets[i] = offset; + global_idxs[i] = global_idx; + } + } + + for (segment_id, mem_ops) in inputs.chunks(air_mem_rows).enumerate() { + let is_last_segment = segment_id == num_chunks - 1; + + let prover_buffer = std::mem::take(&mut prover_buffers.lock().unwrap()[segment_id]); + + self.prove_instance( + mem_ops, + segment_id, + is_last_segment, + prover_buffer, + offsets[segment_id], + global_idxs[segment_id], + ); + } + + // TODO: Uncomment when sequential works + // inputs.par_chunks(air_mem_rows - 1).enumerate().for_each(|(segment_id, mem_ops)| { + // let mem_first_row = if segment_id == 0 { + // inputs.last().unwrap().clone() + // } else { + // inputs[segment_id * ((air_mem_rows - 1) - 1)].clone() + // }; + + // let prover_buffer = std::mem::take(&mut prover_buffers.lock().unwrap()[segment_id]); + + // self.prove_instance( + // mem_ops, + // mem_first_row, + // segment_id, + // segment_id == inputs.len() - 1, + // prover_buffer, + // offsets[segment_id], + // global_idxs[segment_id], + // ); + // }); + } + + /// Finalizes the witness accumulation process and triggers the proof generation. + /// + /// This method is invoked by the executor when no further witness data remains to be added. + /// + /// # Parameters + /// + /// - `mem_inputs`: A slice of all `MemoryInput` inputs + pub fn prove_instance( + &self, + mem_ops: &[MemInput], + segment_id: usize, + is_last_segment: bool, + mut prover_buffer: Vec, + offset: u64, + global_idx: usize, + ) -> Result<(), Box> { + let wcm = self.wcm.clone(); + let pctx = wcm.get_pctx(); + let sctx = wcm.get_sctx(); + + // STEP2: Process the memory inputs and convert them to AIR instances + let air_mem = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_AIR_IDS[0]); + let air_mem_rows = air_mem.num_rows(); + + let max_rows_per_segment = air_mem_rows - 1; + + assert!(mem_ops.len() > 0 && mem_ops.len() <= max_rows_per_segment); + + // In a Mem AIR instance the first row is a dummy row used for the continuations between AIR + // segments In a Memory AIR instance, the first row is reserved as a dummy row. + // This dummy row is used to facilitate the continuation state between different AIR + // segments. It ensures seamless transitions when multiple AIR segments are + // processed consecutively. This design avoids discontinuities in memory access + // patterns and ensures that the memory trace is continuous, For this reason we use + // AIR num_rows - 1 as the number of rows in each memory AIR instance + + // Create a vector of Mem0Row instances, one for each memory operation + // Recall that first row is a dummy row used for the continuations between AIR segments + // The length of the vector is the number of input memory operations plus one because + // in the prove_witnesses method we drain the memory operations in chunks of n - 1 rows + + let mut trace = + MemTrace::::map_buffer(&mut prover_buffer, air_mem_rows, offset as usize).unwrap(); + + let mut range_check_data: Vec = vec![0; MEMORY_MAX_DIFF as usize]; + + // Fill the first row + let first_mem_op = mem_ops.first().unwrap(); + let addr = first_mem_op.address >> 3; + debug_assert!(addr >= MEM_INITIAL_ADDRESS); + + trace[0].addr = F::from_canonical_u32(addr); + trace[0].step = F::from_canonical_u64(first_mem_op.step); + trace[0].sel = F::zero(); + trace[0].wr = F::zero(); + + let value = first_mem_op.value; + let (low_val, high_val) = self.get_u32_values(value); + trace[0].value = [F::from_canonical_u32(low_val), F::from_canonical_u32(high_val)]; + trace[0].addr_changes = F::zero(); + + trace[0].same_value = F::zero(); + trace[0].first_addr_access_is_read = F::zero(); + + let increment = addr - 1 - MEM_INITIAL_ADDRESS; + trace[0].increment = F::from_canonical_u32(increment); + + // Store the value of incremenet so it can be range checked + println!("addr: {:#X}, initial: {:#X}, increment: {:#X}", addr, MEM_INITIAL_ADDRESS, increment); + range_check_data[increment as usize] += 1; // TODO + + // Fill the remaining rows + for (idx, mem_op) in mem_ops.iter().enumerate() { + let i = idx + 1; + + let mem_addr = mem_op.address >> 3; + trace[i].addr = F::from_canonical_u32(mem_addr); // n-byte address, real address = addr * MEM_BYTES + trace[i].step = F::from_canonical_u64(mem_op.step); + trace[i].sel = F::one(); + trace[i].wr = F::from_bool(mem_op.is_write); + + let (low_val, high_val) = self.get_u32_values(mem_op.value); + trace[i].value = [F::from_canonical_u32(low_val), F::from_canonical_u32(high_val)]; + + let addr_changes = trace[i - 1].addr != trace[i].addr; + trace[i].addr_changes = if addr_changes { F::one() } else { F::zero() }; + + let same_value = trace[i - 1].value[0] == trace[i].value[0] && + trace[i - 1].value[1] == trace[i].value[1]; + trace[i].same_value = if same_value { F::one() } else { F::zero() }; + + let first_addr_access_is_read = addr_changes && !mem_op.is_write; + trace[i].first_addr_access_is_read = + if first_addr_access_is_read { F::one() } else { F::zero() }; + assert!(trace[i].sel.is_zero() || trace[i].sel.is_one()); + + let increment = if addr_changes { + trace[i].addr - trace[i - 1].addr + } else { + trace[i].step - trace[i - 1].step + }; + trace[i].increment = increment; + + // Store the value of incremenet so it can be range checked + let element = + increment.as_canonical_biguint().to_usize().expect("Cannot convert to usize"); + // range_check_data[element] += 1; // TODO: + } + + // STEP3. Add dummy rows to the output vector to fill the remaining rows + // PADDING: At end of memory fill with same addr, incrementing step, same value, sel = 0, rd + // = 1, wr = 0 + let last_row_idx = mem_ops.len(); + let addr = trace[last_row_idx].addr; + let mut step = trace[last_row_idx].step; + let value = trace[last_row_idx].value; + + let padding_size = air_mem_rows - (mem_ops.len() + 1); + + for i in (mem_ops.len() + 1)..air_mem_rows { + step += F::one(); + + // TODO CHECK + // trace[i].mem_segment = segment_id_field; + // trace[i].mem_last_segment = is_last_segment_field; + + trace[i].addr = addr; + trace[i].step = step; + trace[i].sel = F::zero(); + trace[i].wr = F::zero(); + + trace[i].value = value; + + trace[i].addr_changes = F::zero(); + trace[i].same_value = F::one(); + trace[i].first_addr_access_is_read = F::zero(); + + // Set increment to the minimum value so the range check passes + trace[i].increment = F::one(); + } + + // Store the value of trivial increment so that they can be range checked + range_check_data[1] += padding_size as u64; + + // TODO: Perform the range checks + // let std = self.std.clone(); + // let range_id = std.get_range(BigInt::from(1), BigInt::from(MEMORY_MAX_DIFF), None); + // for (value, &multiplicity) in range_check_data.iter().enumerate() { + // std.range_check( + // F::from_canonical_usize(value), + // F::from_canonical_u64(multiplicity), + // range_id, + // ); + // } + + let mut air_instance = AirInstance::new( + sctx.clone(), + ZISK_AIRGROUP_ID, + MEM_AIR_IDS[0], + Some(segment_id), + prover_buffer, + ); + + air_instance.set_airvalue( + &sctx, + "Mem.mem_segment", + F::from_canonical_u64(segment_id as u64), + ); + air_instance.set_airvalue(&sctx, "Mem.mem_last_segment", F::from_bool(is_last_segment)); + + pctx.air_instance_repo.add_air_instance(air_instance, Some(global_idx)); + + Ok(()) + } + + fn get_u32_values(&self, value: u64) -> (u32, u32) { + (value as u32, (value >> 32) as u32) + } +} + +impl MemModule for MemSM { + fn send_inputs(&self, mem_op: &[MemInput]) { + self.prove(&mem_op); + } + fn get_addr_ranges(&self) -> Vec<(u32, u32)> { + vec![(MEM_INITIAL_ADDRESS, MEM_FINAL_ADDRESS)] + } + fn get_flush_input_size(&self) -> u32 { + self.num_rows as u32 + } +} + +impl WitnessComponent for MemSM {} diff --git a/state-machines/mem/src/mem_traces.rs b/state-machines/mem/src/mem_traces.rs deleted file mode 100644 index c80a8c74..00000000 --- a/state-machines/mem/src/mem_traces.rs +++ /dev/null @@ -1,5 +0,0 @@ -use proofman_common as common; -pub use proofman_macros::trace; - -trace!(MemALignedRow, MemALignedTrace { fake: F }); -trace!(MemUnaLignedRow, MemUnaLignedTrace { fake: F}); diff --git a/state-machines/mem/src/mem_unaligned.rs b/state-machines/mem/src/mem_unaligned.rs deleted file mode 100644 index 9d47a135..00000000 --- a/state-machines/mem/src/mem_unaligned.rs +++ /dev/null @@ -1,114 +0,0 @@ -use std::sync::{ - atomic::{AtomicU32, Ordering}, - Arc, Mutex, -}; - -use p3_field::Field; -use proofman::{WitnessComponent, WitnessManager}; -use proofman_common::{ExecutionCtx, ProofCtx, SetupCtx}; -use rayon::Scope; -use sm_common::{MemUnalignedOp, OpResult, Provable}; -use zisk_pil::{MEM_AIRGROUP_ID, MEM_UNALIGNED_AIR_IDS}; - -const PROVE_CHUNK_SIZE: usize = 1 << 12; - -pub struct MemUnalignedSM { - // Count of registered predecessors - registered_predecessors: AtomicU32, - - // Inputs - inputs: Mutex>, -} - -#[allow(unused, unused_variables)] -impl MemUnalignedSM { - pub fn new(wcm: Arc>) -> Arc { - let mem_aligned_sm = - Self { registered_predecessors: AtomicU32::new(0), inputs: Mutex::new(Vec::new()) }; - let mem_aligned_sm = Arc::new(mem_aligned_sm); - - wcm.register_component( - mem_aligned_sm.clone(), - Some(MEM_AIRGROUP_ID), - Some(MEM_UNALIGNED_AIR_IDS), - ); - - mem_aligned_sm - } - - pub fn register_predecessor(&self) { - self.registered_predecessors.fetch_add(1, Ordering::SeqCst); - } - - pub fn unregister_predecessor(&self, scope: &Scope) { - if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { - >::prove(self, &[], true, scope); - } - } - - fn read( - &self, - _addr: u64, - _width: usize, /* , _ctx: &mut ProofCtx, _ectx: &ExecutionCtx */ - ) -> Result> { - Ok((0, true)) - } - - fn write( - &self, - _addr: u64, - _width: usize, - _val: u64, /* , _ctx: &mut ProofCtx, _ectx: &ExecutionCtx */ - ) -> Result> { - Ok((0, true)) - } -} - -impl WitnessComponent for MemUnalignedSM { - fn calculate_witness( - &self, - _stage: u32, - _air_instance: Option, - _pctx: Arc>, - _ectx: Arc>, - _sctx: Arc>, - ) { - } -} - -impl Provable for MemUnalignedSM { - fn calculate(&self, operation: MemUnalignedOp) -> Result> { - match operation { - MemUnalignedOp::Read(addr, width) => self.read(addr, width), - MemUnalignedOp::Write(addr, width, val) => self.write(addr, width, val), - } - } - - fn prove(&self, operations: &[MemUnalignedOp], drain: bool, scope: &Scope) { - if let Ok(mut inputs) = self.inputs.lock() { - inputs.extend_from_slice(operations); - - while inputs.len() >= PROVE_CHUNK_SIZE || (drain && !inputs.is_empty()) { - let num_drained = std::cmp::min(PROVE_CHUNK_SIZE, inputs.len()); - let _drained_inputs = inputs.drain(..num_drained).collect::>(); - - scope.spawn(move |_| { - // TODO! Implement prove drained_inputs (a chunk of operations) - }); - } - } - } - - fn calculate_prove( - &self, - operation: MemUnalignedOp, - drain: bool, - scope: &Scope, - ) -> Result> { - let result = self.calculate(operation.clone()); - - self.prove(&[operation], drain, scope); - - result - } -} diff --git a/state-machines/mem/src/mem_unmapped.rs b/state-machines/mem/src/mem_unmapped.rs new file mode 100644 index 00000000..f647750a --- /dev/null +++ b/state-machines/mem/src/mem_unmapped.rs @@ -0,0 +1,29 @@ +use std::marker::PhantomData; + +use crate::{MemInput, MemModule}; +use p3_field::PrimeField; + +pub struct MemUnmapped { + ranges: Vec<(u32, u32)>, + __data: PhantomData, +} + +impl MemUnmapped { + pub fn new() -> Self { + Self { ranges: Vec::new(), __data: PhantomData } + } + pub fn add_range(&mut self, _start: u32, _end: u32) { + self.ranges.push((_start, _end)); + } +} +impl MemModule for MemUnmapped { + fn send_inputs(&self, _mem_op: &[MemInput]) { + // panic!("[MemUnmapped] invalid access to addr {:x}", _mem_op[0].addr); + } + fn get_addr_ranges(&self) -> Vec<(u32, u32)> { + self.ranges.to_vec() + } + fn get_flush_input_size(&self) -> u32 { + 1 + } +} diff --git a/state-machines/rom/src/rom.rs b/state-machines/rom/src/rom.rs index 3fc4f9b1..234aedac 100644 --- a/state-machines/rom/src/rom.rs +++ b/state-machines/rom/src/rom.rs @@ -30,30 +30,30 @@ impl RomSM { &self, rom: &ZiskRom, pc_histogram: ZiskPcHistogram, + instance_gid: usize, ) -> Result<(), Box> { - let buffer_allocator = self.wcm.get_ectx().buffer_allocator.clone(); - let sctx = self.wcm.get_sctx(); - if pc_histogram.end_pc == 0 { panic!("RomSM::prove() detected pc_histogram.end_pc == 0"); // TODO: return an error } + let buffer_allocator = self.wcm.get_ectx().buffer_allocator.clone(); + let sctx = self.wcm.get_sctx(); + let main_trace_len = - self.wcm.get_pctx().pilout.get_air(ZISK_AIRGROUP_ID, MAIN_AIR_IDS[0]).num_rows() as u64; + self.wcm.get_pctx().pilout.get_air(ZISK_AIRGROUP_ID, MAIN_AIR_IDS[0]).num_rows(); - let (prover_buffer, _, air_id) = - Self::compute_trace_rom(rom, buffer_allocator, &sctx, pc_histogram, main_trace_len)?; + let prover_buffer = Self::compute_trace_rom( + rom, + buffer_allocator, + &sctx, + pc_histogram, + main_trace_len as u64, + )?; let air_instance = - AirInstance::new(sctx.clone(), ZISK_AIRGROUP_ID, air_id, None, prover_buffer); - let (is_mine, instance_gid) = - self.wcm.get_ectx().dctx.write().unwrap().add_instance(ZISK_AIRGROUP_ID, air_id, 1); - if is_mine { - self.wcm - .get_pctx() - .air_instance_repo - .add_air_instance(air_instance, Some(instance_gid)); - } + AirInstance::new(sctx.clone(), ZISK_AIRGROUP_ID, ROM_AIR_IDS[0], None, prover_buffer); + + self.wcm.get_pctx().air_instance_repo.add_air_instance(air_instance, Some(instance_gid)); Ok(()) } @@ -61,7 +61,7 @@ impl RomSM { rom_path: PathBuf, buffer_allocator: Arc>, sctx: &SetupCtx, - ) -> Result<(Vec, u64, usize), Box> { + ) -> Result, Box> { // Get the ELF file path as a string let elf_filename: String = rom_path.to_str().unwrap().into(); println!("Proving ROM for ELF file={}", elf_filename); @@ -90,69 +90,12 @@ impl RomSM { sctx: &SetupCtx, pc_histogram: ZiskPcHistogram, main_trace_len: u64, - ) -> Result<(Vec, u64, usize), Box> { + ) -> Result, Box> { let pilout = Pilout::pilout(); - let sizes = ( - pilout.get_air(ZISK_AIRGROUP_ID, ROM_AIR_IDS[0]).num_rows(), - // pilout.get_air(ZISK_AIRGROUP_ID, ROM_M_AIR_IDS[0]).num_rows(), - // pilout.get_air(ZISK_AIRGROUP_ID, ROM_L_AIR_IDS[0]).num_rows(), - ); + let num_rows = pilout.get_air(ZISK_AIRGROUP_ID, ROM_AIR_IDS[0]).num_rows(); let number_of_instructions = rom.insts.len(); - Self::create_rom_s( - sizes.0, - rom, - number_of_instructions, - buffer_allocator, - sctx, - pc_histogram, - main_trace_len, - ) - // match number_of_instructions { - // n if n <= sizes.0 => Self::create_rom_s( - // sizes.0, - // rom, - // n, - // buffer_allocator, - // sctx, - // pc_histogram, - // main_trace_len, - // ), - // n if n <= sizes.1 => Self::create_rom_m( - // sizes.1, - // rom, - // n, - // buffer_allocator, - // sctx, - // pc_histogram, - // main_trace_len, - // ), - // n if n < sizes.2 => Self::create_rom_l( - // sizes.2, - // rom, - // n, - // buffer_allocator, - // sctx, - // pc_histogram, - // main_trace_len, - // ), - // _ => panic!("RomSM::compute_trace() found rom too big size={}", - // number_of_instructions), } - } - - fn create_rom_s( - rom_s_size: usize, - rom: &zisk_core::ZiskRom, - number_of_instructions: usize, - buffer_allocator: Arc>, - sctx: &SetupCtx, - pc_histogram: ZiskPcHistogram, - main_trace_len: u64, - ) -> Result<(Vec, u64, usize), Box> { - // Set trace size - let trace_size = rom_s_size; - // Allocate a prover buffer let (buffer_size, offsets) = buffer_allocator .get_buffer_info(sctx, ZISK_AIRGROUP_ID, ROM_AIR_IDS[0]) @@ -161,8 +104,8 @@ impl RomSM { // Create an empty ROM trace let mut rom_trace = - RomTrace::::map_buffer(&mut prover_buffer, trace_size, offsets[0] as usize) - .expect("RomSM::compute_trace() failed mapping buffer to ROMSRow"); + RomTrace::::map_buffer(&mut prover_buffer, num_rows, offsets[0] as usize) + .expect("RomSM::compute_trace() failed mapping buffer to ROMS0Trace"); // For every instruction in the rom, fill its corresponding ROM trace //for (i, inst_builder) in rom.insts.clone().into_iter().enumerate() { @@ -235,230 +178,15 @@ impl RomSM { rom_trace[i].jmp_offset2 = jmp_offset2; rom_trace[i].flags = F::from_canonical_u64(inst.get_flags()); rom_trace[i].multiplicity = F::from_canonical_u64(multiplicity); - /*println!( - "ROM SM [{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}], {}", - inst.paddr, - inst.a_offset_imm0, - if inst.a_src == SRC_IMM { inst.a_use_sp_imm1 } else { 0 }, - inst.b_offset_imm0, - if inst.b_src == SRC_IMM { inst.b_use_sp_imm1 } else { 0 }, - if inst.b_src == SRC_IND { 1 } else { 0 }, - inst.ind_width, - inst.op, - inst.store_offset as u64, - inst.jmp_offset1 as u64, - inst.jmp_offset2 as u64, - inst.get_flags(), - multiplicity, - );*/ - i += 1; } // Padd with zeroes - for i in number_of_instructions..trace_size { + for i in number_of_instructions..num_rows { rom_trace[i] = RomRow::default(); } - Ok((prover_buffer, offsets[0], ROM_AIR_IDS[0])) + Ok(prover_buffer) } - - // fn create_rom_m( - // rom_m_size: usize, - // rom: &zisk_core::ZiskRom, - // number_of_instructions: usize, - // buffer_allocator: Arc, - // sctx: &SetupCtx, - // pc_histogram: ZiskPcHistogram, - // main_trace_len: u64, - // ) -> Result<(Vec, u64, usize), Box> { - // // Set trace size - // let trace_size = rom_m_size; - - // // Allocate a prover buffer - // let (buffer_size, offsets) = buffer_allocator - // .get_buffer_info(sctx, ZISK_AIRGROUP_ID, ROM_M_AIR_IDS[0]) - // .unwrap_or_else(|err| panic!("Error getting buffer info: {}", err)); - // let mut prover_buffer = create_buffer_fast(buffer_size as usize); - - // // Create an empty ROM trace - // let mut rom_trace = - // RomM1Trace::::map_buffer(&mut prover_buffer, trace_size, offsets[0] as usize) - // .expect("RomSM::compute_trace() failed mapping buffer to ROMMRow"); - - // // For every instruction in the rom, fill its corresponding ROM trace - // for (i, inst_builder) in rom.insts.clone().into_iter().enumerate() { - // // Get the Zisk instruction - // let inst = inst_builder.1.i; - - // // Calculate the multiplicity, i.e. the number of times this pc is used in this - // // execution - // let mut multiplicity: u64; - // if pc_histogram.map.is_empty() { - // multiplicity = 1; // If the histogram is empty, we use 1 for all pc's - // } else { - // let counter = pc_histogram.map.get(&inst.paddr); - // if counter.is_some() { - // multiplicity = *counter.unwrap(); - // if inst.paddr == pc_histogram.end_pc { - // multiplicity += main_trace_len - 1 - (pc_histogram.steps % - // main_trace_len); } - // } else { - // continue; // We skip those pc's that are not used in this execution - // } - // } - - // // Convert the i64 offsets to F - // let jmp_offset1 = if inst.jmp_offset1 >= 0 { - // F::from_canonical_u64(inst.jmp_offset1 as u64) - // } else { - // F::neg(F::from_canonical_u64((-inst.jmp_offset1) as u64)) - // }; - // let jmp_offset2 = if inst.jmp_offset2 >= 0 { - // F::from_canonical_u64(inst.jmp_offset2 as u64) - // } else { - // F::neg(F::from_canonical_u64((-inst.jmp_offset2) as u64)) - // }; - // let store_offset = if inst.store_offset >= 0 { - // F::from_canonical_u64(inst.store_offset as u64) - // } else { - // F::neg(F::from_canonical_u64((-inst.store_offset) as u64)) - // }; - // let a_offset_imm0 = if inst.a_offset_imm0 as i64 >= 0 { - // F::from_canonical_u64(inst.a_offset_imm0) - // } else { - // F::neg(F::from_canonical_u64((-(inst.a_offset_imm0 as i64)) as u64)) - // }; - // let b_offset_imm0 = if inst.b_offset_imm0 as i64 >= 0 { - // F::from_canonical_u64(inst.b_offset_imm0) - // } else { - // F::neg(F::from_canonical_u64((-(inst.b_offset_imm0 as i64)) as u64)) - // }; - - // // Fill the rom trace row fields - // rom_trace[i].line = F::from_canonical_u64(inst.paddr); // TODO: unify names: pc, - // paddr, line rom_trace[i].a_offset_imm0 = a_offset_imm0; - // rom_trace[i].a_imm1 = - // F::from_canonical_u64(if inst.a_src == SRC_IMM { inst.a_use_sp_imm1 } else { 0 - // }); rom_trace[i].b_offset_imm0 = b_offset_imm0; - // rom_trace[i].b_imm1 = - // F::from_canonical_u64(if inst.b_src == SRC_IMM { inst.b_use_sp_imm1 } else { 0 - // }); //rom_trace[i].b_src_ind = - // // F::from_canonical_u64(if inst.b_src == SRC_IND { 1 } else { 0 }); - // rom_trace[i].ind_width = F::from_canonical_u64(inst.ind_width); - // rom_trace[i].op = F::from_canonical_u8(inst.op); - // rom_trace[i].store_offset = store_offset; - // rom_trace[i].jmp_offset1 = jmp_offset1; - // rom_trace[i].jmp_offset2 = jmp_offset2; - // rom_trace[i].flags = F::from_canonical_u64(inst.get_flags()); - // rom_trace[i].multiplicity = F::from_canonical_u64(multiplicity); - // } - - // // Padd with zeroes - // for i in number_of_instructions..trace_size { - // rom_trace[i] = RomM1Row::default(); - // } - - // Ok((prover_buffer, offsets[0], ROM_M_AIR_IDS[0])) - // } - - // fn create_rom_l( - // rom_l_size: usize, - // rom: &zisk_core::ZiskRom, - // number_of_instructions: usize, - // buffer_allocator: Arc, - // sctx: &SetupCtx, - // pc_histogram: ZiskPcHistogram, - // main_trace_len: u64, - // ) -> Result<(Vec, u64, usize), Box> { - // // Set trace size - // let trace_size = rom_l_size; - - // // Allocate a prover buffer - // let (buffer_size, offsets) = buffer_allocator - // .get_buffer_info(sctx, ZISK_AIRGROUP_ID, ROM_L_AIR_IDS[0]) - // .unwrap_or_else(|err| panic!("Error getting buffer info: {}", err)); - // let mut prover_buffer = create_buffer_fast(buffer_size as usize); - - // // Create an empty ROM trace - // let mut rom_trace = - // RomL2Trace::::map_buffer(&mut prover_buffer, trace_size, offsets[0] as usize) - // .expect("RomSM::compute_trace() failed mapping buffer to ROMLRow"); - - // // For every instruction in the rom, fill its corresponding ROM trace - // for (i, inst_builder) in rom.insts.clone().into_iter().enumerate() { - // // Get the Zisk instruction - // let inst = inst_builder.1.i; - - // // Calculate the multiplicity, i.e. the number of times this pc is used in this - // // execution - // let mut multiplicity: u64; - // if pc_histogram.map.is_empty() { - // multiplicity = 1; // If the histogram is empty, we use 1 for all pc's - // } else { - // let counter = pc_histogram.map.get(&inst.paddr); - // if counter.is_some() { - // multiplicity = *counter.unwrap(); - // if inst.paddr == pc_histogram.end_pc { - // multiplicity += main_trace_len - 1 - (pc_histogram.steps % - // main_trace_len); } - // } else { - // continue; // We skip those pc's that are not used in this execution - // } - // } - - // // Convert the i64 offsets to F - // let jmp_offset1 = if inst.jmp_offset1 >= 0 { - // F::from_canonical_u64(inst.jmp_offset1 as u64) - // } else { - // F::neg(F::from_canonical_u64((-inst.jmp_offset1) as u64)) - // }; - // let jmp_offset2 = if inst.jmp_offset2 >= 0 { - // F::from_canonical_u64(inst.jmp_offset2 as u64) - // } else { - // F::neg(F::from_canonical_u64((-inst.jmp_offset2) as u64)) - // }; - // let store_offset = if inst.store_offset >= 0 { - // F::from_canonical_u64(inst.store_offset as u64) - // } else { - // F::neg(F::from_canonical_u64((-inst.store_offset) as u64)) - // }; - // let a_offset_imm0 = if inst.a_offset_imm0 as i64 >= 0 { - // F::from_canonical_u64(inst.a_offset_imm0) - // } else { - // F::neg(F::from_canonical_u64((-(inst.a_offset_imm0 as i64)) as u64)) - // }; - // let b_offset_imm0 = if inst.b_offset_imm0 as i64 >= 0 { - // F::from_canonical_u64(inst.b_offset_imm0) - // } else { - // F::neg(F::from_canonical_u64((-(inst.b_offset_imm0 as i64)) as u64)) - // }; - - // // Fill the rom trace row fields - // rom_trace[i].line = F::from_canonical_u64(inst.paddr); // TODO: unify names: pc, - // paddr, line rom_trace[i].a_offset_imm0 = a_offset_imm0; - // rom_trace[i].a_imm1 = - // F::from_canonical_u64(if inst.a_src == SRC_IMM { inst.a_use_sp_imm1 } else { 0 - // }); rom_trace[i].b_offset_imm0 = b_offset_imm0; - // rom_trace[i].b_imm1 = - // F::from_canonical_u64(if inst.b_src == SRC_IMM { inst.b_use_sp_imm1 } else { 0 - // }); //rom_trace[i].b_src_ind = - // // F::from_canonical_u64(if inst.b_src == SRC_IND { 1 } else { 0 }); - // rom_trace[i].ind_width = F::from_canonical_u64(inst.ind_width); - // rom_trace[i].op = F::from_canonical_u8(inst.op); - // rom_trace[i].store_offset = store_offset; - // rom_trace[i].jmp_offset1 = jmp_offset1; - // rom_trace[i].jmp_offset2 = jmp_offset2; - // rom_trace[i].flags = F::from_canonical_u64(inst.get_flags()); - // rom_trace[i].multiplicity = F::from_canonical_u64(multiplicity); - // } - - // // Padd with zeroes - // for i in number_of_instructions..trace_size { - // rom_trace[i] = RomL2Row::default(); - // } - - // Ok((prover_buffer, offsets[0], ROM_L_AIR_IDS[0])) - // } } impl WitnessComponent for RomSM {} diff --git a/witness-computation/src/executor.rs b/witness-computation/src/executor.rs index b77f6bfe..8efd1fc8 100644 --- a/witness-computation/src/executor.rs +++ b/witness-computation/src/executor.rs @@ -10,16 +10,18 @@ use sm_arith::ArithSM; use sm_binary::BinarySM; use sm_common::create_prover_buffer; use sm_main::{InstanceExtensionCtx, MainSM}; -use sm_mem::MemSM; +use sm_mem::MemProxy; use sm_rom::RomSM; use std::{ fs, path::{Path, PathBuf}, sync::Arc, + thread, }; use zisk_core::{Riscv2zisk, ZiskOperationType, ZiskRom, ZISK_OPERATION_TYPE_VARIANTS}; use zisk_pil::{ - ARITH_AIR_IDS, BINARY_AIR_IDS, BINARY_EXTENSION_AIR_IDS, MAIN_AIR_IDS, ZISK_AIRGROUP_ID, + ARITH_AIR_IDS, BINARY_AIR_IDS, BINARY_EXTENSION_AIR_IDS, MAIN_AIR_IDS, ROM_AIR_IDS, + ZISK_AIRGROUP_ID, }; use ziskemu::{EmuOptions, ZiskEmulator}; @@ -34,7 +36,7 @@ pub struct ZiskExecutor { pub rom_sm: Arc>, /// Memory State Machine - pub mem_sm: Arc, + pub mem_proxy_sm: Arc>, /// Binary State Machine pub binary_sm: Arc>, @@ -50,7 +52,7 @@ impl ZiskExecutor { let std = Std::new(wcm.clone()); let rom_sm = RomSM::new(wcm.clone()); - let mem_sm = MemSM::new(wcm.clone()); + let mem_proxy_sm = MemProxy::new(wcm.clone(), std.clone()); let binary_sm = BinarySM::new(wcm.clone(), std.clone()); let arith_sm = ArithSM::new(wcm.clone()); @@ -82,9 +84,10 @@ impl ZiskExecutor { // TODO - If there is more than one Main AIR available, the MAX_ACCUMULATED will be the one // with the highest num_rows. It has to be a power of 2. - let main_sm = MainSM::new(wcm.clone(), arith_sm.clone(), binary_sm.clone(), mem_sm.clone()); + let main_sm = + MainSM::new(wcm.clone(), mem_proxy_sm.clone(), arith_sm.clone(), binary_sm.clone()); - Self { zisk_rom, main_sm, rom_sm, mem_sm, binary_sm, arith_sm } + Self { zisk_rom, main_sm, rom_sm, mem_proxy_sm, binary_sm, arith_sm } } /// Executes the MainSM state machine and processes the inputs in batches when the maximum @@ -119,6 +122,7 @@ impl ZiskExecutor { let path = PathBuf::from(public_inputs_path.display().to_string()); fs::read(path).expect("Could not read inputs file") }; + let public_inputs = Arc::new(public_inputs); // During ROM processing, we gather execution data necessary for creating the AIR instances. // This data is collected by the emulator and includes the minimal execution trace, @@ -138,17 +142,36 @@ impl ZiskExecutor { op_sizes[ZiskOperationType::Binary as usize] = air_binary.num_rows() as u64; op_sizes[ZiskOperationType::BinaryE as usize] = air_binary_e.num_rows() as u64; + // STEP 1. Generate all inputs + // ============================================== + + // Memory State Machine + // ---------------------------------------------- + let mem_thread = thread::spawn({ + let zisk_rom = self.zisk_rom.clone(); + let public_inputs = public_inputs.clone(); + move || { + ZiskEmulator::par_process_rom_memory::(&zisk_rom, &public_inputs) + .expect("Failed in ZiskEmulator::par_process_rom_memory") + } + }); + // ROM State Machine // ---------------------------------------------- // Run the ROM to compute the ROM witness - let rom_sm = self.rom_sm.clone(); - let zisk_rom = self.zisk_rom.clone(); - let pc_histogram = - ZiskEmulator::process_rom_pc_histogram(&self.zisk_rom, &public_inputs, &emu_options) - .expect( - "MainSM::execute() failed calling ZiskEmulator::process_rom_pc_histogram()", - ); - let handle_rom = std::thread::spawn(move || rom_sm.prove(&zisk_rom, pc_histogram)); + let rom_thread = thread::spawn({ + let zisk_rom = self.zisk_rom.clone(); + let public_inputs = public_inputs.clone(); + let emu_options_cloned = emu_options.clone(); + move || { + ZiskEmulator::process_rom_pc_histogram( + &zisk_rom, + &public_inputs, + &emu_options_cloned, + ) + .expect("MainSM::execute() failed calling ZiskEmulator::process_rom_pc_histogram()") + } + }); // Main, Binary and Arith State Machines // ---------------------------------------------- @@ -165,10 +188,43 @@ impl ZiskExecutor { .expect("Error during emulator execution"); timer_stop_and_log_debug!(PAR_PROCESS_ROM); - emu_slices.points.sort_by(|a, b| a.op_type.partial_cmp(&b.op_type).unwrap()); + // STEP 2. Wait until all inputs are generated + // ============================================== + // Join all the threads to synchronize the execution + let mut mem_required = mem_thread.join().expect("Error during Memory witness computation"); + let rom_required = rom_thread.join().expect("Error during ROM witness computation"); + + // STEP 3. Generate AIRs and Prove + // ============================================== - // Join threads to synchronize the execution - handle_rom.join().unwrap().expect("Error during ROM witness computation"); + // Memory State Machine + // ---------------------------------------------- + let mem_thread = thread::spawn({ + let mem_proxy_sm = self.mem_proxy_sm.clone(); + move || { + mem_proxy_sm + .prove(&mut mem_required) + .expect("Error during Memory witness computation") + } + }); + + // ROM State Machine + // ---------------------------------------------- + let (rom_is_mine, rom_instance_gid) = + ectx.dctx.write().unwrap().add_instance(ZISK_AIRGROUP_ID, ROM_AIR_IDS[0], 1); + + let rom_thread = if rom_is_mine { + let rom_sm = self.rom_sm.clone(); + let zisk_rom = self.zisk_rom.clone(); + + Some(thread::spawn(move || rom_sm.prove(&zisk_rom, rom_required, rom_instance_gid))) + } else { + None + }; + + // Main, Binary and Arith State Machines + // ---------------------------------------------- + emu_slices.points.sort_by(|a, b| a.op_type.partial_cmp(&b.op_type).unwrap()); // FIXME: Move InstanceExtensionCtx form main SM to another place let mut instances_extension_ctx: Vec> = @@ -236,7 +292,28 @@ impl ZiskExecutor { } timer_stop_and_log_debug!(ADD_INSTANCES_TO_THE_REPO); - // self.mem_sm.unregister_predecessor(scope); + mem_thread.join().expect("Error during Memory witness computation"); + + // match mem_thread.join() { + // Ok(_) => println!("El thread ha finalitzat correctament."), + // Err(e) => { + // println!("El thread ha fet panic!"); + // + // // Converteix l'error en una cadena llegible (opcional) + // if let Some(missatge) = e.downcast_ref::<&str>() { + // println!("Missatge d'error: {}", missatge); + // } else if let Some(missatge) = e.downcast_ref::() { + // println!("Missatge d'error: {}", missatge); + // } else { + // println!("No es pot determinar el tipus d'error."); + // } + // } + // } + if let Some(thread) = rom_thread { + let _ = thread.join().expect("Error during ROM witness computation"); + } + + self.mem_proxy_sm.unregister_predecessor(); self.binary_sm.unregister_predecessor(); self.arith_sm.unregister_predecessor(); }