From 16ac6735723dba52a308f2b8296c5e3d7fac212d Mon Sep 17 00:00:00 2001 From: Xavier Pinsach <10213118+xavi-pinsach@users.noreply.github.com> Date: Tue, 5 Nov 2024 05:26:46 +0000 Subject: [PATCH 01/44] first commit on memory_checkpoint3 --- Cargo.lock | 320 +++++++++++++++--- Cargo.toml | 24 +- emulator/src/emu.rs | 259 +++++++++++++- emulator/src/emulator.rs | 27 +- pil/src/pil_helpers/pilout.rs | 15 +- pil/src/pil_helpers/traces.rs | 8 +- state-machines/binary/src/binary.rs | 13 +- state-machines/main/pil/main.pil | 2 +- state-machines/main/src/main_sm.rs | 7 +- state-machines/mem/Cargo.toml | 1 + state-machines/mem/src/lib.rs | 14 +- state-machines/mem/src/mem.rs | 101 ------ .../src/{mem_aligned.rs => mem_align_sm.rs} | 10 +- state-machines/mem/src/mem_proxy.rs | 115 +++++++ state-machines/mem/src/mem_sm.rs | 284 ++++++++++++++++ state-machines/mem/src/mem_traces.rs | 5 - state-machines/mem/src/mem_unaligned.rs | 114 ------- state-machines/rom/src/rom.rs | 310 +---------------- witness-computation/src/executor.rs | 93 ++++- 19 files changed, 1076 insertions(+), 646 deletions(-) delete mode 100644 state-machines/mem/src/mem.rs rename state-machines/mem/src/{mem_aligned.rs => mem_align_sm.rs} (92%) create mode 100644 state-machines/mem/src/mem_proxy.rs create mode 100644 state-machines/mem/src/mem_sm.rs delete mode 100644 state-machines/mem/src/mem_traces.rs delete mode 100644 state-machines/mem/src/mem_unaligned.rs diff --git a/Cargo.lock b/Cargo.lock index dd7c1966..c9705649 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -62,9 +62,9 @@ dependencies = [ [[package]] name = "anstyle" -version = "1.0.9" +version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8365de52b16c035ff4fcafe0092ba9390540e3e352870ac09933bebcaa2c8c56" +checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9" [[package]] name = "anstyle-parse" @@ -96,9 +96,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.91" +version = "1.0.92" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c042108f3ed77fd83760a5fd79b53be043192bb3b9dba91d8c574c0ada7850c8" +checksum = "74f37166d7d48a0284b99dd824694c26119c700b53bf0d1540cdb147dbdaaf13" dependencies = [ "backtrace", ] @@ -198,9 +198,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.1.31" +version = "1.1.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2e7962b54006dcfcc61cb72735f4d89bb97061dd6a7ed882ec6b8ee53714c6f" +checksum = "67b9470d453346108f93a59222a9a1a5724db32d0a4727b7ab7ace4b4d822dc9" dependencies = [ "jobserver", "libc", @@ -436,6 +436,17 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "displaydoc" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "either" version = "1.13.0" @@ -793,14 +804,143 @@ dependencies = [ "tracing", ] +[[package]] +name = "icu_collections" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db2fa452206ebee18c4b5c2274dbf1de17008e874b4dc4f0aea9d01ca79e4526" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_locid" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13acbb8371917fc971be86fc8057c41a64b521c184808a698c02acc242dbf637" +dependencies = [ + "displaydoc", + "litemap", + "tinystr", + "writeable", + "zerovec", +] + +[[package]] +name = "icu_locid_transform" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01d11ac35de8e40fdeda00d9e1e9d92525f3f9d887cdd7aa81d727596788b54e" +dependencies = [ + "displaydoc", + "icu_locid", + "icu_locid_transform_data", + "icu_provider", + "tinystr", + "zerovec", +] + +[[package]] +name = "icu_locid_transform_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdc8ff3388f852bede6b579ad4e978ab004f139284d7b28715f773507b946f6e" + +[[package]] +name = "icu_normalizer" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19ce3e0da2ec68599d193c93d088142efd7f9c5d6fc9b803774855747dc6a84f" +dependencies = [ + "displaydoc", + "icu_collections", + "icu_normalizer_data", + "icu_properties", + "icu_provider", + "smallvec", + "utf16_iter", + "utf8_iter", + "write16", + "zerovec", +] + +[[package]] +name = "icu_normalizer_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8cafbf7aa791e9b22bec55a167906f9e1215fd475cd22adfcf660e03e989516" + +[[package]] +name = "icu_properties" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93d6020766cfc6302c15dbbc9c8778c37e62c14427cb7f6e601d849e092aeef5" +dependencies = [ + "displaydoc", + "icu_collections", + "icu_locid_transform", + "icu_properties_data", + "icu_provider", + "tinystr", + "zerovec", +] + +[[package]] +name = "icu_properties_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67a8effbc3dd3e4ba1afa8ad918d5684b8868b3b26500753effea8d2eed19569" + +[[package]] +name = "icu_provider" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ed421c8a8ef78d3e2dbc98a973be2f3770cb42b606e3ab18d6237c4dfde68d9" +dependencies = [ + "displaydoc", + "icu_locid", + "icu_provider_macros", + "stable_deref_trait", + "tinystr", + "writeable", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_provider_macros" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "idna" -version = "0.5.0" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "686f825264d630750a544639377bae737628043f20d38bbc029e8f29ea968a7e" +dependencies = [ + "idna_adapter", + "smallvec", + "utf8_iter", +] + +[[package]] +name = "idna_adapter" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" +checksum = "daca1df1c957320b2cf139ac61e7bd64fed304c5040df000a745aa1de3b4ef71" dependencies = [ - "unicode-bidi", - "unicode-normalization", + "icu_normalizer", + "icu_properties", ] [[package]] @@ -996,6 +1136,12 @@ version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" +[[package]] +name = "litemap" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "643cb0b8d4fcc284004d5fd0d67ccf61dfffadb7f75e1e71bc420f4688a3a704" + [[package]] name = "lock_api" version = "0.4.12" @@ -1328,7 +1474,6 @@ dependencies = [ [[package]] name = "pil-std-lib" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#2645c3a1695bad2007830f67a527cccb486815ce" dependencies = [ "log", "num-bigint", @@ -1346,7 +1491,6 @@ dependencies = [ [[package]] name = "pilout" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#2645c3a1695bad2007830f67a527cccb486815ce" dependencies = [ "bytes", "log", @@ -1466,7 +1610,6 @@ dependencies = [ [[package]] name = "proofman" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#2645c3a1695bad2007830f67a527cccb486815ce" dependencies = [ "colored", "env_logger", @@ -1487,7 +1630,6 @@ dependencies = [ [[package]] name = "proofman-common" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#2645c3a1695bad2007830f67a527cccb486815ce" dependencies = [ "env_logger", "log", @@ -1505,7 +1647,6 @@ dependencies = [ [[package]] name = "proofman-hints" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#2645c3a1695bad2007830f67a527cccb486815ce" dependencies = [ "p3-field", "proofman-common", @@ -1515,7 +1656,6 @@ dependencies = [ [[package]] name = "proofman-macros" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#2645c3a1695bad2007830f67a527cccb486815ce" dependencies = [ "proc-macro2", "quote", @@ -1525,7 +1665,6 @@ dependencies = [ [[package]] name = "proofman-starks-lib-c" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#2645c3a1695bad2007830f67a527cccb486815ce" dependencies = [ "log", ] @@ -1533,7 +1672,6 @@ dependencies = [ [[package]] name = "proofman-util" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#2645c3a1695bad2007830f67a527cccb486815ce" dependencies = [ "colored", "sysinfo 0.31.4", @@ -2098,6 +2236,7 @@ dependencies = [ "proofman", "proofman-common", "proofman-macros", + "proofman-util", "rayon", "sm-common", "zisk-core", @@ -2167,7 +2306,6 @@ checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" [[package]] name = "stark" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#2645c3a1695bad2007830f67a527cccb486815ce" dependencies = [ "log", "p3-field", @@ -2231,9 +2369,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.86" +version = "2.0.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e89275301d38033efb81a6e60e3497e734dfcc62571f2854bf4b16690398824c" +checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d" dependencies = [ "proc-macro2", "quote", @@ -2249,6 +2387,17 @@ dependencies = [ "futures-core", ] +[[package]] +name = "synstructure" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "sysinfo" version = "0.31.4" @@ -2298,18 +2447,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.66" +version = "1.0.67" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d171f59dbaa811dbbb1aee1e73db92ec2b122911a48e1390dfe327a821ddede" +checksum = "3b3c6efbfc763e64eb85c11c25320f0737cb7364c4b6336db90aa9ebe27a0bbd" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.66" +version = "1.0.67" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b08be0f17bd307950653ce45db00cd31200d82b624b36e181337d9c7d92765b5" +checksum = "b607164372e89797d78b8e23a6d67d5d1038c1c65efd52e1389ef8b77caba2a6" dependencies = [ "proc-macro2", "quote", @@ -2358,6 +2507,16 @@ dependencies = [ "crunchy", ] +[[package]] +name = "tinystr" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9117f5d4db391c1cf6927e7bea3db74b9a1c1add8f7eda9ffd5364f40f57b82f" +dependencies = [ + "displaydoc", + "zerovec", +] + [[package]] name = "tinytemplate" version = "1.2.1" @@ -2476,7 +2635,6 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#2645c3a1695bad2007830f67a527cccb486815ce" dependencies = [ "proofman-starks-lib-c", ] @@ -2497,27 +2655,12 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" -[[package]] -name = "unicode-bidi" -version = "0.3.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ab17db44d7388991a428b2ee655ce0c212e862eff1768a455c58f9aad6e7893" - [[package]] name = "unicode-ident" version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" -[[package]] -name = "unicode-normalization" -version = "0.1.24" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5033c97c4262335cded6d6fc3e5c18ab755e1a3dc96376350f3d8e9f009ad956" -dependencies = [ - "tinyvec", -] - [[package]] name = "unicode-width" version = "0.1.14" @@ -2532,15 +2675,27 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "url" -version = "2.5.2" +version = "2.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22784dbdf76fdde8af1aeda5622b546b422b6fc585325248a2bf9f5e41e94d6c" +checksum = "8d157f1b96d14500ffdc1f10ba712e780825526c03d9a49b4d0324b0d9113ada" dependencies = [ "form_urlencoded", "idna", "percent-encoding", ] +[[package]] +name = "utf16_iter" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8232dd3cdaed5356e0f716d285e4b40b932ac434100fe9b7e0e8e935b9e6246" + +[[package]] +name = "utf8_iter" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" + [[package]] name = "utf8parse" version = "0.2.2" @@ -2964,12 +3119,48 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "write16" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1890f4022759daae28ed4fe62859b1236caebfc61ede2f63ed4e695f3f6d936" + +[[package]] +name = "writeable" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51" + [[package]] name = "yansi" version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" +[[package]] +name = "yoke" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c5b1314b079b0930c31e3af543d8ee1757b1951ae1e1565ec704403a7240ca5" +dependencies = [ + "serde", + "stable_deref_trait", + "yoke-derive", + "zerofrom", +] + +[[package]] +name = "yoke-derive" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28cc31741b18cb6f1d5ff12f5b7523e3d6eb0852bbbad19d73905511d9849b95" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + [[package]] name = "zerocopy" version = "0.7.35" @@ -2991,12 +3182,55 @@ dependencies = [ "syn", ] +[[package]] +name = "zerofrom" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91ec111ce797d0e0784a1116d0ddcdbea84322cd79e5d5ad173daeba4f93ab55" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ea7b4a3637ea8669cedf0f1fd5c286a17f3de97b8dd5a70a6c167a1730e63a5" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + [[package]] name = "zeroize" version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" +[[package]] +name = "zerovec" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa2b893d79df23bfb12d5461018d408ea19dfafe76c2c7ef6d4eba614f8ff079" +dependencies = [ + "yoke", + "zerofrom", + "zerovec-derive", +] + +[[package]] +name = "zerovec-derive" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "zisk-core" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 8a5159f5..da96aae8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,19 +26,19 @@ opt-level = 3 opt-level = 3 [workspace.dependencies] -proofman-common = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", branch ="develop" } -proofman-macros = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", branch ="develop" } -proofman-util = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", branch ="develop" } -proofman = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", branch ="develop" } -pil-std-lib = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", branch ="develop" } -stark = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", branch ="develop" } +# proofman-common = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", branch ="develop" } +# proofman-macros = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", branch ="develop" } +# proofman-util = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", branch ="develop" } +# proofman = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", branch ="develop" } +# pil-std-lib = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", branch ="develop" } +# stark = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", branch ="develop" } #Local development -#proofman-common = { path = "../pil2-proofman/common" } -#proofman-macros = { path = "../pil2-proofman/macros" } -#proofman-util = { path = "../pil2-proofman/util" } -#proofman = { path = "../pil2-proofman/proofman" } -#pil-std-lib = { path = "../pil2-proofman/pil2-components/lib/std/rs" } -#stark = { path = "../pil2-proofman/provers/stark" } +proofman-common = { path = "../pil2-proofman/common" } +proofman-macros = { path = "../pil2-proofman/macros" } +proofman-util = { path = "../pil2-proofman/util" } +proofman = { path = "../pil2-proofman/proofman" } +pil-std-lib = { path = "../pil2-proofman/pil2-components/lib/std/rs" } +stark = { path = "../pil2-proofman/provers/stark" } p3-field = { git = "https://github.com/Plonky3/Plonky3.git", rev = "c3d754ef77b9fce585b46b972af751fe6e7a9803" } log = "0.4" diff --git a/emulator/src/emu.rs b/emulator/src/emu.rs index a56c35d0..b5b8da69 100644 --- a/emulator/src/emu.rs +++ b/emulator/src/emu.rs @@ -9,9 +9,9 @@ use riscv::RiscVRegisters; // #[cfg(feature = "sp")] // use zisk_core::SRC_SP; use zisk_core::{ - InstContext, ZiskInst, ZiskOperationType, ZiskPcHistogram, ZiskRequiredOperation, ZiskRom, - OUTPUT_ADDR, ROM_ENTRY, SRC_C, SRC_IMM, SRC_IND, SRC_MEM, SRC_STEP, STORE_IND, STORE_MEM, - STORE_NONE, SYS_ADDR, ZISK_OPERATION_TYPE_VARIANTS, + InstContext, ZiskInst, ZiskOperationType, ZiskPcHistogram, ZiskRequiredMemory, + ZiskRequiredOperation, ZiskRom, OUTPUT_ADDR, ROM_ENTRY, SRC_C, SRC_IMM, SRC_IND, SRC_MEM, + SRC_STEP, STORE_IND, STORE_MEM, STORE_NONE, SYS_ADDR, ZISK_OPERATION_TYPE_VARIANTS, }; /// ZisK emulator structure, containing the ZisK rom, the list of ZisK operations, and the @@ -92,6 +92,47 @@ impl<'a> Emu<'a> { } } + /// Calculate the 'a' register value based on the source specified by the current instruction + #[inline(always)] + pub fn source_a_memory( + &mut self, + instruction: &ZiskInst, + emu_mem: &mut Vec, + is_aligned: bool, + ) { + match instruction.a_src { + SRC_C => self.ctx.inst_ctx.a = self.ctx.inst_ctx.c, + SRC_MEM => { + let mut addr = instruction.a_offset_imm0; + if instruction.a_use_sp_imm1 != 0 { + addr += self.ctx.inst_ctx.sp; + } + self.ctx.inst_ctx.a = self.ctx.inst_ctx.mem.read(addr, 8); + + if is_aligned == Self::is_8_aligned(addr, 8) { + let required_memory = ZiskRequiredMemory { + step: self.ctx.inst_ctx.step, + is_write: false, + address: addr, + width: 8, + value: self.ctx.inst_ctx.a, + }; + emu_mem.push(required_memory); + } + } + SRC_IMM => { + self.ctx.inst_ctx.a = instruction.a_offset_imm0 | (instruction.a_use_sp_imm1 << 32) + } + SRC_STEP => self.ctx.inst_ctx.a = self.ctx.inst_ctx.step, + // #[cfg(feature = "sp")] + // SRC_SP => self.ctx.inst_ctx.a = self.ctx.inst_ctx.sp, + _ => panic!( + "Emu::source_a() Invalid a_src={} pc={}", + instruction.a_src, self.ctx.inst_ctx.pc + ), + } + } + /// Calculate the 'b' register value based on the source specified by the current instruction #[inline(always)] pub fn source_b(&mut self, instruction: &ZiskInst) { @@ -128,6 +169,62 @@ impl<'a> Emu<'a> { } } + /// Calculate the 'b' register value based on the source specified by the current instruction + #[inline(always)] + pub fn source_b_memory( + &mut self, + instruction: &ZiskInst, + emu_mem: &mut Vec, + is_aligned: bool, + ) { + match instruction.b_src { + SRC_C => self.ctx.inst_ctx.b = self.ctx.inst_ctx.c, + SRC_MEM => { + let mut addr = instruction.b_offset_imm0; + if instruction.b_use_sp_imm1 != 0 { + addr += self.ctx.inst_ctx.sp; + } + self.ctx.inst_ctx.b = self.ctx.inst_ctx.mem.read(addr, 8); + + if is_aligned == Self::is_8_aligned(addr, 8) { + let required_memory = ZiskRequiredMemory { + step: self.ctx.inst_ctx.step, + is_write: false, + address: addr, + width: 8, + value: self.ctx.inst_ctx.b, + }; + emu_mem.push(required_memory); + } + } + SRC_IMM => { + self.ctx.inst_ctx.b = instruction.b_offset_imm0 | (instruction.b_use_sp_imm1 << 32) + } + SRC_IND => { + let mut addr = + (self.ctx.inst_ctx.a as i64 + instruction.b_offset_imm0 as i64) as u64; + if instruction.b_use_sp_imm1 != 0 { + addr += self.ctx.inst_ctx.sp; + } + self.ctx.inst_ctx.b = self.ctx.inst_ctx.mem.read(addr, instruction.ind_width); + if is_aligned == Self::is_8_aligned(addr, instruction.ind_width) { + let required_memory = ZiskRequiredMemory { + step: self.ctx.inst_ctx.step, + is_write: false, + address: addr, + width: instruction.ind_width, + value: self.ctx.inst_ctx.b, + }; + emu_mem.push(required_memory); + } + } + _ => panic!( + "Emu::source_b() Invalid b_src={} pc={}", + instruction.b_src, self.ctx.inst_ctx.pc + ), + } + } + /// Store the 'c' register value based on the storage specified by the current instruction #[inline(always)] pub fn store_c(&mut self, instruction: &ZiskInst) { @@ -171,6 +268,75 @@ impl<'a> Emu<'a> { } } + /// Store the 'c' register value based on the storage specified by the current instruction + #[inline(always)] + pub fn store_c_memory( + &mut self, + instruction: &ZiskInst, + emu_mem: &mut Vec, + is_aligned: bool, + ) { + match instruction.store { + STORE_NONE => {} + STORE_MEM => { + let val: i64 = if instruction.store_ra { + self.ctx.inst_ctx.pc as i64 + instruction.jmp_offset2 + } else { + self.ctx.inst_ctx.c as i64 + }; + let mut addr: i64 = instruction.store_offset; + if instruction.store_use_sp { + addr += self.ctx.inst_ctx.sp as i64; + } + self.ctx.inst_ctx.mem.write_silent(addr as u64, val as u64, 8); + + if is_aligned == Self::is_8_aligned(addr as u64, 8) { + let required_memory = ZiskRequiredMemory { + step: self.ctx.inst_ctx.step, + is_write: true, + address: addr as u64, + width: 8, + value: val as u64, + }; + emu_mem.push(required_memory); + } + } + STORE_IND => { + let val: i64 = if instruction.store_ra { + self.ctx.inst_ctx.pc as i64 + instruction.jmp_offset2 + } else { + self.ctx.inst_ctx.c as i64 + }; + let mut addr = instruction.store_offset; + if instruction.store_use_sp { + addr += self.ctx.inst_ctx.sp as i64; + } + addr += self.ctx.inst_ctx.a as i64; + self.ctx.inst_ctx.mem.write_silent(addr as u64, val as u64, instruction.ind_width); + + if is_aligned == Self::is_8_aligned(addr as u64, instruction.ind_width) { + let required_memory = ZiskRequiredMemory { + step: self.ctx.inst_ctx.step, + is_write: true, + address: addr as u64, + width: instruction.ind_width, + value: val as u64, + }; + emu_mem.push(required_memory); + } + } + _ => panic!( + "Emu::store_c() Invalid store={} pc={}", + instruction.store, self.ctx.inst_ctx.pc + ), + } + } + + #[inline(always)] + fn is_8_aligned(address: u64, width: u64) -> bool { + address & 7 == 0 && width == 8 + } + /// Store the 'c' register value based on the storage specified by the current instruction and /// log memory access if required #[inline(always)] @@ -335,9 +501,9 @@ impl<'a> Emu<'a> { } // Log emulation step, if requested - if options.print_step.is_some() && - (options.print_step.unwrap() != 0) && - ((self.ctx.inst_ctx.step % options.print_step.unwrap()) == 0) + if options.print_step.is_some() + && (options.print_step.unwrap() != 0) + && ((self.ctx.inst_ctx.step % options.print_step.unwrap()) == 0) { println!("step={}", self.ctx.inst_ctx.step); } @@ -449,6 +615,26 @@ impl<'a> Emu<'a> { (emu_traces, emu_segments) } + pub fn par_run_memory( + &mut self, + inputs: Vec, + is_aligned: bool, + ) -> Vec { + // Context, where the state of the execution is stored and modified at every execution step + self.ctx = self.create_emu_context(inputs); + + // Init pc to the rom entry address + self.ctx.trace.start_state.pc = ROM_ENTRY; + + let mut emu_mem = Vec::new(); + + while !self.ctx.inst_ctx.end { + self.par_step_memory::(&mut emu_mem, is_aligned); + } + + emu_mem + } + /// Performs one single step of the emulation #[inline(always)] #[allow(unused_variables)] @@ -520,9 +706,9 @@ impl<'a> Emu<'a> { // Increment step counter self.ctx.inst_ctx.step += 1; - if self.ctx.inst_ctx.end || - ((self.ctx.inst_ctx.step - self.ctx.last_callback_step) == - self.ctx.callback_steps) + if self.ctx.inst_ctx.end + || ((self.ctx.inst_ctx.step - self.ctx.last_callback_step) + == self.ctx.callback_steps) { // In run() we have checked the callback consistency with ctx.do_callback let callback = callback.as_ref().unwrap(); @@ -622,6 +808,45 @@ impl<'a> Emu<'a> { self.ctx.inst_ctx.step += 1; } + /// Performs one single step of the emulation + #[inline(always)] + #[allow(unused_variables)] + pub fn par_step_memory( + &mut self, + emu_mem: &mut Vec, + is_aligned: bool, + ) { + let last_pc = self.ctx.inst_ctx.pc; + let last_c = self.ctx.inst_ctx.c; + + let instruction = self.rom.get_instruction(self.ctx.inst_ctx.pc); + + // Build the 'a' register value based on the source specified by the current instruction + self.source_a_memory(instruction, emu_mem, is_aligned); + + // Build the 'b' register value based on the source specified by the current instruction + self.source_b_memory(instruction, emu_mem, is_aligned); + + // Call the operation + (instruction.func)(&mut self.ctx.inst_ctx); + + // Store the 'c' register value based on the storage specified by the current instruction + self.store_c_memory(instruction, emu_mem, is_aligned); + + // Set SP, if specified by the current instruction + // #[cfg(feature = "sp")] + // self.set_sp(instruction); + + // Set PC, based on current PC, current flag and current instruction + self.set_pc(instruction); + + // If this is the last instruction, stop executing + self.ctx.inst_ctx.end = instruction.end; + + // Increment step counter + self.ctx.inst_ctx.step += 1; + } + /// Performs one single step of the emulation #[inline(always)] #[allow(unused_variables)] @@ -699,11 +924,11 @@ impl<'a> Emu<'a> { let mut current_box_id = 0; let mut current_step_idx = loop { - if current_box_id == vec_traces.len() - 1 || - vec_traces[current_box_id + 1].start_state.step >= emu_trace_start.step + if current_box_id == vec_traces.len() - 1 + || vec_traces[current_box_id + 1].start_state.step >= emu_trace_start.step { - break emu_trace_start.step as usize - - vec_traces[current_box_id].start_state.step as usize; + break emu_trace_start.step as usize + - vec_traces[current_box_id].start_state.step as usize; } current_box_id += 1; }; @@ -814,8 +1039,8 @@ impl<'a> Emu<'a> { let b = [inst_ctx.b & 0xFFFFFFFF, (inst_ctx.b >> 32) & 0xFFFFFFFF]; let c = [inst_ctx.c & 0xFFFFFFFF, (inst_ctx.c >> 32) & 0xFFFFFFFF]; - let addr1 = (inst.b_offset_imm0 as i64 + - if inst.b_src == SRC_IND { inst_ctx.a as i64 } else { 0 }) as u64; + let addr1 = (inst.b_offset_imm0 as i64 + + if inst.b_src == SRC_IND { inst_ctx.a as i64 } else { 0 }) as u64; let jmp_offset1 = if inst.jmp_offset1 >= 0 { F::from_canonical_u64(inst.jmp_offset1 as u64) @@ -893,8 +1118,8 @@ impl<'a> Emu<'a> { m32: F::from_bool(inst.m32), addr1: F::from_canonical_u64(addr1), __debug_operation_bus_enabled: F::from_bool( - inst.op_type == ZiskOperationType::Binary || - inst.op_type == ZiskOperationType::BinaryE, + inst.op_type == ZiskOperationType::Binary + || inst.op_type == ZiskOperationType::BinaryE, ), } } diff --git a/emulator/src/emulator.rs b/emulator/src/emulator.rs index 6f7078f8..eb862fe0 100644 --- a/emulator/src/emulator.rs +++ b/emulator/src/emulator.rs @@ -11,8 +11,8 @@ use std::{ }; use sysinfo::System; use zisk_core::{ - Riscv2zisk, ZiskOperationType, ZiskPcHistogram, ZiskRequiredOperation, ZiskRom, - ZISK_OPERATION_TYPE_VARIANTS, + Riscv2zisk, ZiskOperationType, ZiskPcHistogram, ZiskRequiredMemory, ZiskRequiredOperation, + ZiskRom, ZISK_OPERATION_TYPE_VARIANTS, }; pub trait Emulator { @@ -243,6 +243,29 @@ impl ZiskEmulator { Ok((vec_traces, emu_slices)) } + pub fn par_process_rom_memory( + rom: &ZiskRom, + inputs: &[u8], + ) -> Result<[Vec; 2], ZiskEmulatorErr> { + let mut result: [Vec; 2] = [Vec::new(), Vec::new()]; + + result.par_iter_mut().enumerate().for_each(|(is_aligned, result)| { + let is_aligned = is_aligned == 0; + let mut emu = Emu::new(rom); + let required = emu.par_run_memory::(inputs.to_owned(), is_aligned); + + if !emu.terminated() { + panic!("Emulation did not complete"); + // TODO! + // return Err(ZiskEmulatorErr::EmulationNoCompleted); + } + + *result = required; + }); + + Ok(result) + } + #[inline] pub fn process_slice_required( rom: &ZiskRom, diff --git a/pil/src/pil_helpers/pilout.rs b/pil/src/pil_helpers/pilout.rs index e16e3936..449b8417 100644 --- a/pil/src/pil_helpers/pilout.rs +++ b/pil/src/pil_helpers/pilout.rs @@ -2,7 +2,7 @@ // Manual modifications are not recommended and may be overwritten. use proofman_common::WitnessPilout; -pub const PILOUT_HASH: &[u8] = b"ZiskContinuations1-hash"; +pub const PILOUT_HASH: &[u8] = b"Zisk-hash"; //AIRGROUP CONSTANTS @@ -18,8 +18,6 @@ pub const BINARY_EXTENSION_AIRGROUP_ID: usize = 4; pub const BINARY_EXTENSION_TABLE_AIRGROUP_ID: usize = 5; -pub const SPECIFIED_RANGES_AIRGROUP_ID: usize = 6; - //AIR CONSTANTS pub const MAIN_AIR_IDS: &[usize] = &[0]; @@ -32,15 +30,15 @@ pub const BINARY_TABLE_AIR_IDS: &[usize] = &[0]; pub const BINARY_EXTENSION_AIR_IDS: &[usize] = &[0]; -pub const BINARY_EXTENSION_TABLE_AIR_IDS: &[usize] = &[0]; +pub const SPECIFIED_RANGES_AIR_IDS: &[usize] = &[1]; -pub const SPECIFIED_RANGES_AIR_IDS: &[usize] = &[0]; +pub const BINARY_EXTENSION_TABLE_AIR_IDS: &[usize] = &[0]; pub struct Pilout; impl Pilout { pub fn pilout() -> WitnessPilout { - let mut pilout = WitnessPilout::new("ZiskContinuations1", 2, PILOUT_HASH.to_vec()); + let mut pilout = WitnessPilout::new("Zisk", 2, PILOUT_HASH.to_vec()); let air_group = pilout.add_air_group(Some("Main")); @@ -61,15 +59,12 @@ impl Pilout { let air_group = pilout.add_air_group(Some("BinaryExtension")); air_group.add_air(Some("BinaryExtension"), 2097152); + air_group.add_air(Some("SpecifiedRanges"), 16777216); let air_group = pilout.add_air_group(Some("BinaryExtensionTable")); air_group.add_air(Some("BinaryExtensionTable"), 4194304); - let air_group = pilout.add_air_group(Some("SpecifiedRanges")); - - air_group.add_air(Some("SpecifiedRanges"), 16777216); - pilout } } diff --git a/pil/src/pil_helpers/traces.rs b/pil/src/pil_helpers/traces.rs index 67d7db98..5fb91f20 100644 --- a/pil/src/pil_helpers/traces.rs +++ b/pil/src/pil_helpers/traces.rs @@ -23,10 +23,10 @@ trace!(BinaryExtension0Row, BinaryExtension0Trace { op: F, in1: [F; 8], in2_low: F, out: [[F; 2]; 8], op_is_shift: F, in2: [F; 2], main_step: F, multiplicity: F, }); -trace!(BinaryExtensionTable0Row, BinaryExtensionTable0Trace { - multiplicity: F, +trace!(SpecifiedRanges1Row, SpecifiedRanges1Trace { + mul: [F; 1], }); -trace!(SpecifiedRanges0Row, SpecifiedRanges0Trace { - mul: [F; 1], +trace!(BinaryExtensionTable0Row, BinaryExtensionTable0Trace { + multiplicity: F, }); diff --git a/state-machines/binary/src/binary.rs b/state-machines/binary/src/binary.rs index 0acc615d..c8b128f9 100644 --- a/state-machines/binary/src/binary.rs +++ b/state-machines/binary/src/binary.rs @@ -79,13 +79,12 @@ impl BinarySM { pub fn unregister_predecessor(&self) { if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { - /* as Provable>::prove( - self, - &[], - true, - scope, - );*/ - //self.threads_controller.wait_for_threads(); + // as Provable>::prove( + // self, + // &[], + // true, + // scope, + // ); self.binary_basic_sm.unregister_predecessor(); self.binary_extension_sm.unregister_predecessor(); diff --git a/state-machines/main/pil/main.pil b/state-machines/main/pil/main.pil index e4d406f9..027bee94 100644 --- a/state-machines/main/pil/main.pil +++ b/state-machines/main/pil/main.pil @@ -275,5 +275,5 @@ airtemplate Main(int N = 2**21, int RC = 2, int stack_enabled = 0, const int ope op, store_offset, jmp_offset1, jmp_offset2, rom_flags], sel: 1 - SEGMENT_L1); direct_update(MAIN_CONTINUATION_ID, cols: [0, 0, 4096, 0, 0], bus_type: PIOP_BUS_SUM, proves: 1); - direct_update(MAIN_CONTINUATION_ID, cols: [0, 1, 4312, 0, 0], bus_type: PIOP_BUS_SUM, proves: 0); + direct_update(MAIN_CONTINUATION_ID, cols: [0, 1, 0x10000000, 0, 0], bus_type: PIOP_BUS_SUM, proves: 0); } \ No newline at end of file diff --git a/state-machines/main/src/main_sm.rs b/state-machines/main/src/main_sm.rs index 0e2fe78f..c7d15619 100644 --- a/state-machines/main/src/main_sm.rs +++ b/state-machines/main/src/main_sm.rs @@ -33,9 +33,6 @@ pub struct MainSM { /// Binary state machine binary_sm: Arc>, - - /// Memory state machine - mem_sm: Arc, } impl MainSM { @@ -56,14 +53,12 @@ impl MainSM { wcm: Arc>, arith_sm: Arc, binary_sm: Arc>, - mem_sm: Arc, ) -> Arc { - let main_sm = Arc::new(Self { wcm: wcm.clone(), arith_sm, binary_sm, mem_sm }); + let main_sm = Arc::new(Self { wcm: wcm.clone(), arith_sm, binary_sm }); wcm.register_component(main_sm.clone(), Some(MAIN_AIRGROUP_ID), Some(MAIN_AIR_IDS)); // For all the secondary state machines, register the main state machine as a predecessor - main_sm.mem_sm.register_predecessor(); main_sm.binary_sm.register_predecessor(); main_sm.arith_sm.register_predecessor(); diff --git a/state-machines/mem/Cargo.toml b/state-machines/mem/Cargo.toml index 3f8ee914..39264a00 100644 --- a/state-machines/mem/Cargo.toml +++ b/state-machines/mem/Cargo.toml @@ -11,6 +11,7 @@ zisk-pil = { path = "../../pil" } p3-field = { workspace=true } proofman-common = { workspace = true } proofman-macros = { workspace = true } +proofman-util = { workspace = true } proofman = { workspace = true } log = { workspace = true } rayon = { workspace = true } diff --git a/state-machines/mem/src/lib.rs b/state-machines/mem/src/lib.rs index 67bf225c..47dd31fd 100644 --- a/state-machines/mem/src/lib.rs +++ b/state-machines/mem/src/lib.rs @@ -1,9 +1,7 @@ -mod mem; -mod mem_aligned; -mod mem_traces; -mod mem_unaligned; +mod mem_align_sm; +mod mem_sm; +mod mem_proxy; -pub use mem::*; -pub use mem_aligned::*; -pub use mem_traces::*; -pub use mem_unaligned::*; +pub use mem_align_sm::*; +pub use mem_sm::*; +pub use mem_proxy::*; diff --git a/state-machines/mem/src/mem.rs b/state-machines/mem/src/mem.rs deleted file mode 100644 index 065b1841..00000000 --- a/state-machines/mem/src/mem.rs +++ /dev/null @@ -1,101 +0,0 @@ -use std::sync::{ - atomic::{AtomicU32, Ordering}, - Arc, Mutex, -}; - -use crate::{MemAlignedSM, MemUnalignedSM}; -use p3_field::Field; -use rayon::Scope; -use sm_common::{MemOp, MemUnalignedOp, OpResult, Provable}; -use zisk_core::ZiskRequiredMemory; - -use proofman::{WitnessComponent, WitnessManager}; -use proofman_common::{ExecutionCtx, ProofCtx, SetupCtx}; - -#[allow(dead_code)] -const PROVE_CHUNK_SIZE: usize = 1 << 12; - -#[allow(dead_code)] -pub struct MemSM { - // Count of registered predecessors - registered_predecessors: AtomicU32, - - // Inputs - inputs_aligned: Mutex>, - inputs_unaligned: Mutex>, - - // Secondary State machines - mem_aligned_sm: Arc, - mem_unaligned_sm: Arc, -} - -impl MemSM { - pub fn new(wcm: Arc>) -> Arc { - let mem_aligned_sm = MemAlignedSM::new(wcm.clone()); - let mem_unaligned_sm = MemUnalignedSM::new(wcm.clone()); - - let mem_sm = Self { - registered_predecessors: AtomicU32::new(0), - inputs_aligned: Mutex::new(Vec::new()), - inputs_unaligned: Mutex::new(Vec::new()), - mem_aligned_sm: mem_aligned_sm.clone(), - mem_unaligned_sm: mem_unaligned_sm.clone(), - }; - let mem_sm = Arc::new(mem_sm); - - wcm.register_component(mem_sm.clone(), None, None); - - // For all the secondary state machines, register the main state machine as a predecessor - mem_sm.mem_aligned_sm.register_predecessor(); - mem_sm.mem_unaligned_sm.register_predecessor(); - - mem_sm - } - - pub fn register_predecessor(&self) { - self.registered_predecessors.fetch_add(1, Ordering::SeqCst); - } - - pub fn unregister_predecessor(&self, scope: &Scope) { - if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { - >::prove(self, &[], true, scope); - - self.mem_aligned_sm.unregister_predecessor::(scope); - self.mem_unaligned_sm.unregister_predecessor::(scope); - } - } -} - -impl WitnessComponent for MemSM { - fn calculate_witness( - &self, - _stage: u32, - _air_instance: Option, - _pctx: Arc>, - _ectx: Arc, - _sctx: Arc, - ) { - } -} - -impl Provable for MemSM { - fn calculate( - &self, - _operation: ZiskRequiredMemory, - ) -> Result> { - unimplemented!() - } - - fn prove(&self, _operations: &[ZiskRequiredMemory], _drain: bool, _scope: &Scope) { - // TODO! - } - - fn calculate_prove( - &self, - _operation: ZiskRequiredMemory, - _drain: bool, - _scope: &Scope, - ) -> Result> { - unimplemented!() - } -} diff --git a/state-machines/mem/src/mem_aligned.rs b/state-machines/mem/src/mem_align_sm.rs similarity index 92% rename from state-machines/mem/src/mem_aligned.rs rename to state-machines/mem/src/mem_align_sm.rs index 1a126e3c..0eeb4c38 100644 --- a/state-machines/mem/src/mem_aligned.rs +++ b/state-machines/mem/src/mem_align_sm.rs @@ -12,7 +12,7 @@ use zisk_pil::{MEM_AIRGROUP_ID, MEM_ALIGN_AIR_IDS}; const PROVE_CHUNK_SIZE: usize = 1 << 12; -pub struct MemAlignedSM { +pub struct MemAlignSM { // Count of registered predecessors registered_predecessors: AtomicU32, @@ -21,7 +21,7 @@ pub struct MemAlignedSM { } #[allow(unused, unused_variables)] -impl MemAlignedSM { +impl MemAlignSM { pub fn new(wcm: Arc>) -> Arc { let mem_aligned_sm = Self { registered_predecessors: AtomicU32::new(0), inputs: Mutex::new(Vec::new()) }; @@ -42,7 +42,7 @@ impl MemAlignedSM { pub fn unregister_predecessor(&self, scope: &Scope) { if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { - >::prove(self, &[], true, scope); + >::prove(self, &[], true, scope); } } @@ -62,7 +62,7 @@ impl MemAlignedSM { } } -impl WitnessComponent for MemAlignedSM { +impl WitnessComponent for MemAlignSM { fn calculate_witness( &self, _stage: u32, @@ -74,7 +74,7 @@ impl WitnessComponent for MemAlignedSM { } } -impl Provable for MemAlignedSM { +impl Provable for MemAlignSM { fn calculate(&self, operation: MemOp) -> Result> { match operation { MemOp::Read(addr) => self.read(addr), diff --git a/state-machines/mem/src/mem_proxy.rs b/state-machines/mem/src/mem_proxy.rs new file mode 100644 index 00000000..c7a4bce9 --- /dev/null +++ b/state-machines/mem/src/mem_proxy.rs @@ -0,0 +1,115 @@ +use std::sync::{ + atomic::{AtomicU32, Ordering}, + Arc, Mutex, +}; + +use crate::{MemAlignSM, MemSM}; +use p3_field::{Field, PrimeField}; +use proofman_util::{timer_start_debug, timer_stop_and_log_debug}; +use sm_common::{MemOp, MemUnalignedOp}; +use zisk_core::ZiskRequiredMemory; + +use proofman::{WitnessComponent, WitnessManager}; + +#[allow(dead_code)] +const PROVE_CHUNK_SIZE: usize = 1 << 12; + +#[allow(dead_code)] +pub struct MemProxy { + // Count of registered predecessors + registered_predecessors: AtomicU32, + + // Inputs + inputs_aligned: Mutex>, + inputs_unaligned: Mutex>, + + // Secondary State machines + mem_sm: Arc>, + mem_align_sm: Arc, +} + +impl MemProxy { + pub fn new(wcm: Arc>) -> Arc { + let mem_sm = MemSM::new(wcm.clone()); + let mem_align_sm = MemAlignSM::new(wcm.clone()); + + let mem_proxy = Self { + registered_predecessors: AtomicU32::new(0), + inputs_aligned: Mutex::new(Vec::new()), + inputs_unaligned: Mutex::new(Vec::new()), + mem_sm: mem_sm.clone(), + mem_align_sm: mem_align_sm.clone(), + }; + let mem_proxy = Arc::new(mem_proxy); + + wcm.register_component(mem_proxy.clone(), None, None); + + // For all the secondary state machines, register the main state machine as a predecessor + mem_sm.register_predecessor(); + mem_align_sm.register_predecessor(); + + mem_proxy + } + + pub fn register_predecessor(&self) { + self.registered_predecessors.fetch_add(1, Ordering::SeqCst); + } + + pub fn unregister_predecessor(&self) { + if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { + // self.mem_sm.unregister_predecessor(); + // self.mem_align_sm.unregister_predecessor::(); + } + } + + pub fn prove( + &self, + mut operations: [Vec; 2], + ) -> Result<(), Box> { + let mut aligned = std::mem::take(&mut operations[0]); + let non_aligned = std::mem::take(&mut operations[1]); + let new_aligned = Vec::new(); + + // Step 1. Sort the aligned memory accesses + timer_start_debug!(MEM_SORT); + aligned.sort_by_key(|mem| mem.address); + timer_stop_and_log_debug!(MEM_SORT); + + // Step 2. For each non-aligned memory access + non_aligned.iter().for_each(|mem| { + // Step 2.1 Find the possible aligned memory access + let potential_aligned_mem = self.get_potential_aligned_mem(&aligned, &mem); + + // Step 2.2 Align memory access using mem_align state machine + // self.mem_aligned_sm.align_mem_accesses(potential_aligned_mem, mem, &mut new_aligned); + + // Step 2.3 Store the new aligned memory access(es) + }); + + // Step 3. Concatenate the new aligned memory accesses with the original aligned memory accesses + aligned.extend(new_aligned); + + // Step 4. Sort the (full) aligned memory accesses + timer_start_debug!(MEM_SORT_2); + aligned.sort_by_key(|mem| mem.address); + timer_stop_and_log_debug!(MEM_SORT_2); + + // Step 5. Prove the aligned memory accesses using mem state machine + + println!("Proving MemSM"); + println!("Aligned: {:?}", operations[0].len()); + println!("Non aligned: {:?}", operations[1].len()); + Ok(()) + } + + fn get_potential_aligned_mem( + &self, + aligned_accesses: &[ZiskRequiredMemory], + unaligned_access: &ZiskRequiredMemory, + ) -> Vec { + let mut aligned_mem = Vec::new(); + aligned_mem + } +} + +impl WitnessComponent for MemProxy {} diff --git a/state-machines/mem/src/mem_sm.rs b/state-machines/mem/src/mem_sm.rs new file mode 100644 index 00000000..b0febe2f --- /dev/null +++ b/state-machines/mem/src/mem_sm.rs @@ -0,0 +1,284 @@ +use std::sync::{ + atomic::{AtomicU32, Ordering}, + Arc, Mutex, +}; + +use p3_field::PrimeField; +use proofman::{WitnessComponent, WitnessManager}; +use proofman_common::{AirInstance, ExecutionCtx, ProofCtx, SetupCtx}; +use rayon::Scope; +use sm_common::{MemOp, OpResult, Provable}; +use zisk_core::ZiskRequiredMemory; +// use zisk_pil::{Mem0Trace, MEM_AIRGROUP_ID, MEM_AIR_IDS}; + +const PROVE_CHUNK_SIZE: usize = 1 << 12; + +pub struct MemSM { + // Witness computation manager + wcm: Arc>, + + // Count of registered predecessors + registered_predecessors: AtomicU32, + + // Inputs + inputs: Mutex>, + + _phantom: std::marker::PhantomData, +} + +#[allow(unused, unused_variables)] +impl MemSM { + pub fn new(wcm: Arc>) -> Arc { + let mem_sm = Self { + wcm: wcm.clone(), + registered_predecessors: AtomicU32::new(0), + inputs: Mutex::new(Vec::new()), + _phantom: std::marker::PhantomData, + }; + let mem_sm = Arc::new(mem_sm); + + // wcm.register_component(mem_sm.clone(), Some(MEM_AIRGROUP_ID), Some(MEM_AIR_IDS)); + + mem_sm + } + + pub fn register_predecessor(&self) { + self.registered_predecessors.fetch_add(1, Ordering::SeqCst); + } + + pub fn unregister_predecessor(&self) { + if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { + // as Provable>::prove(self, &[], true, scope); + } + } + + /// Finalizes the witness accumulation process and triggers the proof generation. + /// + /// This method is invoked by the executor when no further witness data remains to be added. + /// + /// # Parameters + /// + /// - `mem_inputs`: A slice of all `ZiskRequiredMemory` inputs + pub fn prove_instance( + &self, + mem_ops: &[ZiskRequiredMemory], + mem_first_row: ZiskRequiredMemory, + segment_id: usize, + is_last_segment: bool, + mut prover_buffer: Vec, + offset: u64, + pctx: Arc>, + ectx: Arc, + sctx: Arc, + ) -> Result<(), Box> { + // STEP2: Process the memory inputs and convert them to AIR instances + // let air = pctx.pilout.get_air(MEM_AIRGROUP_ID, MEM_AIR_IDS[0]); + + // let max_rows_per_segment = air.num_rows() - 1; + + // assert!(mem_ops.len() > 0 && mem_ops.len() <= max_rows_per_segment); + + // // In a Mem AIR instance the first row is a dummy row used for the continuations between AIR segments + // // In a Memory AIR instance, the first row is reserved as a dummy row. + // // This dummy row is used to facilitate the continuation state between different AIR segments. + // // It ensures seamless transitions when multiple AIR segments are processed consecutively. + // // This design avoids discontinuities in memory access patterns and ensures that the memory trace is continuous, + // // For this reason we use AIR num_rows - 1 as the number of rows in each memory AIR instance + + // // Create a vector of Mem0Row instances, one for each memory operation + // // Recall that first row is a dummy row used for the continuations between AIR segments + // // The length of the vector is the number of input memory operations plus one because + // // in the prove_witnesses method we drain the memory operations in chunks of n - 1 rows + + // let mut trace = + // Mem0Trace::::map_buffer(&mut prover_buffer, air.num_rows(), offset as usize) + // .unwrap(); + + // let segment_id_field = F::from_canonical_u64(segment_id as u64); + // let is_last_segment_field = F::from_bool(is_last_segment); + + // // STEP1. Add the first row to the output vector as equal to the last row of the previous segment + // // CASE: last row of segment is read + // // + // // S[n-1] wr = 0, sel = 1, addr, step, value + // // S+1[0] wr = 0, sel = 0, addr, step, value + // // + // // CASE: last row of segment is write + // // + // // S[n-1] wr = 1, sel = 1, addr, step, value + // // S+1[0] wr = 0, sel = 0, addr, step, value + + // trace[0].mem_segment = segment_id_field; + // trace[0].mem_last_segment = is_last_segment_field; + + // trace[0].addr = F::from_canonical_u64(mem_first_row.address); + // trace[0].step = F::from_canonical_u64(mem_first_row.step); + // trace[0].sel = F::zero(); + // trace[0].wr = F::zero(); + + // let value = match mem_first_row.width { + // 1 => mem_first_row.value as u8 as u64, + // 2 => mem_first_row.value as u16 as u64, + // 4 => mem_first_row.value as u32 as u64, + // 8 => mem_first_row.value, + // _ => panic!("Invalid width"), + // }; + // let (low_val, high_val) = self.get_u32_values(value); + // trace[0].value = [F::from_canonical_u32(low_val), F::from_canonical_u32(high_val)]; + // trace[0].addr_changes = F::zero(); + + // trace[0].same_value = F::zero(); + // trace[0].first_addr_access_is_read = F::zero(); + + // // STEP2. Add all the memory operations to the buffer + // for (idx, mem_op) in mem_ops.iter().enumerate() { + // let i = idx + 1; + // trace[i].mem_segment = segment_id_field; + // trace[i].mem_last_segment = is_last_segment_field; + + // trace[i].addr = F::from_canonical_u64(mem_op.address); // n-byte address, real address = addr * MEM_BYTES + // trace[i].step = F::from_canonical_u64(mem_op.step); + // trace[i].sel = F::one(); + // trace[i].wr = F::from_bool(mem_op.is_write); + + // let value = match mem_op.width { + // 1 => mem_op.value as u8 as u64, + // 2 => mem_op.value as u16 as u64, + // 4 => mem_op.value as u32 as u64, + // 8 => mem_op.value, + // _ => panic!("Invalid width"), + // }; + // let (low_val, high_val) = self.get_u32_values(value); + // trace[i].value = [F::from_canonical_u32(low_val), F::from_canonical_u32(high_val)]; + // if i == 66587 || i == 66586 { + // println!( + // "mem_op.value: {:?} value: {:?} width: {}", + // mem_op.value, trace[i].value, mem_op.width + // ); + // println!("mem_op: {:?}", mem_op); + // } + // let addr_changes = trace[i - 1].addr != trace[i].addr; + // trace[i].addr_changes = if addr_changes { F::one() } else { F::zero() }; + + // let same_value = trace[i - 1].value[0] == trace[i].value[0] + // && trace[i - 1].value[1] == trace[i].value[1]; + // trace[i].same_value = if same_value { F::one() } else { F::zero() }; + + // let first_addr_access_is_read = addr_changes && !mem_op.is_write; + // trace[i].first_addr_access_is_read = + // if first_addr_access_is_read { F::one() } else { F::zero() }; + + // if i == 66587 || i == 66586 { + // println!("trace[{}]: {:?}", i, trace[i]); + // } + // } + + // // STEP3. Add dummy rows to the output vector to fill the remaining rows + // //PADDING: At end of memory fill with same addr, incrementing step, same value, sel = 0, rd = 1, wr = 0 + // let last_row_idx = mem_ops.len(); + // let addr = trace[last_row_idx].addr; + // let mut step = trace[last_row_idx].step; + // let value = trace[last_row_idx].value; + + // for i in (mem_ops.len() + 1)..air.num_rows() { + // step += F::one(); + + // trace[i].mem_segment = segment_id_field; + // trace[i].mem_last_segment = is_last_segment_field; + + // trace[i].addr = addr; + // trace[i].step = step; + // trace[i].sel = F::zero(); + // trace[i].wr = F::zero(); + + // trace[i].value = value; + + // trace[i].addr_changes = F::zero(); + // trace[i].same_value = F::one(); + // trace[i].first_addr_access_is_read = F::zero(); + // } + + // let air_instance = AirInstance::new( + // self.wcm.get_sctx(), + // MEM_AIRGROUP_ID, + // MEM_AIR_IDS[0], + // Some(segment_id), + // prover_buffer, + // ); + + // pctx.air_instance_repo.add_air_instance(air_instance); + + Ok(()) + } + + fn get_u32_values(&self, value: u64) -> (u32, u32) { + (value as u32, (value >> 32) as u32) + } +} + +impl WitnessComponent for MemSM {} + +impl Provable for MemSM { + fn prove(&self, operations: &[MemOp], drain: bool, scope: &Scope) { + if let Ok(mut inputs) = self.inputs.lock() { + inputs.extend_from_slice(operations); + + while inputs.len() >= PROVE_CHUNK_SIZE || (drain && !inputs.is_empty()) { + let num_drained = std::cmp::min(PROVE_CHUNK_SIZE, inputs.len()); + let _drained_inputs = inputs.drain(..num_drained).collect::>(); + + scope.spawn(move |_| { + // TODO! Implement prove drained_inputs (a chunk of operations) + }); + } + } + } +} + +#[cfg(test)] +mod tests { + // use super::*; + // use p3_field::AbstractField; + // use p3_goldilocks::Goldilocks; + // use zisk_core::ZiskRequiredMemory; + + // type GL = Goldilocks; + + // #[test] + // fn test_calculate_witness_rows() { + // let mem_ops = vec![ + // ZiskRequiredMemory::new(0, true, 0, 1, 0), + // ZiskRequiredMemory::new(1, false, 1, 1, 0), + // ZiskRequiredMemory::new(2, true, 2, 1, 0), + // ZiskRequiredMemory::new(3, false, 3, 1, 0), + // ZiskRequiredMemory::new(4, true, 4, 1, 0), + // ZiskRequiredMemory::new(5, false, 5, 1, 0), + // ZiskRequiredMemory::new(6, true, 6, 1, 0), + // ZiskRequiredMemory::new(7, false, 7, 1, 0), + // ZiskRequiredMemory::new(8, true, 8, 1, 0), + // ZiskRequiredMemory::new(9, false, 9, 1, 0), + // ]; + + // let witness_rows = MemWitness::calculate_witness_rows::(mem_ops, 10, 0, true); + + // assert_eq!(witness_rows.len(), 10); + + // // Check the dummy row + // assert_eq!(witness_rows[0].mem_segment, GL::from_canonical_u64(0)); + // assert_eq!(witness_rows[0].mem_last_segment, GL::from_bool(true)); + // assert_eq!(witness_rows[0].addr, GL::default()); + // assert_eq!(witness_rows[0].step, GL::default()); + // assert_eq!(witness_rows[0].sel, GL::default()); + // assert_eq!(witness_rows[0].wr, GL::default()); + // assert_eq!(witness_rows[0].value, [GL::default(), GL::default()]); + // assert_eq!(witness_rows[0].addr_changes, GL::default()); + // assert_eq!(witness_rows[0].same_value, GL::default()); + // assert_eq!(witness_rows[0].first_addr_access_is_read, GL::default()); + + // // Check the remaining rows + // for i in 1..10 { + // assert_eq!(witness_rows[i].mem_segment, GL::from_canonical_u64(0)); + // // ... + // } + // } +} diff --git a/state-machines/mem/src/mem_traces.rs b/state-machines/mem/src/mem_traces.rs deleted file mode 100644 index e4830fc6..00000000 --- a/state-machines/mem/src/mem_traces.rs +++ /dev/null @@ -1,5 +0,0 @@ -use proofman_common as common; -pub use proofman_macros::trace; - -trace!(MemALigned0Row, MemALigned0Trace { fake: F }); -trace!(MemUnaLigned0Row, MemUnaLigned0Trace { fake: F}); diff --git a/state-machines/mem/src/mem_unaligned.rs b/state-machines/mem/src/mem_unaligned.rs deleted file mode 100644 index fde238e3..00000000 --- a/state-machines/mem/src/mem_unaligned.rs +++ /dev/null @@ -1,114 +0,0 @@ -use std::sync::{ - atomic::{AtomicU32, Ordering}, - Arc, Mutex, -}; - -use p3_field::Field; -use proofman::{WitnessComponent, WitnessManager}; -use proofman_common::{ExecutionCtx, ProofCtx, SetupCtx}; -use rayon::Scope; -use sm_common::{MemUnalignedOp, OpResult, Provable}; -use zisk_pil::{MEM_AIRGROUP_ID, MEM_UNALIGNED_AIR_IDS}; - -const PROVE_CHUNK_SIZE: usize = 1 << 12; - -pub struct MemUnalignedSM { - // Count of registered predecessors - registered_predecessors: AtomicU32, - - // Inputs - inputs: Mutex>, -} - -#[allow(unused, unused_variables)] -impl MemUnalignedSM { - pub fn new(wcm: Arc>) -> Arc { - let mem_aligned_sm = - Self { registered_predecessors: AtomicU32::new(0), inputs: Mutex::new(Vec::new()) }; - let mem_aligned_sm = Arc::new(mem_aligned_sm); - - wcm.register_component( - mem_aligned_sm.clone(), - Some(MEM_AIRGROUP_ID), - Some(MEM_UNALIGNED_AIR_IDS), - ); - - mem_aligned_sm - } - - pub fn register_predecessor(&self) { - self.registered_predecessors.fetch_add(1, Ordering::SeqCst); - } - - pub fn unregister_predecessor(&self, scope: &Scope) { - if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { - >::prove(self, &[], true, scope); - } - } - - fn read( - &self, - _addr: u64, - _width: usize, /* , _ctx: &mut ProofCtx, _ectx: &ExecutionCtx */ - ) -> Result> { - Ok((0, true)) - } - - fn write( - &self, - _addr: u64, - _width: usize, - _val: u64, /* , _ctx: &mut ProofCtx, _ectx: &ExecutionCtx */ - ) -> Result> { - Ok((0, true)) - } -} - -impl WitnessComponent for MemUnalignedSM { - fn calculate_witness( - &self, - _stage: u32, - _air_instance: Option, - _pctx: Arc>, - _ectx: Arc, - _sctx: Arc, - ) { - } -} - -impl Provable for MemUnalignedSM { - fn calculate(&self, operation: MemUnalignedOp) -> Result> { - match operation { - MemUnalignedOp::Read(addr, width) => self.read(addr, width), - MemUnalignedOp::Write(addr, width, val) => self.write(addr, width, val), - } - } - - fn prove(&self, operations: &[MemUnalignedOp], drain: bool, scope: &Scope) { - if let Ok(mut inputs) = self.inputs.lock() { - inputs.extend_from_slice(operations); - - while inputs.len() >= PROVE_CHUNK_SIZE || (drain && !inputs.is_empty()) { - let num_drained = std::cmp::min(PROVE_CHUNK_SIZE, inputs.len()); - let _drained_inputs = inputs.drain(..num_drained).collect::>(); - - scope.spawn(move |_| { - // TODO! Implement prove drained_inputs (a chunk of operations) - }); - } - } - } - - fn calculate_prove( - &self, - operation: MemUnalignedOp, - drain: bool, - scope: &Scope, - ) -> Result> { - let result = self.calculate(operation.clone()); - - self.prove(&[operation], drain, scope); - - result - } -} diff --git a/state-machines/rom/src/rom.rs b/state-machines/rom/src/rom.rs index ac135e83..0700bef8 100644 --- a/state-machines/rom/src/rom.rs +++ b/state-machines/rom/src/rom.rs @@ -31,30 +31,25 @@ impl RomSM { &self, rom: &ZiskRom, pc_histogram: ZiskPcHistogram, + instance_gid: usize, ) -> Result<(), Box> { - let buffer_allocator = self.wcm.get_ectx().buffer_allocator.clone(); - let sctx = self.wcm.get_sctx(); - if pc_histogram.end_pc == 0 { panic!("RomSM::prove() detected pc_histogram.end_pc == 0"); // TODO: return an error } - let main_trace_len = - self.wcm.get_pctx().pilout.get_air(MAIN_AIRGROUP_ID, MAIN_AIR_IDS[0]).num_rows() as u64; + let buffer_allocator = self.wcm.get_ectx().buffer_allocator.clone(); + let sctx = self.wcm.get_sctx(); - let (prover_buffer, _, air_id) = - Self::compute_trace_rom(rom, buffer_allocator, &sctx, pc_histogram, main_trace_len)?; + let num_rows = + self.wcm.get_pctx().pilout.get_air(MAIN_AIRGROUP_ID, MAIN_AIR_IDS[0]).num_rows(); + + let prover_buffer = + Self::compute_trace_rom(rom, buffer_allocator, &sctx, pc_histogram, num_rows as u64)?; let air_instance = - AirInstance::new(sctx.clone(), ROM_AIRGROUP_ID, air_id, None, prover_buffer); - let (is_mine, instance_gid) = - self.wcm.get_ectx().dctx.write().unwrap().add_instance(ROM_AIRGROUP_ID, air_id, 1); - if is_mine { - self.wcm - .get_pctx() - .air_instance_repo - .add_air_instance(air_instance, Some(instance_gid)); - } + AirInstance::new(sctx.clone(), ROM_AIRGROUP_ID, MAIN_AIR_IDS[0], None, prover_buffer); + + self.wcm.get_pctx().air_instance_repo.add_air_instance(air_instance, Some(instance_gid)); Ok(()) } @@ -62,7 +57,7 @@ impl RomSM { rom_path: PathBuf, buffer_allocator: Arc, sctx: &SetupCtx, - ) -> Result<(Vec, u64, usize), Box> { + ) -> Result, Box> { // Get the ELF file path as a string let elf_filename: String = rom_path.to_str().unwrap().into(); println!("Proving ROM for ELF file={}", elf_filename); @@ -91,69 +86,12 @@ impl RomSM { sctx: &SetupCtx, pc_histogram: ZiskPcHistogram, main_trace_len: u64, - ) -> Result<(Vec, u64, usize), Box> { + ) -> Result, Box> { let pilout = Pilout::pilout(); - let sizes = ( - pilout.get_air(ROM_AIRGROUP_ID, ROM_AIR_IDS[0]).num_rows(), - // pilout.get_air(ROM_AIRGROUP_ID, ROM_M_AIR_IDS[0]).num_rows(), - // pilout.get_air(ROM_AIRGROUP_ID, ROM_L_AIR_IDS[0]).num_rows(), - ); + let num_rows = pilout.get_air(ROM_AIRGROUP_ID, ROM_AIR_IDS[0]).num_rows(); let number_of_instructions = rom.insts.len(); - Self::create_rom_s( - sizes.0, - rom, - number_of_instructions, - buffer_allocator, - sctx, - pc_histogram, - main_trace_len, - ) - // match number_of_instructions { - // n if n <= sizes.0 => Self::create_rom_s( - // sizes.0, - // rom, - // n, - // buffer_allocator, - // sctx, - // pc_histogram, - // main_trace_len, - // ), - // n if n <= sizes.1 => Self::create_rom_m( - // sizes.1, - // rom, - // n, - // buffer_allocator, - // sctx, - // pc_histogram, - // main_trace_len, - // ), - // n if n < sizes.2 => Self::create_rom_l( - // sizes.2, - // rom, - // n, - // buffer_allocator, - // sctx, - // pc_histogram, - // main_trace_len, - // ), - // _ => panic!("RomSM::compute_trace() found rom too big size={}", - // number_of_instructions), } - } - - fn create_rom_s( - rom_s_size: usize, - rom: &zisk_core::ZiskRom, - number_of_instructions: usize, - buffer_allocator: Arc, - sctx: &SetupCtx, - pc_histogram: ZiskPcHistogram, - main_trace_len: u64, - ) -> Result<(Vec, u64, usize), Box> { - // Set trace size - let trace_size = rom_s_size; - // Allocate a prover buffer let (buffer_size, offsets) = buffer_allocator .get_buffer_info(sctx, ROM_AIRGROUP_ID, ROM_AIR_IDS[0]) @@ -162,7 +100,7 @@ impl RomSM { // Create an empty ROM trace let mut rom_trace = - Rom0Trace::::map_buffer(&mut prover_buffer, trace_size, offsets[0] as usize) + Rom0Trace::::map_buffer(&mut prover_buffer, num_rows, offsets[0] as usize) .expect("RomSM::compute_trace() failed mapping buffer to ROMS0Trace"); // For every instruction in the rom, fill its corresponding ROM trace @@ -232,229 +170,15 @@ impl RomSM { rom_trace[i].jmp_offset2 = jmp_offset2; rom_trace[i].flags = F::from_canonical_u64(inst.get_flags()); rom_trace[i].multiplicity = F::from_canonical_u64(multiplicity); - /*println!( - "ROM SM [{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}], {}", - inst.paddr, - inst.a_offset_imm0, - if inst.a_src == SRC_IMM { inst.a_use_sp_imm1 } else { 0 }, - inst.b_offset_imm0, - if inst.b_src == SRC_IMM { inst.b_use_sp_imm1 } else { 0 }, - if inst.b_src == SRC_IND { 1 } else { 0 }, - inst.ind_width, - inst.op, - inst.store_offset as u64, - inst.jmp_offset1 as u64, - inst.jmp_offset2 as u64, - inst.get_flags(), - multiplicity, - );*/ } // Padd with zeroes - for i in number_of_instructions..trace_size { + for i in number_of_instructions..num_rows { rom_trace[i] = Rom0Row::default(); } - Ok((prover_buffer, offsets[0], ROM_AIR_IDS[0])) + Ok(prover_buffer) } - - // fn create_rom_m( - // rom_m_size: usize, - // rom: &zisk_core::ZiskRom, - // number_of_instructions: usize, - // buffer_allocator: Arc, - // sctx: &SetupCtx, - // pc_histogram: ZiskPcHistogram, - // main_trace_len: u64, - // ) -> Result<(Vec, u64, usize), Box> { - // // Set trace size - // let trace_size = rom_m_size; - - // // Allocate a prover buffer - // let (buffer_size, offsets) = buffer_allocator - // .get_buffer_info(sctx, ROM_AIRGROUP_ID, ROM_M_AIR_IDS[0]) - // .unwrap_or_else(|err| panic!("Error getting buffer info: {}", err)); - // let mut prover_buffer = create_buffer_fast(buffer_size as usize); - - // // Create an empty ROM trace - // let mut rom_trace = - // RomM1Trace::::map_buffer(&mut prover_buffer, trace_size, offsets[0] as usize) - // .expect("RomSM::compute_trace() failed mapping buffer to ROMM0Trace"); - - // // For every instruction in the rom, fill its corresponding ROM trace - // for (i, inst_builder) in rom.insts.clone().into_iter().enumerate() { - // // Get the Zisk instruction - // let inst = inst_builder.1.i; - - // // Calculate the multiplicity, i.e. the number of times this pc is used in this - // // execution - // let mut multiplicity: u64; - // if pc_histogram.map.is_empty() { - // multiplicity = 1; // If the histogram is empty, we use 1 for all pc's - // } else { - // let counter = pc_histogram.map.get(&inst.paddr); - // if counter.is_some() { - // multiplicity = *counter.unwrap(); - // if inst.paddr == pc_histogram.end_pc { - // multiplicity += main_trace_len - 1 - (pc_histogram.steps % - // main_trace_len); } - // } else { - // continue; // We skip those pc's that are not used in this execution - // } - // } - - // // Convert the i64 offsets to F - // let jmp_offset1 = if inst.jmp_offset1 >= 0 { - // F::from_canonical_u64(inst.jmp_offset1 as u64) - // } else { - // F::neg(F::from_canonical_u64((-inst.jmp_offset1) as u64)) - // }; - // let jmp_offset2 = if inst.jmp_offset2 >= 0 { - // F::from_canonical_u64(inst.jmp_offset2 as u64) - // } else { - // F::neg(F::from_canonical_u64((-inst.jmp_offset2) as u64)) - // }; - // let store_offset = if inst.store_offset >= 0 { - // F::from_canonical_u64(inst.store_offset as u64) - // } else { - // F::neg(F::from_canonical_u64((-inst.store_offset) as u64)) - // }; - // let a_offset_imm0 = if inst.a_offset_imm0 as i64 >= 0 { - // F::from_canonical_u64(inst.a_offset_imm0) - // } else { - // F::neg(F::from_canonical_u64((-(inst.a_offset_imm0 as i64)) as u64)) - // }; - // let b_offset_imm0 = if inst.b_offset_imm0 as i64 >= 0 { - // F::from_canonical_u64(inst.b_offset_imm0) - // } else { - // F::neg(F::from_canonical_u64((-(inst.b_offset_imm0 as i64)) as u64)) - // }; - - // // Fill the rom trace row fields - // rom_trace[i].line = F::from_canonical_u64(inst.paddr); // TODO: unify names: pc, - // paddr, line rom_trace[i].a_offset_imm0 = a_offset_imm0; - // rom_trace[i].a_imm1 = - // F::from_canonical_u64(if inst.a_src == SRC_IMM { inst.a_use_sp_imm1 } else { 0 - // }); rom_trace[i].b_offset_imm0 = b_offset_imm0; - // rom_trace[i].b_imm1 = - // F::from_canonical_u64(if inst.b_src == SRC_IMM { inst.b_use_sp_imm1 } else { 0 - // }); //rom_trace[i].b_src_ind = - // // F::from_canonical_u64(if inst.b_src == SRC_IND { 1 } else { 0 }); - // rom_trace[i].ind_width = F::from_canonical_u64(inst.ind_width); - // rom_trace[i].op = F::from_canonical_u8(inst.op); - // rom_trace[i].store_offset = store_offset; - // rom_trace[i].jmp_offset1 = jmp_offset1; - // rom_trace[i].jmp_offset2 = jmp_offset2; - // rom_trace[i].flags = F::from_canonical_u64(inst.get_flags()); - // rom_trace[i].multiplicity = F::from_canonical_u64(multiplicity); - // } - - // // Padd with zeroes - // for i in number_of_instructions..trace_size { - // rom_trace[i] = RomM1Row::default(); - // } - - // Ok((prover_buffer, offsets[0], ROM_M_AIR_IDS[0])) - // } - - // fn create_rom_l( - // rom_l_size: usize, - // rom: &zisk_core::ZiskRom, - // number_of_instructions: usize, - // buffer_allocator: Arc, - // sctx: &SetupCtx, - // pc_histogram: ZiskPcHistogram, - // main_trace_len: u64, - // ) -> Result<(Vec, u64, usize), Box> { - // // Set trace size - // let trace_size = rom_l_size; - - // // Allocate a prover buffer - // let (buffer_size, offsets) = buffer_allocator - // .get_buffer_info(sctx, ROM_AIRGROUP_ID, ROM_L_AIR_IDS[0]) - // .unwrap_or_else(|err| panic!("Error getting buffer info: {}", err)); - // let mut prover_buffer = create_buffer_fast(buffer_size as usize); - - // // Create an empty ROM trace - // let mut rom_trace = - // RomL2Trace::::map_buffer(&mut prover_buffer, trace_size, offsets[0] as usize) - // .expect("RomSM::compute_trace() failed mapping buffer to ROML0Trace"); - - // // For every instruction in the rom, fill its corresponding ROM trace - // for (i, inst_builder) in rom.insts.clone().into_iter().enumerate() { - // // Get the Zisk instruction - // let inst = inst_builder.1.i; - - // // Calculate the multiplicity, i.e. the number of times this pc is used in this - // // execution - // let mut multiplicity: u64; - // if pc_histogram.map.is_empty() { - // multiplicity = 1; // If the histogram is empty, we use 1 for all pc's - // } else { - // let counter = pc_histogram.map.get(&inst.paddr); - // if counter.is_some() { - // multiplicity = *counter.unwrap(); - // if inst.paddr == pc_histogram.end_pc { - // multiplicity += main_trace_len - 1 - (pc_histogram.steps % - // main_trace_len); } - // } else { - // continue; // We skip those pc's that are not used in this execution - // } - // } - - // // Convert the i64 offsets to F - // let jmp_offset1 = if inst.jmp_offset1 >= 0 { - // F::from_canonical_u64(inst.jmp_offset1 as u64) - // } else { - // F::neg(F::from_canonical_u64((-inst.jmp_offset1) as u64)) - // }; - // let jmp_offset2 = if inst.jmp_offset2 >= 0 { - // F::from_canonical_u64(inst.jmp_offset2 as u64) - // } else { - // F::neg(F::from_canonical_u64((-inst.jmp_offset2) as u64)) - // }; - // let store_offset = if inst.store_offset >= 0 { - // F::from_canonical_u64(inst.store_offset as u64) - // } else { - // F::neg(F::from_canonical_u64((-inst.store_offset) as u64)) - // }; - // let a_offset_imm0 = if inst.a_offset_imm0 as i64 >= 0 { - // F::from_canonical_u64(inst.a_offset_imm0) - // } else { - // F::neg(F::from_canonical_u64((-(inst.a_offset_imm0 as i64)) as u64)) - // }; - // let b_offset_imm0 = if inst.b_offset_imm0 as i64 >= 0 { - // F::from_canonical_u64(inst.b_offset_imm0) - // } else { - // F::neg(F::from_canonical_u64((-(inst.b_offset_imm0 as i64)) as u64)) - // }; - - // // Fill the rom trace row fields - // rom_trace[i].line = F::from_canonical_u64(inst.paddr); // TODO: unify names: pc, - // paddr, line rom_trace[i].a_offset_imm0 = a_offset_imm0; - // rom_trace[i].a_imm1 = - // F::from_canonical_u64(if inst.a_src == SRC_IMM { inst.a_use_sp_imm1 } else { 0 - // }); rom_trace[i].b_offset_imm0 = b_offset_imm0; - // rom_trace[i].b_imm1 = - // F::from_canonical_u64(if inst.b_src == SRC_IMM { inst.b_use_sp_imm1 } else { 0 - // }); //rom_trace[i].b_src_ind = - // // F::from_canonical_u64(if inst.b_src == SRC_IND { 1 } else { 0 }); - // rom_trace[i].ind_width = F::from_canonical_u64(inst.ind_width); - // rom_trace[i].op = F::from_canonical_u8(inst.op); - // rom_trace[i].store_offset = store_offset; - // rom_trace[i].jmp_offset1 = jmp_offset1; - // rom_trace[i].jmp_offset2 = jmp_offset2; - // rom_trace[i].flags = F::from_canonical_u64(inst.get_flags()); - // rom_trace[i].multiplicity = F::from_canonical_u64(multiplicity); - // } - - // // Padd with zeroes - // for i in number_of_instructions..trace_size { - // rom_trace[i] = RomL2Row::default(); - // } - - // Ok((prover_buffer, offsets[0], ROM_L_AIR_IDS[0])) - // } } impl WitnessComponent for RomSM {} diff --git a/witness-computation/src/executor.rs b/witness-computation/src/executor.rs index 2398fbd5..1736760d 100644 --- a/witness-computation/src/executor.rs +++ b/witness-computation/src/executor.rs @@ -10,17 +10,18 @@ use sm_arith::ArithSM; use sm_binary::BinarySM; use sm_common::create_prover_buffer; use sm_main::{InstanceExtensionCtx, MainSM}; -use sm_mem::MemSM; +use sm_mem::MemProxy; use sm_rom::RomSM; use std::{ fs, path::{Path, PathBuf}, sync::Arc, + thread, }; use zisk_core::{Riscv2zisk, ZiskOperationType, ZiskRom, ZISK_OPERATION_TYPE_VARIANTS}; use zisk_pil::{ BINARY_AIRGROUP_ID, BINARY_AIR_IDS, BINARY_EXTENSION_AIRGROUP_ID, BINARY_EXTENSION_AIR_IDS, - MAIN_AIRGROUP_ID, MAIN_AIR_IDS, + MAIN_AIRGROUP_ID, MAIN_AIR_IDS, ROM_AIRGROUP_ID, ROM_AIR_IDS, }; use ziskemu::{EmuOptions, ZiskEmulator}; @@ -35,7 +36,7 @@ pub struct ZiskExecutor { pub rom_sm: Arc>, /// Memory State Machine - pub mem_sm: Arc, + pub mem_proxy: Arc>, /// Binary State Machine pub binary_sm: Arc>, @@ -51,7 +52,7 @@ impl ZiskExecutor { let std = Std::new(wcm.clone()); let rom_sm = RomSM::new(wcm.clone()); - let mem_sm = MemSM::new(wcm.clone()); + let mem_proxy = MemProxy::new(wcm.clone()); let binary_sm = BinarySM::new(wcm.clone(), std.clone()); let arith_sm = ArithSM::new(wcm.clone()); @@ -76,6 +77,7 @@ impl ZiskExecutor { // TODO - Remove this when the ZisK ROM is able to be loaded from a file panic!("ROM file must be an ELF file"); }; + let zisk_rom = Arc::new(zisk_rom); let zisk_rom = Arc::new(zisk_rom); @@ -83,9 +85,9 @@ impl ZiskExecutor { // TODO - If there is more than one Main AIR available, the MAX_ACCUMULATED will be the one // with the highest num_rows. It has to be a power of 2. - let main_sm = MainSM::new(wcm.clone(), arith_sm.clone(), binary_sm.clone(), mem_sm.clone()); + let main_sm = MainSM::new(wcm.clone(), arith_sm.clone(), binary_sm.clone()); - Self { zisk_rom, main_sm, rom_sm, mem_sm, binary_sm, arith_sm } + Self { zisk_rom, main_sm, rom_sm, mem_proxy, binary_sm, arith_sm } } /// Executes the MainSM state machine and processes the inputs in batches when the maximum @@ -120,6 +122,7 @@ impl ZiskExecutor { let path = PathBuf::from(public_inputs_path.display().to_string()); fs::read(path).expect("Could not read inputs file") }; + let public_inputs = Arc::new(public_inputs); // During ROM processing, we gather execution data necessary for creating the AIR instances. // This data is collected by the emulator and includes the minimal execution trace, @@ -138,17 +141,36 @@ impl ZiskExecutor { op_sizes[ZiskOperationType::Binary as usize] = air_binary.num_rows() as u64; op_sizes[ZiskOperationType::BinaryE as usize] = air_binary_e.num_rows() as u64; + // STEP 1. Generate all inputs + // ============================================== + + // Memory State Machine + // ---------------------------------------------- + let mem_thread = thread::spawn({ + let zisk_rom = self.zisk_rom.clone(); + let public_inputs = public_inputs.clone(); + move || { + ZiskEmulator::par_process_rom_memory::(&zisk_rom, &public_inputs) + .expect("Failed in ZiskEmulator::par_process_rom_memory") + } + }); + // ROM State Machine // ---------------------------------------------- // Run the ROM to compute the ROM witness - let rom_sm = self.rom_sm.clone(); - let zisk_rom = self.zisk_rom.clone(); - let pc_histogram = - ZiskEmulator::process_rom_pc_histogram(&self.zisk_rom, &public_inputs, &emu_options) - .expect( - "MainSM::execute() failed calling ZiskEmulator::process_rom_pc_histogram()", - ); - let handle_rom = std::thread::spawn(move || rom_sm.prove(&zisk_rom, pc_histogram)); + let rom_thread = thread::spawn({ + let zisk_rom = self.zisk_rom.clone(); + let public_inputs = public_inputs.clone(); + let emu_options_cloned = emu_options.clone(); + move || { + ZiskEmulator::process_rom_pc_histogram( + &zisk_rom, + &public_inputs, + &emu_options_cloned, + ) + .expect("MainSM::execute() failed calling ZiskEmulator::process_rom_pc_histogram()") + } + }); // Main, Binary and Arith State Machines // ---------------------------------------------- @@ -165,10 +187,39 @@ impl ZiskExecutor { .expect("Error during emulator execution"); timer_stop_and_log_debug!(PAR_PROCESS_ROM); - emu_slices.points.sort_by(|a, b| a.op_type.partial_cmp(&b.op_type).unwrap()); + // STEP 2. Wait until all inputs are generated + // ============================================== + // Join all the threads to synchronize the execution + let mem_required = mem_thread.join().expect("Error during Memory witness computation"); + let rom_required = rom_thread.join().expect("Error during ROM witness computation"); + + // STEP 3. Generate AIRs and Prove + // ============================================== + + // Memory State Machine + // ---------------------------------------------- + let mem_thread = thread::spawn({ + let mem_proxy = self.mem_proxy.clone(); + move || mem_proxy.prove(mem_required).expect("Error during Memory witness computation") + }); - // Join threads to synchronize the execution - handle_rom.join().unwrap().expect("Error during ROM witness computation"); + // ROM State Machine + // ---------------------------------------------- + let (rom_is_mine, rom_instance_gid) = + ectx.dctx.write().unwrap().add_instance(ROM_AIRGROUP_ID, ROM_AIR_IDS[0], 1); + + let rom_thread = if rom_is_mine { + let rom_sm = self.rom_sm.clone(); + let zisk_rom = self.zisk_rom.clone(); + + Some(thread::spawn(move || rom_sm.prove(&zisk_rom, rom_required, rom_instance_gid))) + } else { + None + }; + + // Main, Binary and Arith State Machines + // ---------------------------------------------- + emu_slices.points.sort_by(|a, b| a.op_type.partial_cmp(&b.op_type).unwrap()); // FIXME: Move InstanceExtensionCtx form main SM to another place let mut instances_extension_ctx: Vec> = @@ -234,7 +285,13 @@ impl ZiskExecutor { } timer_stop_and_log_debug!(ADD_INSTANCES_TO_THE_REPO); - // self.mem_sm.unregister_predecessor(scope); + mem_thread.join().expect("Error during Memory witness computation"); + + if let Some(thread) = rom_thread { + let _ = thread.join().expect("Error during ROM witness computation"); + } + + self.mem_proxy.unregister_predecessor(); self.binary_sm.unregister_predecessor(); // self.arith_sm.register_predecessor(scope); } From 36573d9273c992c140210e88cb95367d556670ce Mon Sep 17 00:00:00 2001 From: Xavier Pinsach <10213118+xavi-pinsach@users.noreply.github.com> Date: Tue, 5 Nov 2024 05:27:31 +0000 Subject: [PATCH 02/44] fix Arc --- witness-computation/src/executor.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/witness-computation/src/executor.rs b/witness-computation/src/executor.rs index 1736760d..f06d0c87 100644 --- a/witness-computation/src/executor.rs +++ b/witness-computation/src/executor.rs @@ -77,7 +77,6 @@ impl ZiskExecutor { // TODO - Remove this when the ZisK ROM is able to be loaded from a file panic!("ROM file must be an ELF file"); }; - let zisk_rom = Arc::new(zisk_rom); let zisk_rom = Arc::new(zisk_rom); From 9d4be190b5415fa668bd564a13f4812dc2bfb7c0 Mon Sep 17 00:00:00 2001 From: Xavier Pinsach <10213118+xavi-pinsach@users.noreply.github.com> Date: Tue, 5 Nov 2024 07:38:45 +0000 Subject: [PATCH 03/44] wip --- pil/src/pil_helpers/pilout.rs | 57 ++- pil/src/pil_helpers/traces.rs | 24 +- pil/zisk.pil | 33 +- state-machines/binary/src/binary.rs | 13 +- state-machines/binary/src/binary_basic.rs | 12 +- .../binary/src/binary_basic_table.rs | 10 +- state-machines/binary/src/binary_extension.rs | 40 +- .../binary/src/binary_extension_table.rs | 10 +- state-machines/mem/pil/mem.pil | 46 ++- state-machines/mem/pil/mem_align.pil | 173 +++++++++ state-machines/mem/pil/mem_align_rom.pil | 299 +++++++++++++++ state-machines/mem/src/mem_proxy.rs | 15 +- state-machines/mem/src/mem_sm.rs | 350 ++++++++++-------- state-machines/rom/src/rom.rs | 12 +- 14 files changed, 791 insertions(+), 303 deletions(-) create mode 100644 state-machines/mem/pil/mem_align_rom.pil diff --git a/pil/src/pil_helpers/pilout.rs b/pil/src/pil_helpers/pilout.rs index 449b8417..2df7024e 100644 --- a/pil/src/pil_helpers/pilout.rs +++ b/pil/src/pil_helpers/pilout.rs @@ -2,68 +2,55 @@ // Manual modifications are not recommended and may be overwritten. use proofman_common::WitnessPilout; -pub const PILOUT_HASH: &[u8] = b"Zisk-hash"; +pub const PILOUT_HASH: &[u8] = b"ZiskMem-hash"; //AIRGROUP CONSTANTS -pub const MAIN_AIRGROUP_ID: usize = 0; +pub const ZISK_AIRGROUP_ID: usize = 0; -pub const ROM_AIRGROUP_ID: usize = 1; +//AIR CONSTANTS -pub const BINARY_AIRGROUP_ID: usize = 2; +pub const MAIN_AIR_IDS: &[usize] = &[0]; -pub const BINARY_TABLE_AIRGROUP_ID: usize = 3; +pub const ROM_AIR_IDS: &[usize] = &[1]; -pub const BINARY_EXTENSION_AIRGROUP_ID: usize = 4; +pub const MEM_AIR_IDS: &[usize] = &[2]; -pub const BINARY_EXTENSION_TABLE_AIRGROUP_ID: usize = 5; +pub const MEM_ALIGN_AIR_IDS: &[usize] = &[3]; -//AIR CONSTANTS - -pub const MAIN_AIR_IDS: &[usize] = &[0]; +pub const MEM_ALIGN_ROM_AIR_IDS: &[usize] = &[4]; -pub const ROM_AIR_IDS: &[usize] = &[0]; +pub const BINARY_AIR_IDS: &[usize] = &[5]; -pub const BINARY_AIR_IDS: &[usize] = &[0]; +pub const BINARY_TABLE_AIR_IDS: &[usize] = &[6]; -pub const BINARY_TABLE_AIR_IDS: &[usize] = &[0]; +pub const BINARY_EXTENSION_AIR_IDS: &[usize] = &[7]; -pub const BINARY_EXTENSION_AIR_IDS: &[usize] = &[0]; +pub const BINARY_EXTENSION_TABLE_AIR_IDS: &[usize] = &[8]; -pub const SPECIFIED_RANGES_AIR_IDS: &[usize] = &[1]; +pub const SPECIFIED_RANGES_AIR_IDS: &[usize] = &[9]; -pub const BINARY_EXTENSION_TABLE_AIR_IDS: &[usize] = &[0]; +pub const U_8_AIR_AIR_IDS: &[usize] = &[10]; pub struct Pilout; impl Pilout { pub fn pilout() -> WitnessPilout { - let mut pilout = WitnessPilout::new("Zisk", 2, PILOUT_HASH.to_vec()); + let mut pilout = WitnessPilout::new("ZiskMem", 2, PILOUT_HASH.to_vec()); - let air_group = pilout.add_air_group(Some("Main")); + let air_group = pilout.add_air_group(Some("Zisk")); air_group.add_air(Some("Main"), 2097152); - - let air_group = pilout.add_air_group(Some("Rom")); - - air_group.add_air(Some("Rom"), 1048576); - - let air_group = pilout.add_air_group(Some("Binary")); - + air_group.add_air(Some("Rom"), 4194304); + air_group.add_air(Some("Mem"), 2097152); + air_group.add_air(Some("MemAlign"), 2097152); + air_group.add_air(Some("MemAlignRom"), 256); air_group.add_air(Some("Binary"), 2097152); - - let air_group = pilout.add_air_group(Some("BinaryTable")); - air_group.add_air(Some("BinaryTable"), 4194304); - - let air_group = pilout.add_air_group(Some("BinaryExtension")); - air_group.add_air(Some("BinaryExtension"), 2097152); - air_group.add_air(Some("SpecifiedRanges"), 16777216); - - let air_group = pilout.add_air_group(Some("BinaryExtensionTable")); - air_group.add_air(Some("BinaryExtensionTable"), 4194304); + air_group.add_air(Some("SpecifiedRanges"), 16777216); + air_group.add_air(Some("U8Air"), 256); pilout } diff --git a/pil/src/pil_helpers/traces.rs b/pil/src/pil_helpers/traces.rs index 5fb91f20..f1af3709 100644 --- a/pil/src/pil_helpers/traces.rs +++ b/pil/src/pil_helpers/traces.rs @@ -11,6 +11,18 @@ trace!(Rom0Row, Rom0Trace { line: F, a_offset_imm0: F, a_imm1: F, b_offset_imm0: F, b_imm1: F, ind_width: F, op: F, store_offset: F, jmp_offset1: F, jmp_offset2: F, flags: F, multiplicity: F, }); +trace!(Mem0Row, Mem0Trace { + addr: F, step: F, sel: F, wr: F, value: [F; 2], addr_changes: F, same_value: F, first_addr_access_is_read: F, +}); + +trace!(MemAlign0Row, MemAlign0Trace { + addr: F, offset: F, width: F, wr: F, pc: F, reset: F, sel_up_to_down: F, sel_down_to_up: F, reg: [F; 8], sel: [F; 8], sel_prove: F, step: F, +}); + +trace!(MemAlignRom0Row, MemAlignRom0Trace { + multiplicity: F, +}); + trace!(Binary0Row, Binary0Trace { m_op: F, mode32: F, free_in_a: [F; 8], free_in_b: [F; 8], free_in_c: [F; 8], carry: [F; 8], use_last_carry: F, op_is_min_max: F, multiplicity: F, main_step: F, }); @@ -23,10 +35,14 @@ trace!(BinaryExtension0Row, BinaryExtension0Trace { op: F, in1: [F; 8], in2_low: F, out: [[F; 2]; 8], op_is_shift: F, in2: [F; 2], main_step: F, multiplicity: F, }); -trace!(SpecifiedRanges1Row, SpecifiedRanges1Trace { - mul: [F; 1], -}); - trace!(BinaryExtensionTable0Row, BinaryExtensionTable0Trace { multiplicity: F, }); + +trace!(SpecifiedRanges0Row, SpecifiedRanges0Trace { + mul: [F; 2], +}); + +trace!(U8Air0Row, U8Air0Trace { + mul: F, +}); diff --git a/pil/zisk.pil b/pil/zisk.pil index 69cd09bb..dbb215f9 100644 --- a/pil/zisk.pil +++ b/pil/zisk.pil @@ -1,38 +1,25 @@ - -require "constants.pil" -require "rom/pil/rom.pil" require "main/pil/main.pil" +require "rom/pil/rom.pil" +require "mem/pil/mem.pil" +require "mem/pil/mem_align.pil" +require "mem/pil/mem_align_rom.pil" require "binary/pil/binary.pil" require "binary/pil/binary_table.pil" require "binary/pil/binary_extension.pil" require "binary/pil/binary_extension_table.pil" -// require "mem/pil/mem.pil" const int OPERATION_BUS_ID = 5000; -airgroup Main { - Main(N: 2**21, RC: 2, operation_bus_id: OPERATION_BUS_ID); -} -airgroup Rom { - Rom(N: 2**20); -} +airgroup Zisk { + Main(N: 2**21, RC: 2, operation_bus_id: OPERATION_BUS_ID); + Rom(N: 2**22); -// airgroup Mem { -// Mem(N: 2**21, RC: 2); -// } + Mem(N: 2**21, RC: 2); + MemAlign(N: 2**21); + MemAlignRom(disable_fixed: 0); -airgroup Binary { Binary(N: 2**21, operation_bus_id: OPERATION_BUS_ID); -} - -airgroup BinaryTable { BinaryTable(disable_fixed: 0); -} - -airgroup BinaryExtension { BinaryExtension(N: 2**21, operation_bus_id: OPERATION_BUS_ID); -} - -airgroup BinaryExtensionTable { BinaryExtensionTable(disable_fixed: 0); } \ No newline at end of file diff --git a/state-machines/binary/src/binary.rs b/state-machines/binary/src/binary.rs index c8b128f9..9b020312 100644 --- a/state-machines/binary/src/binary.rs +++ b/state-machines/binary/src/binary.rs @@ -11,9 +11,8 @@ use rayon::Scope; use sm_common::{OpResult, Provable}; use zisk_core::ZiskRequiredOperation; use zisk_pil::{ - BINARY_AIRGROUP_ID, BINARY_AIR_IDS, BINARY_EXTENSION_AIRGROUP_ID, BINARY_EXTENSION_AIR_IDS, - BINARY_EXTENSION_TABLE_AIRGROUP_ID, BINARY_EXTENSION_TABLE_AIR_IDS, BINARY_TABLE_AIRGROUP_ID, - BINARY_TABLE_AIR_IDS, + BINARY_AIR_IDS, BINARY_EXTENSION_AIR_IDS, BINARY_EXTENSION_TABLE_AIR_IDS, BINARY_TABLE_AIR_IDS, + ZISK_AIRGROUP_ID, }; const PROVE_CHUNK_SIZE: usize = 1 << 16; @@ -35,24 +34,24 @@ pub struct BinarySM { impl BinarySM { pub fn new(wcm: Arc>, std: Arc>) -> Arc { let binary_basic_table_sm = - BinaryBasicTableSM::new(wcm.clone(), BINARY_TABLE_AIRGROUP_ID, BINARY_TABLE_AIR_IDS); + BinaryBasicTableSM::new(wcm.clone(), ZISK_AIRGROUP_ID, BINARY_TABLE_AIR_IDS); let binary_basic_sm = BinaryBasicSM::new( wcm.clone(), binary_basic_table_sm, - BINARY_AIRGROUP_ID, + ZISK_AIRGROUP_ID, BINARY_AIR_IDS, ); let binary_extension_table_sm = BinaryExtensionTableSM::new( wcm.clone(), - BINARY_EXTENSION_TABLE_AIRGROUP_ID, + ZISK_AIRGROUP_ID, BINARY_EXTENSION_TABLE_AIR_IDS, ); let binary_extension_sm = BinaryExtensionSM::new( wcm.clone(), std, binary_extension_table_sm, - BINARY_EXTENSION_AIRGROUP_ID, + ZISK_AIRGROUP_ID, BINARY_EXTENSION_AIR_IDS, ); diff --git a/state-machines/binary/src/binary_basic.rs b/state-machines/binary/src/binary_basic.rs index ac4ddcb1..3755609b 100644 --- a/state-machines/binary/src/binary_basic.rs +++ b/state-machines/binary/src/binary_basic.rs @@ -10,9 +10,9 @@ use proofman_common::AirInstance; use proofman_util::{timer_start_trace, timer_stop_and_log_trace}; use rayon::Scope; use sm_common::{create_prover_buffer, OpResult, Provable}; +use zisk_pil::{Binary0Row, Binary0Trace, BINARY_AIR_IDS, BINARY_TABLE_AIR_IDS, ZISK_AIRGROUP_ID}; use std::cmp::Ordering as CmpOrdering; use zisk_core::{zisk_ops::ZiskOp, ZiskRequiredOperation}; -use zisk_pil::*; use crate::{BinaryBasicTableOp, BinaryBasicTableSM}; @@ -657,9 +657,9 @@ impl BinaryBasicSM { ) { timer_start_trace!(BINARY_TRACE); let pctx = wcm.get_pctx(); - let air = pctx.pilout.get_air(BINARY_AIRGROUP_ID, BINARY_AIR_IDS[0]); + let air = pctx.pilout.get_air(ZISK_AIRGROUP_ID, BINARY_AIR_IDS[0]); let air_binary_table = - pctx.pilout.get_air(BINARY_TABLE_AIRGROUP_ID, BINARY_TABLE_AIR_IDS[0]); + pctx.pilout.get_air(ZISK_AIRGROUP_ID, BINARY_TABLE_AIR_IDS[0]); assert!(operations.len() <= air.num_rows()); info!( @@ -729,7 +729,7 @@ impl Provable for BinaryBasicSM { inputs.extend_from_slice(operations); let pctx = self.wcm.get_pctx(); - let air = pctx.pilout.get_air(BINARY_AIRGROUP_ID, BINARY_AIR_IDS[0]); + let air = pctx.pilout.get_air(ZISK_AIRGROUP_ID, BINARY_AIR_IDS[0]); while inputs.len() >= air.num_rows() || (drain && !inputs.is_empty()) { let num_drained = std::cmp::min(air.num_rows(), inputs.len()); @@ -743,7 +743,7 @@ impl Provable for BinaryBasicSM { let (mut prover_buffer, offset) = create_prover_buffer( &wcm.get_ectx(), &wcm.get_sctx(), - BINARY_AIRGROUP_ID, + ZISK_AIRGROUP_ID, BINARY_AIR_IDS[0], ); @@ -757,7 +757,7 @@ impl Provable for BinaryBasicSM { let air_instance = AirInstance::new( sctx, - BINARY_AIRGROUP_ID, + ZISK_AIRGROUP_ID, BINARY_AIR_IDS[0], None, prover_buffer, diff --git a/state-machines/binary/src/binary_basic_table.rs b/state-machines/binary/src/binary_basic_table.rs index df028e77..7f2070d9 100644 --- a/state-machines/binary/src/binary_basic_table.rs +++ b/state-machines/binary/src/binary_basic_table.rs @@ -10,7 +10,7 @@ use proofman_common::AirInstance; use rayon::prelude::*; use sm_common::create_prover_buffer; use zisk_core::{zisk_ops::ZiskOp, P2_16, P2_17, P2_18, P2_19, P2_8}; -use zisk_pil::{BINARY_TABLE_AIRGROUP_ID, BINARY_TABLE_AIR_IDS}; +use zisk_pil::{ZISK_AIRGROUP_ID, BINARY_TABLE_AIR_IDS}; #[derive(Debug, Clone, PartialEq, Copy)] #[repr(u8)] @@ -53,7 +53,7 @@ impl BinaryBasicTableSM { pub fn new(wcm: Arc>, airgroup_id: usize, air_ids: &[usize]) -> Arc { let pctx = wcm.get_pctx(); - let air = pctx.pilout.get_air(BINARY_TABLE_AIRGROUP_ID, BINARY_TABLE_AIR_IDS[0]); + let air = pctx.pilout.get_air(ZISK_AIRGROUP_ID, BINARY_TABLE_AIR_IDS[0]); let binary_basic_table = Self { wcm: wcm.clone(), @@ -222,7 +222,7 @@ impl BinaryBasicTableSM { let mut multiplicity = self.multiplicity.lock().unwrap(); let (is_myne, instance_global_idx) = - dctx.add_instance(BINARY_TABLE_AIRGROUP_ID, BINARY_TABLE_AIR_IDS[0], 1); + dctx.add_instance(ZISK_AIRGROUP_ID, BINARY_TABLE_AIR_IDS[0], 1); let owner: usize = dctx.owner(instance_global_idx); let mut multiplicity_ = std::mem::take(&mut *multiplicity); @@ -233,7 +233,7 @@ impl BinaryBasicTableSM { let (mut prover_buffer, offset) = create_prover_buffer( &self.wcm.get_ectx(), &self.wcm.get_sctx(), - BINARY_TABLE_AIRGROUP_ID, + ZISK_AIRGROUP_ID, BINARY_TABLE_AIR_IDS[0], ); prover_buffer[offset as usize..offset as usize + self.num_rows] @@ -248,7 +248,7 @@ impl BinaryBasicTableSM { ); let air_instance = AirInstance::new( self.wcm.get_sctx(), - BINARY_TABLE_AIRGROUP_ID, + ZISK_AIRGROUP_ID, BINARY_TABLE_AIR_IDS[0], None, prover_buffer, diff --git a/state-machines/binary/src/binary_extension.rs b/state-machines/binary/src/binary_extension.rs index dbd8f409..d53414e9 100644 --- a/state-machines/binary/src/binary_extension.rs +++ b/state-machines/binary/src/binary_extension.rs @@ -17,7 +17,7 @@ use proofman_util::{timer_start_debug, timer_stop_and_log_debug}; use rayon::Scope; use sm_common::{create_prover_buffer, OpResult, Provable}; use zisk_core::{zisk_ops::ZiskOp, ZiskRequiredOperation}; -use zisk_pil::*; +use zisk_pil::{BinaryExtension0Row, BinaryExtension0Trace, BINARY_EXTENSION_AIR_IDS, BINARY_EXTENSION_TABLE_AIR_IDS, ZISK_AIRGROUP_ID}; const MASK_32: u64 = 0xFFFFFFFF; const MASK_64: u64 = 0xFFFFFFFFFFFFFFFF; @@ -117,12 +117,12 @@ impl BinaryExtensionSM { fn opcode_is_shift(opcode: ZiskOp) -> bool { match opcode { - ZiskOp::Sll | - ZiskOp::Srl | - ZiskOp::Sra | - ZiskOp::SllW | - ZiskOp::SrlW | - ZiskOp::SraW => true, + ZiskOp::Sll + | ZiskOp::Srl + | ZiskOp::Sra + | ZiskOp::SllW + | ZiskOp::SrlW + | ZiskOp::SraW => true, ZiskOp::SignExtendB | ZiskOp::SignExtendH | ZiskOp::SignExtendW => false, @@ -134,12 +134,12 @@ impl BinaryExtensionSM { match opcode { ZiskOp::SllW | ZiskOp::SrlW | ZiskOp::SraW => true, - ZiskOp::Sll | - ZiskOp::Srl | - ZiskOp::Sra | - ZiskOp::SignExtendB | - ZiskOp::SignExtendH | - ZiskOp::SignExtendW => false, + ZiskOp::Sll + | ZiskOp::Srl + | ZiskOp::Sra + | ZiskOp::SignExtendB + | ZiskOp::SignExtendH + | ZiskOp::SignExtendW => false, _ => panic!("BinaryExtensionSM::opcode_is_shift() got invalid opcode={:?}", opcode), } @@ -389,10 +389,9 @@ impl BinaryExtensionSM { timer_start_debug!(BINARY_EXTENSION_TRACE); let pctx = wcm.get_pctx(); - let air = pctx.pilout.get_air(BINARY_EXTENSION_AIRGROUP_ID, BINARY_EXTENSION_AIR_IDS[0]); - let air_binary_extension_table = pctx - .pilout - .get_air(BINARY_EXTENSION_TABLE_AIRGROUP_ID, BINARY_EXTENSION_TABLE_AIR_IDS[0]); + let air = pctx.pilout.get_air(ZISK_AIRGROUP_ID, BINARY_EXTENSION_AIR_IDS[0]); + let air_binary_extension_table = + pctx.pilout.get_air(ZISK_AIRGROUP_ID, BINARY_EXTENSION_TABLE_AIR_IDS[0]); assert!(operations.len() <= air.num_rows()); info!( @@ -467,8 +466,7 @@ impl Provable for BinaryExtensio inputs.extend_from_slice(operations); let pctx = self.wcm.get_pctx(); - let air = - pctx.pilout.get_air(BINARY_EXTENSION_AIRGROUP_ID, BINARY_EXTENSION_AIR_IDS[0]); + let air = pctx.pilout.get_air(ZISK_AIRGROUP_ID, BINARY_EXTENSION_AIR_IDS[0]); while inputs.len() >= air.num_rows() || (drain && !inputs.is_empty()) { let num_drained = std::cmp::min(air.num_rows(), inputs.len()); @@ -484,7 +482,7 @@ impl Provable for BinaryExtensio let (mut prover_buffer, offset) = create_prover_buffer( &wcm.get_ectx(), &wcm.get_sctx(), - BINARY_EXTENSION_AIRGROUP_ID, + ZISK_AIRGROUP_ID, BINARY_EXTENSION_AIR_IDS[0], ); @@ -499,7 +497,7 @@ impl Provable for BinaryExtensio let air_instance = AirInstance::new( sctx, - BINARY_EXTENSION_AIRGROUP_ID, + ZISK_AIRGROUP_ID, BINARY_EXTENSION_AIR_IDS[0], None, prover_buffer, diff --git a/state-machines/binary/src/binary_extension_table.rs b/state-machines/binary/src/binary_extension_table.rs index 8c88018b..8fbc1de3 100644 --- a/state-machines/binary/src/binary_extension_table.rs +++ b/state-machines/binary/src/binary_extension_table.rs @@ -10,7 +10,7 @@ use proofman_common::AirInstance; use rayon::prelude::*; use sm_common::create_prover_buffer; use zisk_core::{zisk_ops::ZiskOp, P2_11, P2_19, P2_8}; -use zisk_pil::{BINARY_EXTENSION_TABLE_AIRGROUP_ID, BINARY_EXTENSION_TABLE_AIR_IDS}; +use zisk_pil::{ZISK_AIRGROUP_ID, BINARY_EXTENSION_TABLE_AIR_IDS}; #[derive(Debug, Clone, PartialEq, Copy)] #[repr(u8)] @@ -49,7 +49,7 @@ impl BinaryExtensionTableSM { let pctx = wcm.get_pctx(); let air = pctx .pilout - .get_air(BINARY_EXTENSION_TABLE_AIRGROUP_ID, BINARY_EXTENSION_TABLE_AIR_IDS[0]); + .get_air(ZISK_AIRGROUP_ID, BINARY_EXTENSION_TABLE_AIR_IDS[0]); let binary_extension_table = Self { wcm: wcm.clone(), @@ -131,7 +131,7 @@ impl BinaryExtensionTableSM { let mut multiplicity = self.multiplicity.lock().unwrap(); let (is_myne, instance_global_idx) = dctx.add_instance( - BINARY_EXTENSION_TABLE_AIRGROUP_ID, + ZISK_AIRGROUP_ID, BINARY_EXTENSION_TABLE_AIR_IDS[0], 1, ); @@ -145,7 +145,7 @@ impl BinaryExtensionTableSM { let (mut prover_buffer, offset) = create_prover_buffer( &self.wcm.get_ectx(), &self.wcm.get_sctx(), - BINARY_EXTENSION_TABLE_AIRGROUP_ID, + ZISK_AIRGROUP_ID, BINARY_EXTENSION_TABLE_AIR_IDS[0], ); @@ -162,7 +162,7 @@ impl BinaryExtensionTableSM { let air_instance = AirInstance::new( self.wcm.get_sctx(), - BINARY_EXTENSION_TABLE_AIRGROUP_ID, + ZISK_AIRGROUP_ID, BINARY_EXTENSION_TABLE_AIR_IDS[0], None, prover_buffer, diff --git a/state-machines/mem/pil/mem.pil b/state-machines/mem/pil/mem.pil index 50bd652e..50da226d 100644 --- a/state-machines/mem/pil/mem.pil +++ b/state-machines/mem/pil/mem.pil @@ -10,20 +10,21 @@ const int MEMORY_MAX_DIFF = 2**22; const int MAX_MEM_STEP_OFFSET = 3; -airtemplate Mem (int N = 2**21, int RC = 2, int id = MEMORY_ID, int MAX_STEP = 2 ** 23, int MEM_BYTES = 8 ) { +airtemplate Mem(const int N = 2**21, const int RC = 2, const int id = MEMORY_ID, const int MAX_STEP = 2 ** 23, const int MEM_BYTES = 8) { col fixed SEGMENT_L1 = [1,0...]; const expr SEGMENT_LAST = SEGMENT_L1'; airval mem_segment; airval mem_last_segment; - col witness addr; // n-byte address, real address = addr * MEM_BYTES + col witness addr; // n-byte address, real address = addr * MEM_BYTES col witness step; - col witness sel, wr; + col witness sel; + col witness wr; col witness value[RC]; col witness addr_changes; - const expr rd = (1 - wr); + const expr rd = 1 - wr; sel * (1 - sel) === 0; wr * (1 - wr) === 0; @@ -42,15 +43,18 @@ airtemplate Mem (int N = 2**21, int RC = 2, int id = MEMORY_ID, int MAX_STEP = 2 // setting mem_last_segment = 1 // if addr_changes == 0 means that addr and previous address are the same - (1 - addr_changes) * ('addr - addr) === 0; + const expr same_addr = 1 - SEGMENT_L1 - addr_changes; + same_addr * ('addr - addr) === 0; col witness same_value; - (1 - same_value) * (1 - wr) * (1 - addr_changes) === 0; + same_value * (1 - same_value) === 0; + (1 - same_value) * (1 - wr) * same_addr === 0; col witness first_addr_access_is_read; - (1 - first_addr_access_is_read) * rd * (1 - addr_changes) === 0; + first_addr_access_is_read * (1 - first_addr_access_is_read) === 0; + (1 - first_addr_access_is_read) * rd * same_addr === 0; - for (int index = 0; index < length(value); index = index + 1) { + for (int index = 0; index < length(value); index++) { same_value * (value[index] - 'value[index]) === 0; first_addr_access_is_read * value[index] === 0; } @@ -85,22 +89,26 @@ airtemplate Mem (int N = 2**21, int RC = 2, int id = MEMORY_ID, int MAX_STEP = 2 // permutation_proves(MEMORY_CONT_ID, [(mem_segment + 1), addr, step, ...value], sel: mem_last_segment * 'SEGMENT_L1); // last row // permutation_assumes(MEMORY_CONT_ID, [mem_segment, 0, addr, step, ...value], sel: SEGMENT_L1); // first row - permutation_proves(MEMORY_ID, cols: [wr, addr * MEM_BYTES, step, MEM_BYTES, ...value], sel: sel); + // The Memory component is only able to prove aligned memory access, since we force the bus address to be a multiple of MEM_BYTES + // and the width to be exactly MEM_BYTES + // Notice, however, that the main can also use widths of 4, 2, 1 and addresses that are not multiples of MEM_BYTES. + // These are handled with the Memory Align component + permutation_proves(MEMORY_ID, cols: [wr * (MEMORY_STORE_OP - MEMORY_LOAD_OP) + MEMORY_LOAD_OP, addr * MEM_BYTES, step, MEM_BYTES, ...value], sel: sel); } -// TODO: detect non default value but not called, mandatory parameter. -function mem_load(int id = MEMORY_ID, expr sel = 1, expr addr, expr step, expr step_offset = 0, expr bytes = 8, expr value[]) { - if (step_offset > MAX_MEM_STEP_OFFSET) { - error("max step_offset ${step_offset} is greater than max value ${MAX_MEM_STEP_OFFSET}"); - } - // adding one for first continuation - permutation_assumes(id, [MEMORY_LOAD_OP, addr, 1 + ((MAX_MEM_STEP_OFFSET + 1) * step) + step_offset, bytes, ...value], sel:sel); +function mem_load(int id = MEMORY_ID, expr addr, expr step, expr step_offset = 0, expr bytes = 8, expr value[], expr sel = 1) { + mem_assumes(id, MEMORY_LOAD_OP, addr, step, step_offset, bytes, value, sel); } -function mem_store(int id = MEMORY_ID, expr sel = 1, expr addr, expr step, expr step_offset = 0, expr bytes = 8, expr value[]) { +function mem_store(int id = MEMORY_ID, expr addr, expr step, expr step_offset = 0, expr bytes = 8, expr value[], expr sel = 1) { + mem_assumes(id, MEMORY_STORE_OP, addr, step, step_offset, bytes, value, sel); +} + +private function mem_assumes(int id, int mem_op, expr addr, expr step, expr step_offset, expr bytes, expr value[], expr sel) { if (step_offset > MAX_MEM_STEP_OFFSET) { error("max step_offset ${step_offset} is greater than max value ${MAX_MEM_STEP_OFFSET}"); } - // adding one for first continuation - permutation_assumes(id, [MEMORY_STORE_OP, addr, 1 + ((MAX_MEM_STEP_OFFSET + 1) * step), bytes, ...value], sel:sel); + + // adding 1 at step for first continuation + permutation_assumes(id, [mem_op, addr, 1 + ((MAX_MEM_STEP_OFFSET + 1) * step) + step_offset, bytes, ...value], sel: sel); } \ No newline at end of file diff --git a/state-machines/mem/pil/mem_align.pil b/state-machines/mem/pil/mem_align.pil index e69de29b..184b86a2 100644 --- a/state-machines/mem/pil/mem_align.pil +++ b/state-machines/mem/pil/mem_align.pil @@ -0,0 +1,173 @@ +require "std_permutation.pil" +require "std_lookup.pil" +require "std_range_check.pil" + +// Problem to solve: +// ================= +// We are given an op (rd,wr), an addr, a step and a bytes-width (8,4,2,1) and we should prove that the memory access is correct. +// Note: Either the original addr is not a multiple of 8 or width < 8 to ensure it is a non-aligned access that should be +// handled by this component. + +/* + We will model it as a very specified processor with 8 registers and a very limited instruction set. + + This processor is limited to 4 possible subprograms: + + 1] Read operation that spans one memory word w = [w_0, w_1]: + w_0 w_1 + +---+===+===+===+ +===+---+---+---+ + | 0 | 1 | 2 | 3 | | 4 | 5 | 6 | 7 | + +---+===+===+===+ +===+---+---+---+ + |<------ v ------>| + + [R] In the first clock cycle, we perform an aligned read to w + [V] In the second clock cycle, we return the demanded value v from w + + 2] Write operation that spans one memory word w = [w_0, w_1]: + w_0 w_1 + +---+---+---+---+ +---+===+===+---+ + | 0 | 1 | 2 | 3 | | 4 | 5 | 6 | 7 | + +---+---+---+---+ +---+===+===+---+ + |<- v ->| + + [R] In the first clock cycle, we perform an aligned read to w + [W] In the second clock cycle, we compute an aligned write of v to w + [V] In the third clock cycle, we restore the demanded value from w + + 3] Read operation that spans two memory words w1 = [w1_0, w1_1] and w2 = [w2_0, w2_1]: + w1_0 w1_1 w2_0 w2_1 + +---+---+---+---+ +---+===+===+===+ +===+===+===+===+ +===+---+---+---+ + | 0 | 1 | 2 | 3 | | 4 | 5 | 6 | 7 | | 0 | 1 | 2 | 3 | | 4 | 5 | 6 | 7 | + +---+---+---+---+ +---+===+===+===+ +===+===+===+===+ +===+---+---+---+ + |<---------------- v ---------------->| + + [R] In the first clock cycle, we perform an aligned read to w1 + [V] In the second clock cycle, we return the demanded value v from w1 and w2 + [R] In the third clock cycle, we perform an aligned read to w2 + + 4] Write operation that spans two memory words w1 = [w1_0, w1_1] and w2 = [w2_0, w2_1]: + w1_0 w1_1 w2_0 w2_1 + +---+===+===+===+ +===+===+===+===+ +===+---+---+---+ +---+---+---+---+ + | 0 | 1 | 2 | 3 | | 4 | 5 | 6 | 7 | | 0 | 1 | 2 | 3 | | 4 | 5 | 6 | 7 | + +---+===+===+===+ +===+===+===+===+ +===+---+---+---+ +---+---+---+---+ + |<---------------- v ---------------->| + + [R] In the first clock cycle, we perform an aligned read to w1 + [W] In the second clock cycle, we compute an aligned write of v to w1 + [V] In the third clock cycle, we restore the demanded value from w1 and w2 + [R] In the fourth clock cycle, we perform an aligned read to w2 + [W] In the fiveth clock cycle, we compute an aligned write of v to w2 + + Example: + ========================================================== + (offset = 6, width = 4) + +----+----+----+----+----+----+----+----+ + | R7 | R6 | R5 | R4 | R3 | R2 | R1 | R0 | [R1] (assume, up_to_down) sel = [1,1,1,1,1,1,0,0] + +----+----+----+----+----+----+----+----+ + ⇓ + +----+----+----+----+----+----+====+====+ + | W7 | W6 | W5 | W4 | W3 | W2 | W1 | W0 | [W1] (assume, up_to_down) sel = [0,0,0,0,0,0,1,1] + +----+----+----+----+----+----+====+====+ + ⇓ + +====+====+----+----+----+----+====+====+ + | V1 | V0 | V7 | V6 | V5 | V4 | V3 | V2 | [V] (prove) (shift (offset + width) % 8) sel = [0,0,0,0,0,0,1,0] (*) + +====+====+----+----+----+----+====+====+ + ⇓ + +====+====+----+----+----+----+----+----+ + | W7 | W6 | W5 | W4 | W3 | W2 | W1 | W0 | [W2] (assume, down_to_up) sel = [1,1,0,0,0,0,0,0] + +====+====+----+----+----+----+----+----+ + ⇓ + +----+----+----+----+----+----+----+----+ + | R7 | R6 | R5 | R4 | R3 | R2 | R1 | R0 | [R2] (assume, down_to_up) sel = [0,0,1,1,1,1,1,1] + +----+----+----+----+----+----+----+----+ + + (*) In this step, we use the selectors to indicate the "scanning" needed to form the bus value: + v_0 = sel[0] * [V1,V0,V7,V6] + sel[1] * [V0,V7,V6,V5] + sel[2] * [V7,V6,V5,V4] + sel[3] * [V6,V5,V4,V3] + v_1 = sel[4] * [V5,V4,V3,V2] + sel[5] * [V4,V3,V2,V1] + sel[6] * [V3,V2,V1,V0] + sel[7] * [V2,V1,V0,V7] + Notice that it is enough with 8 combinations. +*/ + +airtemplate MemAlign(const int N = 2**10, const int RC = 2, const int MEM_BYTES = 8, const int CHUNK_BITS = 8) { + const int MEM_HALF_BYTES = MEM_BYTES / 2; + + col witness addr; // MEM_BYTES-byte address, real address = addr * MEM_BYTES + col witness offset; // 0..7, position at which the operation starts + col witness width; // 1,2,4,8, width of the operation + col witness wr; // 1 if the operation is a write, 0 otherwise + col witness pc; // line of the program to execute + col witness reset; // 1 at the beginning of the operation (indicating an address reset), 0 otherwise + col witness sel_up_to_down; // 1 if the next value is the current value (e.g. R -> W) + col witness sel_down_to_up; // 1 if the next value is the previous value (e.g. W -> R) + col witness reg[MEM_BYTES]; // Register values, 1 byte each + col witness sel[MEM_BYTES]; // Selectors, 1 if the value is used, 0 otherwise + + // 1] Ensure the MemAlign follows the program + + // Registers should be bytes and be shuch that: + // - reg' == reg in transitions R -> V, R -> W, W -> V, + // - 'reg == reg in transitions V <- W, W <- R, + // in any case, sel_up_to_down,sel_down_to_up are 0 in [V] steps. + for (int i = 0; i < MEM_BYTES; i++) { + range_check(reg[i], 0, 2**CHUNK_BITS-1); + + (reg[i]' - reg[i]) * sel[i] * sel_up_to_down === 0; + ('reg[i] - reg[i]) * sel[i] * sel_down_to_up === 0; + } + + col fixed L1 = [1,0...]; + L1 * pc === 0; // The program should start at the first line + + // We compress selectors, so we should ensure they are binary + for (int i = 0; i < MEM_BYTES; i++) { + sel[i] * (1 - sel[i]) === 0; + } + wr * (1 - wr) === 0; + reset * (1 - reset) === 0; + sel_up_to_down * (1 - sel_up_to_down) === 0; + sel_down_to_up * (1 - sel_down_to_up) === 0; + + expr flags = 0; + for (int i = 0; i < MEM_BYTES; i++) { + flags += sel[i] * 2**i; + } + flags += wr * 2**MEM_BYTES + reset * 2**(MEM_BYTES + 1) + sel_up_to_down * 2**(MEM_BYTES + 2) + sel_down_to_up * 2**(MEM_BYTES + 3); + + lookup_assumes(MEM_ALIGN_ROM_ID, [pc, pc'-pc, (addr-'addr)*(1-reset), offset, width, flags]); + + // 2] Assume aligned memory accesses against the Memory component + const expr sel_assume = sel_up_to_down + sel_down_to_up; + + // Offset should be 0 in aligned memory accesses, but this is ensured by the rom + // Width should be 8 in aligned memory accesses, but this is ensured by the rom + + // On assume steps, we reconstruct the value from the registers directly + expr assume_val[RC]; + for (int i = 0; i < RC; i++) { + assume_val[i] = 0; + for (int j = 0; j < MEM_HALF_BYTES; j++) { + assume_val[i] += reg[j + i * MEM_HALF_BYTES] * 2**j; + } + } + + // 3] Prove unaligned memory accesses against the Main component + col witness sel_prove; + + sel_prove * sel_assume === 0; // Disjoint selectors + + // On prove steps, we reconstruct the value in the correct manner chosen by the selectors + expr prove_val[RC]; + for (int i = 0; i < RC; i++) { + prove_val[i] = 0; + for (int j = 0; j < MEM_HALF_BYTES; j++) { + expr _prove_val = 0; + for (int k = j; k < j + MEM_HALF_BYTES; k++) { + _prove_val += reg[(k + i * MEM_HALF_BYTES) % MEM_BYTES] * 2**(k-j); + } + prove_val[i] += sel[j + i * MEM_HALF_BYTES] * _prove_val; + } + } + + // We prove and assume with the same permutation check but with disjoint and different sign selectors + col witness step; + permutation(MEMORY_ID, cols: [wr * (MEMORY_STORE_OP - MEMORY_LOAD_OP) + MEMORY_LOAD_OP, addr * MEM_BYTES + offset, step, width, ...prove_val], sel: sel_prove - sel_assume); +} \ No newline at end of file diff --git a/state-machines/mem/pil/mem_align_rom.pil b/state-machines/mem/pil/mem_align_rom.pil new file mode 100644 index 00000000..211da1dd --- /dev/null +++ b/state-machines/mem/pil/mem_align_rom.pil @@ -0,0 +1,299 @@ +require "std_lookup.pil" +require "constants.pil" + +const int MEM_ALIGN_ROM_ID = 133; +const int MEM_ALIGN_ROM_SIZE = P2_8; + +// PROGRAM SIZE +// RV 0 2 +// RWV 1 3 +// RVR 2 3 +// RWVWR 3 5 +// +// Note1: The offset and width are sufficient to group programs with the same number of operations. +// Note2: The first instruction is always a read. + +airtemplate MemAlignRom(const int N = MEM_ALIGN_ROM_SIZE, const int MEM_BYTES = 8, const int disable_fixed = 0) { + if (N < MEM_ALIGN_ROM_SIZE) { + error(`N must be at least ${MEM_ALIGN_ROM_SIZE}, but N=${N} was provided`); + } + + col witness multiplicity; + + if (disable_fixed) { + col fixed _K = [0...]; + multiplicity * _K === 0; + + println("*** DISABLE_FIXED ***"); + return; + } + + // Not all combinations of offset and width are valid for each program. + // Moreover, offset is set to 0 and width to 8 in aligned memory accesses. + // size + col fixed OFFSET = [[[0,0]:3, [0,1]:3, [0,2]:3, [0,3]:3, [0,4]:3, [0,5]:2, [0,6]:2, [0,7]], // RV 6+6*4+4+4+2 = 40 | 40 + [[0,0,0]:3, [0,0,1]:3, [0,0,2]:3, [0,0,3]:3, [0,0,4]:3, [0,0,5]:2, [0,0,6]:2, [0,0,7]], // RWV 9+9*4+6+6+3 = 60 | 100 + [[0,1,0], [0,2,0], [0,3,0], [0,4,0], [0,5,0]:2, [0,6,0]:2, [0,7,0]:3], // RVR 3*4+6+6+9 = 33 | 133 + [[0,0,1,0,0], [0,0,2,0,0], [0,0,3,0,0], [0,0,4,0,0], [0,0,5,0,0]:2, [0,0,6,0,0]:2, [0,0,7,0,0]:3]]...; // RWVWR 5*4+10+10+15 = 55 | 188 => N = 2^8 + + col fixed WIDTH = [[[8,1,8,2,8,4]:5, [8,1,8,2]:2, [8,1]], // RV + [[8,8,1,8,8,2,8,8,4], [8,8,1,8,8,2,8,8,4]:4, [8,8,1,8,8,2]:2, [8,8,1]], // RWV + [[8,8,8]:4, [8,4,8,8,8,8]:2, [8,2,8,8,4,8,8,8,8]], // RVR + [[8,8,8,8,8]:4, [8,8,4,8,8,8,8,8,8,8]:2, [8,8,2,8,8,8,8,4,8,8,8,8,8,8,8]]]...; // RWVWR + + const int psize1 = 40; + const int psize2 = 60; + const int psize3 = 33; + const int psize4 = 55; + + // line | pc | pc'-pc | reset | addr | (addr-'addr)*(1-reset) | + // 0 | 0 | 1 | 1 | X1 | 0 | // (RV) + // 1 | 1 | -1 | 0 | X1 | 0 | + // 2 | 0 | 3 | 1 | X2 | 0 | // (RV) + // 3 | 3 | -3 | 0 | X2 | 0 | + // 4 | 0 | 5 | 1 | X3 | 0 | // (RV) + // 5 | 5 | -5 | 0 | X3 | 0 | + // 6 | 0 | 7 | 1 | ⋮ | ⋮ | // (RV) + // ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | + // 40 | 0 | 41 | 1 | X4 | 0 | // (RWV) + // 41 | 41 | 1 | 0 | X4 | 0 | + // 42 | 42 | -42 | 0 | X4 | 0 | + // 43 | 0 | 44 | 1 | X5 | 0 | // (RWV) + // ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | + // 100 | 0 | 101 | 1 | X6 | 0 | // (RVR) + // 101 |101 | 1 | 0 | X6 | 0 | + // 102 |102 | -102 | 0 | X6+1 | 1 | + // ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | + // 133 | 0 | 134 | 1 | X7 | 0 | // (RWVWR) + // 134 |134 | 1 | 0 | X7 | 0 | + // 135 |135 | 1 | 0 | X7 | 0 | + // 136 |136 | 1 | 0 | X7+1 | 1 | + // 137 |137 | -137 | 0 | X7+1 | 1 | + // ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | + // 188 | 0 | 0 | 0 | 0 | 0 | // for padding + + col fixed PC; + col fixed DELTA_PC; + col fixed DELTA_ADDR; + col fixed FLAGS; + for (int i = 0; i < N; i++) { + const int [offset, width] = [OFFSET[i], WIDTH[i]]; + int pc = 0; + int delta_pc = 0; + int delta_addr = 0; + int is_write = 0; + int reset = 0; + int sel[MEM_BYTES]; + for (int j = 0; j < MEM_BYTES; j++) { + sel[j] = 0; + } + int sel_up_to_down = 0; + int sel_down_to_up = 0; + + const int line = i; + const int next = i+1; + if (line < psize1) // RV + { + if (line % 2 == 0) { + // pc = 0; + delta_pc = next; + // delta_addr = 0; + // is_write = 0; + reset = 1; + sel = get_selectors(offset, width, program: 0); + sel_up_to_down = 1; + // sel_down_to_up = 0; + } else { + pc = line; + delta_pc = -pc; + delta_addr = 1; + // is_write = 0; + // reset = 0; + // sel = [0:MEM_BYTES] + // sel_up_to_down = 0; + // sel_down_to_up = 0; + } + } + else if (line < psize1+psize2) // RWV + { + if (line % 3 == 0) { // R + // pc = 0; + delta_pc = next; + // delta_addr = 0; + // is_write = 0; + reset = 1; + sel = get_selectors(offset, width, program: 1); + sel_up_to_down = 1; + // sel_down_to_up = 0; + } else if (line % 3 == 1) { // W + pc = line; + delta_pc = 1; + delta_addr = 1; + is_write = 0; + // reset = 0; + sel = get_selectors(offset, width, program: 1, is_write: 1); + sel_up_to_down = 1; + // sel_down_to_up = 0; + } else { // V + pc = line; + delta_pc = -pc; + // delta_addr = 0; + // is_write = 0; + // reset = 0; + // sel = [0:MEM_BYTES] + // sel_up_to_down = 0; + // sel_down_to_up = 0; + } + } + else if (line < psize1+psize2+psize3) + { + if (line % 3 == 0) { // R + // pc = 0; + delta_pc = next; + // delta_addr = 0; + // is_write = 0; + reset = 1; + sel = get_selectors(offset, width, program: 2); + sel_up_to_down = 1; + // sel_down_to_up = 0; + } else if (line % 3 == 1) { // V + pc = line; + delta_pc = 1; + delta_addr = 1; + // is_write = 0; + // reset = 0; + // sel = [0:MEM_BYTES] + // sel_up_to_down = 1; + // sel_down_to_up = 0; + } else { // R + pc = line; + delta_pc = -pc; + // delta_addr = 0; + // is_write = 0; + // reset = 0; + sel = get_selectors(offset, width, program: 2); + // sel_up_to_down = 0; + sel_down_to_up = 1; + } + } + else if (line < psize1+psize2+psize3+psize4) + { + if (next % 5 == 0) { // R + // pc = 0; + delta_pc = next; + // delta_addr = 0; + // is_write = 0; + reset = 1; + sel = get_selectors(offset, width, program: 3, is_write: 0, is_first: 1); + sel_up_to_down = 1; + // sel_down_to_up = 0; + } else if (next % 5 == 1) { // W + pc = line; + delta_pc = 1; + delta_addr = 1; + is_write = 1; + // reset = 0; + sel = get_selectors(offset, width, program: 3, is_write: 1, is_first: 1); + sel_up_to_down = 1; + // sel_down_to_up = 0; + } else if (next % 5 == 2) { // V + pc = line; + delta_pc = 1; + delta_addr = 1; + // is_write = 0; + // reset = 0; + // sel = [0:MEM_BYTES] + // sel_up_to_down = 0; + // sel_down_to_up = 0; + } else if (next % 5 == 3) { // W + pc = line; + delta_pc = 1; + delta_addr = 1; + is_write = 1; + // reset = 0; + sel = get_selectors(offset, width, program: 3, is_write: 1, is_first: 0); + // sel_up_to_down = 0; + sel_down_to_up = 1; + } else { // R + pc = line; + delta_pc = -pc; + // delta_addr = 0; + // is_write = 0; + // reset = 0; + sel = get_selectors(offset, width, program: 3, is_write: 0, is_first: 0); + // sel_up_to_down = 0; + sel_down_to_up = 1; + } + } + + PC[i] = pc; + DELTA_PC[i] = delta_pc; + DELTA_ADDR[i] = delta_addr; + FLAGS[i] = 0; + for (int j = 0; j < MEM_BYTES; j++) { + FLAGS[i] += sel[j] * 2**j; + } + FLAGS[i] += is_write * 2**MEM_BYTES + reset * 2**(MEM_BYTES + 1) + sel_up_to_down * 2**(MEM_BYTES + 2) + sel_down_to_up * 2**(MEM_BYTES + 3); + } + + lookup_proves(MEM_ALIGN_ROM_ID, [PC, DELTA_PC, DELTA_ADDR, OFFSET, WIDTH, FLAGS], multiplicity); +} + +private function get_selectors(const int offset, const int width, const int program, const int is_write = 0, const int is_first = 0, const int bytes = 8): int[] { + int _sel[bytes]; + for (int j = 0; j < bytes; j++) { + _sel[j] = 0; + } + + switch (program) { + case 0: // RV + for (int j = 0; j < offset; j++) { + _sel[j] = 1; + } + + case 1: // RWV + if (!is_write) { + for (int j = 0; j < offset; j++) { + _sel[j] = 1; + } + } else { + for (int j = offset; j < offset + width; j++) { + _sel[j] = 1; + } + } + + case 2: // RVR + for (int j = 0; j < offset; j++) { + _sel[j] = 1; + } + + case 3: // RWVWR + if (is_first) { + if (!is_write) { + for (int j = 0; j < offset; j++) { + _sel[j] = 1; + } + } else { + for (int j = offset; j < bytes; j++) { + _sel[j] = 1; + } + } + } else { + const int rem = (offset + width) % bytes; + if (is_write) { + for (int j = 0; j < rem; j++) { + _sel[j] = 1; + } + } else { + for (int j = rem; j < bytes; j++) { + _sel[j] = 1; + } + } + } + + default: + error(`Invalid program ${program}`); + } + + return _sel; +} \ No newline at end of file diff --git a/state-machines/mem/src/mem_proxy.rs b/state-machines/mem/src/mem_proxy.rs index c7a4bce9..73dc7394 100644 --- a/state-machines/mem/src/mem_proxy.rs +++ b/state-machines/mem/src/mem_proxy.rs @@ -4,7 +4,7 @@ use std::sync::{ }; use crate::{MemAlignSM, MemSM}; -use p3_field::{Field, PrimeField}; +use p3_field::PrimeField; use proofman_util::{timer_start_debug, timer_stop_and_log_debug}; use sm_common::{MemOp, MemUnalignedOp}; use zisk_core::ZiskRequiredMemory; @@ -57,7 +57,7 @@ impl MemProxy { pub fn unregister_predecessor(&self) { if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { - // self.mem_sm.unregister_predecessor(); + self.mem_sm.unregister_predecessor(); // self.mem_align_sm.unregister_predecessor::(); } } @@ -89,16 +89,9 @@ impl MemProxy { // Step 3. Concatenate the new aligned memory accesses with the original aligned memory accesses aligned.extend(new_aligned); - // Step 4. Sort the (full) aligned memory accesses - timer_start_debug!(MEM_SORT_2); - aligned.sort_by_key(|mem| mem.address); - timer_stop_and_log_debug!(MEM_SORT_2); - - // Step 5. Prove the aligned memory accesses using mem state machine + // Step 4. Prove the aligned memory accesses using mem state machine + self.mem_sm.prove(&mut aligned); - println!("Proving MemSM"); - println!("Aligned: {:?}", operations[0].len()); - println!("Non aligned: {:?}", operations[1].len()); Ok(()) } diff --git a/state-machines/mem/src/mem_sm.rs b/state-machines/mem/src/mem_sm.rs index b0febe2f..145dd873 100644 --- a/state-machines/mem/src/mem_sm.rs +++ b/state-machines/mem/src/mem_sm.rs @@ -5,13 +5,13 @@ use std::sync::{ use p3_field::PrimeField; use proofman::{WitnessComponent, WitnessManager}; -use proofman_common::{AirInstance, ExecutionCtx, ProofCtx, SetupCtx}; -use rayon::Scope; -use sm_common::{MemOp, OpResult, Provable}; -use zisk_core::ZiskRequiredMemory; -// use zisk_pil::{Mem0Trace, MEM_AIRGROUP_ID, MEM_AIR_IDS}; +use proofman_common::AirInstance; +use proofman_util::{timer_start_debug, timer_stop_and_log_debug}; +use rayon::prelude::*; -const PROVE_CHUNK_SIZE: usize = 1 << 12; +use sm_common::{create_prover_buffer, MemOp}; +use zisk_core::ZiskRequiredMemory; +use zisk_pil::{Mem0Trace, MEM_AIRGROUP_ID, MEM_AIR_IDS, ZISK_AIRGROUP_ID}; pub struct MemSM { // Witness computation manager @@ -37,7 +37,7 @@ impl MemSM { }; let mem_sm = Arc::new(mem_sm); - // wcm.register_component(mem_sm.clone(), Some(MEM_AIRGROUP_ID), Some(MEM_AIR_IDS)); + wcm.register_component(mem_sm.clone(), Some(MEM_AIRGROUP_ID), Some(MEM_AIR_IDS)); mem_sm } @@ -47,9 +47,61 @@ impl MemSM { } pub fn unregister_predecessor(&self) { - if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { - // as Provable>::prove(self, &[], true, scope); + if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 {} + } + + pub fn prove(&self, mem_accesses: &mut [ZiskRequiredMemory]) { + // Sort the (full) aligned memory accesses + timer_start_debug!(MEM_SORT_2); + mem_accesses.sort_by_key(|mem| mem.address); + timer_stop_and_log_debug!(MEM_SORT_2); + + let air = self.wcm.get_pctx().pilout.get_air(MEM_AIRGROUP_ID, MEM_AIR_IDS[0]); + + let num_chunks = (mem_accesses.len() as f64 / (air.num_rows() - 1) as f64).ceil() as usize; + + let mut prover_buffers = vec![Vec::new(); num_chunks]; + let mut offsets = vec![0; num_chunks]; + let mut global_idxs = vec![0; num_chunks]; + + let pctx = self.wcm.get_pctx(); + let ectx = self.wcm.get_ectx(); + let sctx = self.wcm.get_sctx(); + + for i in 0..num_chunks { + if let (true, global_idx) = self.wcm.get_ectx().dctx.write().unwrap().add_instance( + ZISK_AIRGROUP_ID, + MEM_AIR_IDS[0], + 1, + ) { + let (buffer, offset) = + create_prover_buffer::(&ectx, &sctx, ZISK_AIRGROUP_ID, MEM_AIR_IDS[0]); + + prover_buffers.push(buffer); + offsets.push(offset); + global_idxs.push(global_idx); + } } + + mem_accesses.par_chunks(air.num_rows() - 1).enumerate().for_each( + |(segment_id, mem_ops)| { + let mem_first_row = if segment_id == 0 { + mem_accesses.last().unwrap().clone() + } else { + mem_accesses[segment_id * ((air.num_rows() - 1) - 1)].clone() + }; + + self.prove_instance( + mem_ops, + mem_first_row, + segment_id, + segment_id == mem_accesses.len() - 1, + prover_buffers[segment_id], + offsets[segment_id], + global_idxs[segment_id], + ); + }, + ); } /// Finalizes the witness accumulation process and triggers the proof generation. @@ -67,146 +119,139 @@ impl MemSM { is_last_segment: bool, mut prover_buffer: Vec, offset: u64, - pctx: Arc>, - ectx: Arc, - sctx: Arc, + global_idx: usize, ) -> Result<(), Box> { - // STEP2: Process the memory inputs and convert them to AIR instances - // let air = pctx.pilout.get_air(MEM_AIRGROUP_ID, MEM_AIR_IDS[0]); - - // let max_rows_per_segment = air.num_rows() - 1; - - // assert!(mem_ops.len() > 0 && mem_ops.len() <= max_rows_per_segment); - - // // In a Mem AIR instance the first row is a dummy row used for the continuations between AIR segments - // // In a Memory AIR instance, the first row is reserved as a dummy row. - // // This dummy row is used to facilitate the continuation state between different AIR segments. - // // It ensures seamless transitions when multiple AIR segments are processed consecutively. - // // This design avoids discontinuities in memory access patterns and ensures that the memory trace is continuous, - // // For this reason we use AIR num_rows - 1 as the number of rows in each memory AIR instance - - // // Create a vector of Mem0Row instances, one for each memory operation - // // Recall that first row is a dummy row used for the continuations between AIR segments - // // The length of the vector is the number of input memory operations plus one because - // // in the prove_witnesses method we drain the memory operations in chunks of n - 1 rows - - // let mut trace = - // Mem0Trace::::map_buffer(&mut prover_buffer, air.num_rows(), offset as usize) - // .unwrap(); - - // let segment_id_field = F::from_canonical_u64(segment_id as u64); - // let is_last_segment_field = F::from_bool(is_last_segment); - - // // STEP1. Add the first row to the output vector as equal to the last row of the previous segment - // // CASE: last row of segment is read - // // - // // S[n-1] wr = 0, sel = 1, addr, step, value - // // S+1[0] wr = 0, sel = 0, addr, step, value - // // - // // CASE: last row of segment is write - // // - // // S[n-1] wr = 1, sel = 1, addr, step, value - // // S+1[0] wr = 0, sel = 0, addr, step, value + let pctx = self.wcm.get_pctx(); + // STEP2: Process the memory inputs and convert them to AIR instances + let air = pctx.pilout.get_air(MEM_AIRGROUP_ID, MEM_AIR_IDS[0]); + + let max_rows_per_segment = air.num_rows() - 1; + + assert!(mem_ops.len() > 0 && mem_ops.len() <= max_rows_per_segment); + + // In a Mem AIR instance the first row is a dummy row used for the continuations between AIR segments + // In a Memory AIR instance, the first row is reserved as a dummy row. + // This dummy row is used to facilitate the continuation state between different AIR segments. + // It ensures seamless transitions when multiple AIR segments are processed consecutively. + // This design avoids discontinuities in memory access patterns and ensures that the memory trace is continuous, + // For this reason we use AIR num_rows - 1 as the number of rows in each memory AIR instance + + // Create a vector of Mem0Row instances, one for each memory operation + // Recall that first row is a dummy row used for the continuations between AIR segments + // The length of the vector is the number of input memory operations plus one because + // in the prove_witnesses method we drain the memory operations in chunks of n - 1 rows + + let mut trace = + Mem0Trace::::map_buffer(&mut prover_buffer, air.num_rows(), offset as usize) + .unwrap(); + + let segment_id_field = F::from_canonical_u64(segment_id as u64); + let is_last_segment_field = F::from_bool(is_last_segment); + + // STEP1. Add the first row to the output vector as equal to the last row of the previous segment + // CASE: last row of segment is read + // + // S[n-1] wr = 0, sel = 1, addr, step, value + // S+1[0] wr = 0, sel = 0, addr, step, value + // + // CASE: last row of segment is write + // + // S[n-1] wr = 1, sel = 1, addr, step, value + // S+1[0] wr = 0, sel = 0, addr, step, value + + // TODO CHECK // trace[0].mem_segment = segment_id_field; // trace[0].mem_last_segment = is_last_segment_field; - // trace[0].addr = F::from_canonical_u64(mem_first_row.address); - // trace[0].step = F::from_canonical_u64(mem_first_row.step); - // trace[0].sel = F::zero(); - // trace[0].wr = F::zero(); - - // let value = match mem_first_row.width { - // 1 => mem_first_row.value as u8 as u64, - // 2 => mem_first_row.value as u16 as u64, - // 4 => mem_first_row.value as u32 as u64, - // 8 => mem_first_row.value, - // _ => panic!("Invalid width"), - // }; - // let (low_val, high_val) = self.get_u32_values(value); - // trace[0].value = [F::from_canonical_u32(low_val), F::from_canonical_u32(high_val)]; - // trace[0].addr_changes = F::zero(); - - // trace[0].same_value = F::zero(); - // trace[0].first_addr_access_is_read = F::zero(); - - // // STEP2. Add all the memory operations to the buffer - // for (idx, mem_op) in mem_ops.iter().enumerate() { - // let i = idx + 1; - // trace[i].mem_segment = segment_id_field; - // trace[i].mem_last_segment = is_last_segment_field; - - // trace[i].addr = F::from_canonical_u64(mem_op.address); // n-byte address, real address = addr * MEM_BYTES - // trace[i].step = F::from_canonical_u64(mem_op.step); - // trace[i].sel = F::one(); - // trace[i].wr = F::from_bool(mem_op.is_write); - - // let value = match mem_op.width { - // 1 => mem_op.value as u8 as u64, - // 2 => mem_op.value as u16 as u64, - // 4 => mem_op.value as u32 as u64, - // 8 => mem_op.value, - // _ => panic!("Invalid width"), - // }; - // let (low_val, high_val) = self.get_u32_values(value); - // trace[i].value = [F::from_canonical_u32(low_val), F::from_canonical_u32(high_val)]; - // if i == 66587 || i == 66586 { - // println!( - // "mem_op.value: {:?} value: {:?} width: {}", - // mem_op.value, trace[i].value, mem_op.width - // ); - // println!("mem_op: {:?}", mem_op); - // } - // let addr_changes = trace[i - 1].addr != trace[i].addr; - // trace[i].addr_changes = if addr_changes { F::one() } else { F::zero() }; - - // let same_value = trace[i - 1].value[0] == trace[i].value[0] - // && trace[i - 1].value[1] == trace[i].value[1]; - // trace[i].same_value = if same_value { F::one() } else { F::zero() }; - - // let first_addr_access_is_read = addr_changes && !mem_op.is_write; - // trace[i].first_addr_access_is_read = - // if first_addr_access_is_read { F::one() } else { F::zero() }; - - // if i == 66587 || i == 66586 { - // println!("trace[{}]: {:?}", i, trace[i]); - // } - // } - - // // STEP3. Add dummy rows to the output vector to fill the remaining rows - // //PADDING: At end of memory fill with same addr, incrementing step, same value, sel = 0, rd = 1, wr = 0 - // let last_row_idx = mem_ops.len(); - // let addr = trace[last_row_idx].addr; - // let mut step = trace[last_row_idx].step; - // let value = trace[last_row_idx].value; - - // for i in (mem_ops.len() + 1)..air.num_rows() { - // step += F::one(); - - // trace[i].mem_segment = segment_id_field; - // trace[i].mem_last_segment = is_last_segment_field; - - // trace[i].addr = addr; - // trace[i].step = step; - // trace[i].sel = F::zero(); - // trace[i].wr = F::zero(); - - // trace[i].value = value; - - // trace[i].addr_changes = F::zero(); - // trace[i].same_value = F::one(); - // trace[i].first_addr_access_is_read = F::zero(); - // } - - // let air_instance = AirInstance::new( - // self.wcm.get_sctx(), - // MEM_AIRGROUP_ID, - // MEM_AIR_IDS[0], - // Some(segment_id), - // prover_buffer, - // ); - - // pctx.air_instance_repo.add_air_instance(air_instance); + trace[0].addr = F::from_canonical_u64(mem_first_row.address); + trace[0].step = F::from_canonical_u64(mem_first_row.step); + trace[0].sel = F::zero(); + trace[0].wr = F::zero(); + + let value = match mem_first_row.width { + 1 => mem_first_row.value as u8 as u64, + 2 => mem_first_row.value as u16 as u64, + 4 => mem_first_row.value as u32 as u64, + 8 => mem_first_row.value, + _ => panic!("Invalid width"), + }; + let (low_val, high_val) = self.get_u32_values(value); + trace[0].value = [F::from_canonical_u32(low_val), F::from_canonical_u32(high_val)]; + trace[0].addr_changes = F::zero(); + + trace[0].same_value = F::zero(); + trace[0].first_addr_access_is_read = F::zero(); + + // STEP2. Add all the memory operations to the buffer + for (idx, mem_op) in mem_ops.iter().enumerate() { + let i = idx + 1; + // TODO CHECK + // trace[i].mem_segment = segment_id_field; + // trace[i].mem_last_segment = is_last_segment_field; + + trace[i].addr = F::from_canonical_u64(mem_op.address); // n-byte address, real address = addr * MEM_BYTES + trace[i].step = F::from_canonical_u64(mem_op.step); + trace[i].sel = F::one(); + trace[i].wr = F::from_bool(mem_op.is_write); + + let value = match mem_op.width { + 1 => mem_op.value as u8 as u64, + 2 => mem_op.value as u16 as u64, + 4 => mem_op.value as u32 as u64, + 8 => mem_op.value, + _ => panic!("Invalid width"), + }; + let (low_val, high_val) = self.get_u32_values(value); + trace[i].value = [F::from_canonical_u32(low_val), F::from_canonical_u32(high_val)]; + + let addr_changes = trace[i - 1].addr != trace[i].addr; + trace[i].addr_changes = if addr_changes { F::one() } else { F::zero() }; + + let same_value = trace[i - 1].value[0] == trace[i].value[0] + && trace[i - 1].value[1] == trace[i].value[1]; + trace[i].same_value = if same_value { F::one() } else { F::zero() }; + + let first_addr_access_is_read = addr_changes && !mem_op.is_write; + trace[i].first_addr_access_is_read = + if first_addr_access_is_read { F::one() } else { F::zero() }; + } + + // STEP3. Add dummy rows to the output vector to fill the remaining rows + //PADDING: At end of memory fill with same addr, incrementing step, same value, sel = 0, rd = 1, wr = 0 + let last_row_idx = mem_ops.len(); + let addr = trace[last_row_idx].addr; + let mut step = trace[last_row_idx].step; + let value = trace[last_row_idx].value; + + for i in (mem_ops.len() + 1)..air.num_rows() { + step += F::one(); + + // TODO CHECK + // trace[i].mem_segment = segment_id_field; + // trace[i].mem_last_segment = is_last_segment_field; + + trace[i].addr = addr; + trace[i].step = step; + trace[i].sel = F::zero(); + trace[i].wr = F::zero(); + + trace[i].value = value; + + trace[i].addr_changes = F::zero(); + trace[i].same_value = F::one(); + trace[i].first_addr_access_is_read = F::zero(); + } + + let air_instance = AirInstance::new( + self.wcm.get_sctx(), + MEM_AIRGROUP_ID, + MEM_AIR_IDS[0], + Some(segment_id), + prover_buffer, + ); + + pctx.air_instance_repo.add_air_instance(air_instance, Some(global_idx)); Ok(()) } @@ -218,23 +263,6 @@ impl MemSM { impl WitnessComponent for MemSM {} -impl Provable for MemSM { - fn prove(&self, operations: &[MemOp], drain: bool, scope: &Scope) { - if let Ok(mut inputs) = self.inputs.lock() { - inputs.extend_from_slice(operations); - - while inputs.len() >= PROVE_CHUNK_SIZE || (drain && !inputs.is_empty()) { - let num_drained = std::cmp::min(PROVE_CHUNK_SIZE, inputs.len()); - let _drained_inputs = inputs.drain(..num_drained).collect::>(); - - scope.spawn(move |_| { - // TODO! Implement prove drained_inputs (a chunk of operations) - }); - } - } - } -} - #[cfg(test)] mod tests { // use super::*; diff --git a/state-machines/rom/src/rom.rs b/state-machines/rom/src/rom.rs index 0700bef8..9fe9aa4d 100644 --- a/state-machines/rom/src/rom.rs +++ b/state-machines/rom/src/rom.rs @@ -7,7 +7,7 @@ use proofman_util::create_buffer_fast; use zisk_core::{Riscv2zisk, ZiskPcHistogram, ZiskRom, SRC_IMM}; use zisk_pil::{ - Pilout, Rom0Row, Rom0Trace, MAIN_AIRGROUP_ID, MAIN_AIR_IDS, ROM_AIRGROUP_ID, ROM_AIR_IDS, + Pilout, Rom0Row, Rom0Trace, ZISK_AIRGROUP_ID, MAIN_AIR_IDS, ROM_AIR_IDS, }; //use ziskemu::ZiskEmulatorErr; use std::error::Error; @@ -22,7 +22,7 @@ impl RomSM { let rom_sm = Arc::new(rom_sm); let rom_air_ids = ROM_AIR_IDS; - wcm.register_component(rom_sm.clone(), Some(ROM_AIRGROUP_ID), Some(rom_air_ids)); + wcm.register_component(rom_sm.clone(), Some(ZISK_AIRGROUP_ID), Some(rom_air_ids)); rom_sm } @@ -41,13 +41,13 @@ impl RomSM { let sctx = self.wcm.get_sctx(); let num_rows = - self.wcm.get_pctx().pilout.get_air(MAIN_AIRGROUP_ID, MAIN_AIR_IDS[0]).num_rows(); + self.wcm.get_pctx().pilout.get_air(ZISK_AIRGROUP_ID, MAIN_AIR_IDS[0]).num_rows(); let prover_buffer = Self::compute_trace_rom(rom, buffer_allocator, &sctx, pc_histogram, num_rows as u64)?; let air_instance = - AirInstance::new(sctx.clone(), ROM_AIRGROUP_ID, MAIN_AIR_IDS[0], None, prover_buffer); + AirInstance::new(sctx.clone(), ZISK_AIRGROUP_ID, MAIN_AIR_IDS[0], None, prover_buffer); self.wcm.get_pctx().air_instance_repo.add_air_instance(air_instance, Some(instance_gid)); @@ -88,13 +88,13 @@ impl RomSM { main_trace_len: u64, ) -> Result, Box> { let pilout = Pilout::pilout(); - let num_rows = pilout.get_air(ROM_AIRGROUP_ID, ROM_AIR_IDS[0]).num_rows(); + let num_rows = pilout.get_air(ZISK_AIRGROUP_ID, ROM_AIR_IDS[0]).num_rows(); let number_of_instructions = rom.insts.len(); // Allocate a prover buffer let (buffer_size, offsets) = buffer_allocator - .get_buffer_info(sctx, ROM_AIRGROUP_ID, ROM_AIR_IDS[0]) + .get_buffer_info(sctx, ZISK_AIRGROUP_ID, ROM_AIR_IDS[0]) .unwrap_or_else(|err| panic!("Error getting buffer info: {}", err)); let mut prover_buffer = create_buffer_fast(buffer_size as usize); From c53b5001d533a0503536d58f6b0f213e36b2eef8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9ctor=20Masip?= Date: Tue, 5 Nov 2024 08:13:48 +0000 Subject: [PATCH 04/44] First commit --- state-machines/main/pil/main.pil | 30 +- state-machines/mem/src/lib.rs | 2 + state-machines/mem/src/mem_align_rom_sm.rs | 190 ++++++++ state-machines/mem/src/mem_align_sm.rs | 495 +++++++++++++++++++-- 4 files changed, 676 insertions(+), 41 deletions(-) create mode 100644 state-machines/mem/src/mem_align_rom_sm.rs diff --git a/state-machines/main/pil/main.pil b/state-machines/main/pil/main.pil index 027bee94..2d6c3d24 100644 --- a/state-machines/main/pil/main.pil +++ b/state-machines/main/pil/main.pil @@ -79,7 +79,7 @@ airtemplate Main(int N = 2**21, int RC = 2, int stack_enabled = 0, const int ope col witness air.b_imm1; } col witness b_src_ind; - col witness ind_width; // 8 , 4, 2, 1 + col witness ind_width; // 8, 4, 2, 1 // Operations related @@ -135,17 +135,17 @@ airtemplate Main(int N = 2**21, int RC = 2, int stack_enabled = 0, const int ope } // Mem.load - //mem_load(sel: a_src_mem, - // step: addr_step, - // addr: addr0, - // value: a); + mem_load(sel: a_src_mem, + step: addr_step, + addr: addr0, + value: a); // Mem.load - //mem_load(sel: sel_mem_b, - // step: addr_step + 1, - // bytes: ind_width, - // addr: addr1, - // value: b); + mem_load(sel: sel_mem_b, + step: addr_step + 1, + bytes: ind_width, + addr: addr1, + value: b); const expr store_value[2]; @@ -153,11 +153,11 @@ airtemplate Main(int N = 2**21, int RC = 2, int stack_enabled = 0, const int ope store_value[1] = (1 - store_ra) * c[1]; // Mem.store - //mem_store(sel: store_mem + store_ind, - // step: addr_step + 2, - // bytes: ind_width, - // addr: addr2, - // value: store_value); + mem_store(sel: store_mem + store_ind, + step: addr_step + 2, + bytes: ind_width, + addr: addr2, + value: store_value); // Operation.assume => how organize software col witness __debug_operation_bus_enabled; diff --git a/state-machines/mem/src/lib.rs b/state-machines/mem/src/lib.rs index 47dd31fd..a2de5bc2 100644 --- a/state-machines/mem/src/lib.rs +++ b/state-machines/mem/src/lib.rs @@ -1,7 +1,9 @@ mod mem_align_sm; +mod mem_align_rom_sm; mod mem_sm; mod mem_proxy; pub use mem_align_sm::*; +pub use mem_align_rom_sm::*; pub use mem_sm::*; pub use mem_proxy::*; diff --git a/state-machines/mem/src/mem_align_rom_sm.rs b/state-machines/mem/src/mem_align_rom_sm.rs new file mode 100644 index 00000000..cca9dea6 --- /dev/null +++ b/state-machines/mem/src/mem_align_rom_sm.rs @@ -0,0 +1,190 @@ +use std::sync::{ + atomic::{AtomicU32, Ordering}, + Arc, Mutex, +}; + +use log::info; +use p3_field::Field; +use proofman::{WitnessComponent, WitnessManager}; +use proofman_common::AirInstance; +use rayon::prelude::*; +use sm_common::create_prover_buffer; +use zisk_core::{zisk_ops::ZiskOp, P2_11, P2_19, P2_8}; +use zisk_pil::{MEM_UNALIGNED_ROM_AIRGROUP_ID, MEM_UNALIGNED_ROM_AIR_IDS}; + +const MEM_WIDTHS: [u64; 4] = [1, 2, 4, 8]; +const PROGRAM_SIZES: [u64; 4] = [2, 3, 3, 5]; + +pub struct MemUnalignedRomSM { + wcm: Arc>, + + // Count of registered predecessors + registered_predecessors: AtomicU32, + + // Rom data + num_rows: usize, + line: Mutex, + multiplicity: Mutex>, +} + +#[derive(Debug)] +pub enum ExtensionTableSMErr { + InvalidOpcode, +} + +impl MemUnalignedRomSM { + const MY_NAME: &'static str = "MemUnalignedRom"; + + pub fn new(wcm: Arc>, airgroup_id: usize, air_ids: &[usize]) -> Arc { + let pctx = wcm.get_pctx(); + let air = pctx + .pilout + .get_air(MEM_UNALIGNED_ROM_AIRGROUP_ID, MEM_UNALIGNED_ROM_AIR_IDS[0]); + + let mem_unaligned_rom = Self { + wcm: wcm.clone(), + registered_predecessors: AtomicU32::new(0), + num_rows: air.num_rows(), + line: 0, + multiplicity: Mutex::new(vec![0; air.num_rows()]), + }; + let mem_unaligned_rom = Arc::new(mem_unaligned_rom); + wcm.register_component(mem_unaligned_rom.clone(), Some(airgroup_id), Some(air_ids)); + + mem_unaligned_rom + } + + pub fn register_predecessor(&self) { + self.registered_predecessors.fetch_add(1, Ordering::SeqCst); + } + + pub fn unregister_predecessor(&self) { + if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { + self.create_air_instance(); + } + } + + pub fn process_slice(&self, input: &[u64]) { + let mut multiplicity = self.multiplicity.lock().unwrap(); + + for (i, val) in input.iter().enumerate() { + multiplicity[i] += *val; + } + } + + //lookup_proves(MEM_UNALIGNED_ROM_ID, [OP, OFFSET, A, B, C0, C1], multiplicity); + // lookup_proves(MEM_ALIGN_ROM_ID, [PC, DELTA_PC, DELTA_ADDR, OFFSET, WIDTH, FLAGS], multiplicity); + pub fn calculate_rom_row(opcode: MemUnalignedRomOp, offset: u64, a: u64, b: u64) -> u64 { + // Calculate the different row offset contributors, according to the PIL + assert!(a <= 0xff); + let offset_a: u64 = a; + assert!(offset < 0x08); + let offset_offset: u64 = offset * P2_8; + assert!(b <= 0x3f); + let offset_b: u64 = b * P2_11; + let offset_opcode: u64 = Self::offset_opcode(opcode); + + offset_a + offset_offset + offset_b + offset_opcode + } + + pub fn get_program(offset: u64, width: u64, is_wr: bool) -> usize { + match (is_wr, offset + width > 8) { + (false, false) => 0, // RV // TODO: Use an enum instead! + (true, false) => 1, // RWV + (false, true) => 2, // RVR + (true, true) => 3, // RWVWR + } + } + + pub fn get_program_size(offset: u64, width: u64, is_wr: bool) -> usize { + PROGRAM_SIZES[Self::get_program(offset, width, is_wr)] + } + + // TODO + pub fn calculate_next_pc(offset: u8, width: u8, is_wr: bool) -> u64 { + match (offset, width) { + (x,1) if x < 5 => (x+1) * PROGRAM_SIZES[0] - 1, + (x,2) => 2 * PROGRAM_SIZES[0] - 1, + (x,4) => 3 * PROGRAM_SIZES[0] - 1, + (x,8) => panic!("Aligned Memory access: offset=0, width=8"), + + (1,1) => 4 * PROGRAM_SIZES[0] - 1, + (1,2) => 5 * PROGRAM_SIZES[0] - 1, + (1,4) => 6 * PROGRAM_SIZES[0] - 1, + // (1,8) => 7 * PROGRAM_SIZES[0] - 1, // Two words + + (2,1) => 4 * PROGRAM_SIZES[0] - 1, + (2,2) => 5 * PROGRAM_SIZES[0] - 1, + (2,4) => 6 * PROGRAM_SIZES[0] - 1, + // (2,8) => 7 * PROGRAM_SIZES[0] - 1, // Two words + } + } + + fn offset_opcode(opcode: MemUnalignedRomOp) -> u64 { + match opcode { + MemUnalignedRomOp::Sll => 0, + MemUnalignedRomOp::Srl => P2_19, + MemUnalignedRomOp::Sra => 2 * P2_19, + MemUnalignedRomOp::SllW => 3 * P2_19, + MemUnalignedRomOp::SrlW => 4 * P2_19, + MemUnalignedRomOp::SraW => 5 * P2_19, + MemUnalignedRomOp::SignExtendB => 6 * P2_19, + MemUnalignedRomOp::SignExtendH => 6 * P2_19 + P2_11, + MemUnalignedRomOp::SignExtendW => 6 * P2_19 + 2 * P2_11, + //_ => panic!("MemUnalignedRomSM::offset_opcode() got invalid opcode={:?}", opcode), + } + } + + pub fn create_air_instance(&self) { + let ectx = self.wcm.get_ectx(); + let mut dctx: std::sync::RwLockWriteGuard<'_, proofman_common::DistributionCtx> = + ectx.dctx.write().unwrap(); + + let mut multiplicity = self.multiplicity.lock().unwrap(); + + let (is_myne, instance_global_idx) = dctx.add_instance( + MEM_UNALIGNED_ROM_AIRGROUP_ID, + MEM_UNALIGNED_ROM_AIR_IDS[0], + 1, + ); + let owner = dctx.owner(instance_global_idx); + + let mut multiplicity_ = std::mem::take(&mut *multiplicity); + dctx.distribute_multiplicity(&mut multiplicity_, owner); + + if is_myne { + // Create the prover buffer + let (mut prover_buffer, offset) = create_prover_buffer( + &self.wcm.get_ectx(), + &self.wcm.get_sctx(), + MEM_UNALIGNED_ROM_AIRGROUP_ID, + MEM_UNALIGNED_ROM_AIR_IDS[0], + ); + + prover_buffer[offset as usize..offset as usize + self.num_rows] + .par_iter_mut() + .enumerate() + .for_each(|(i, input)| *input = F::from_canonical_u64(multiplicity_[i])); + + info!( + "{}: ··· Creating Binary extension table instance [{} rows filled 100%]", + Self::MY_NAME, + self.num_rows, + ); + + let air_instance = AirInstance::new( + self.wcm.get_sctx(), + MEM_UNALIGNED_ROM_AIRGROUP_ID, + MEM_UNALIGNED_ROM_AIR_IDS[0], + None, + prover_buffer, + ); + self.wcm + .get_pctx() + .air_instance_repo + .add_air_instance(air_instance, Some(instance_global_idx)); + } + } +} + +impl WitnessComponent for MemUnalignedRomSM {} \ No newline at end of file diff --git a/state-machines/mem/src/mem_align_sm.rs b/state-machines/mem/src/mem_align_sm.rs index 0eeb4c38..a2f0105c 100644 --- a/state-machines/mem/src/mem_align_sm.rs +++ b/state-machines/mem/src/mem_align_sm.rs @@ -7,33 +7,37 @@ use p3_field::Field; use proofman::{WitnessComponent, WitnessManager}; use proofman_common::{ExecutionCtx, ProofCtx, SetupCtx}; use rayon::Scope; -use sm_common::{MemOp, OpResult, Provable}; -use zisk_pil::{MEM_AIRGROUP_ID, MEM_ALIGN_AIR_IDS}; +use sm_common::{MemUnalignedOp, OpResult, Provable}; +use zisk_core::ZiskRequiredMemory; +use zisk_pil::{MEM_AIRGROUP_ID, MEM_UNALIGNED_AIR_IDS}; const PROVE_CHUNK_SIZE: usize = 1 << 12; +const CHUNKS: u64 = 8; -pub struct MemAlignSM { +pub struct MemUnalignedSM { // Count of registered predecessors registered_predecessors: AtomicU32, // Inputs - inputs: Mutex>, + inputs: Mutex>, } #[allow(unused, unused_variables)] -impl MemAlignSM { +impl MemUnalignedSM { + const MY_NAME: &'static str = "MemUnaligned"; + pub fn new(wcm: Arc>) -> Arc { - let mem_aligned_sm = + let mem_unaligned_sm = Self { registered_predecessors: AtomicU32::new(0), inputs: Mutex::new(Vec::new()) }; - let mem_aligned_sm = Arc::new(mem_aligned_sm); + let mem_unaligned_sm = Arc::new(mem_unaligned_sm); wcm.register_component( - mem_aligned_sm.clone(), + mem_unaligned_sm.clone(), Some(MEM_AIRGROUP_ID), - Some(MEM_ALIGN_AIR_IDS), + Some(MEM_UNALIGNED_AIR_IDS), ); - mem_aligned_sm + mem_unaligned_sm } pub fn register_predecessor(&self) { @@ -42,13 +46,14 @@ impl MemAlignSM { pub fn unregister_predecessor(&self, scope: &Scope) { if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { - >::prove(self, &[], true, scope); + >::prove(self, &[], true, scope); } } fn read( &self, - _addr: u64, /* , _ctx: &mut ProofCtx, _ectx: &ExecutionCtx */ + _addr: u64, + _width: usize, /* , _ctx: &mut ProofCtx, _ectx: &ExecutionCtx */ ) -> Result> { Ok((0, true)) } @@ -56,13 +61,417 @@ impl MemAlignSM { fn write( &self, _addr: u64, + _width: usize, _val: u64, /* , _ctx: &mut ProofCtx, _ectx: &ExecutionCtx */ ) -> Result> { Ok((0, true)) } + + pub fn process_slice( + input: &Vec, + multiplicity: &mut [u64], + range_check: &mut HashMap, + ) -> Vec> { + // Is a write or a read operation + let wr = input[0].is_write; + + // Get the address + let addr = input[0].address; + let addr_prior = input[1].address; // addr / CHUNKS; + let addr_next = input[2].address; // addr / CHUNKS + CHUNKS; + + // Get the value + let value = input[0].value; + let value_first_read = input[1].value; + let value_first_write = input[2].value; + let value_second_read = input[3].value; + let value_second_write = input[4].value; + + // Get the step + let step = input[0].step; + let step_first_read = input[1].step; + let step_first_write = input[2].step; + let step_second_read = input[3].step; + let step_second_write = input[4].step; + + // Get the offset + let offset = addr % CHUNKS; + + // Get the width + let width = input[0].width; + + // Compute the shift + let shift = (offset + width) % CHUNKS; + + // Get the program to be executed, its size and the pc to jump to + let program = MemUnalignedRomSM::get_program(offset, width, wr); + let program_size = MemUnalignedRomSM::get_program_size(offset, width, wr); + let next_pc = MemUnalignedRomSM::calculate_next_pc(offset, width, wr); + + // Initialize and set the rows of the corresponding program + let mut rows: Vec> = Vec::with_capacity(program_size); + match program { + 0 => { // RV + let mut read_row = MemUnalign0Row:: { + step: F::from_canonical_u64(step_first_read), + addr: F::from_canonical_u64(addr_prior), + // offset: F::from_canonical_u64(0), + width: F::from_canonical_u64(CHUNKS), + // wr: F::from_bool(false), + // pc: F::from_canonical_u64(0), + reset: F::from_bool(true), + sel_up_to_down: F::from_bool(true), + ..Default::default() + }; + + let mut value_row = MemUnalign0Row:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u64(addr), + offset: F::from_canonical_u64(offset), + width: F::from_canonical_u64(width), + // wr: F::from_bool(false), + pc: F::from_canonical_u64(next_pc), + // reset: F::from_bool(false), + sel_prove: F::from_bool(true), + ..Default::default() + }; + + for i in 0..CHUNKS { + read_row.reg[i] = F::from_canonical_u64(value_first_read[i]); + read_row.sel[i] = F::from_bool(true); + + value_row.reg[i] = F::from_canonical_u64(value[shift + i]); + value_row.sel[i] = F::from_bool(i == offset); + + // Store the range check + *range_check.entry(read_row.reg[i]).or_insert(0) += 1; + *range_check.entry(value_row.reg[i]).or_insert(0) += 1; + } + + // Store the rows + rows.push(read_row); + rows.push(value_row); + }, + 1 => { // RWV + let mut read_row = MemUnalign0Row:: { + step: F::from_canonical_u64(step_first_read), + addr: F::from_canonical_u64(addr_prior), + // offset: F::from_canonical_u64(0), + width: F::from_canonical_u64(CHUNKS), + // wr: F::from_bool(false), + // pc: F::from_canonical_u64(0), + reset: F::from_bool(true), + sel_up_to_down: F::from_bool(true), + ..Default::default() + }; + + let mut write_row = MemUnalign0Row:: { + step: F::from_canonical_u64(step_first_write), + addr: F::from_canonical_u64(addr_prior), + // offset: F::from_canonical_u64(0), + width: F::from_canonical_u64(CHUNKS), + wr: F::from_bool(true), + pc: F::from_canonical_u64(next_pc), + // reset: F::from_bool(false), + sel_up_to_down: F::from_bool(true), + ..Default::default() + }; + + let mut value_row = MemUnalign0Row:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u64(addr), + offset: F::from_canonical_u64(offset), + width: F::from_canonical_u64(width), + // wr: F::from_bool(false), + pc: F::from_canonical_u64(next_pc + 1), + // reset: F::from_bool(false), + sel_prove: F::from_bool(true), + ..Default::default() + }; + + for i in 0..CHUNKS { + read_row.reg[i] = F::from_canonical_u64(value_first_read[i]); + read_row.sel[i] = F::from_bool(i < offset); + + write_row.reg[i] = F::from_canonical_u64(value_first_write[i]); + write_row.sel[i] = F::from_bool(i >= offset); + + value_row.reg[i] = F::from_canonical_u64(value[shift + i]); + value_row.sel[i] = F::from_bool(i == offset); + + // Store the range check + *range_check.entry(read_row.reg[i]).or_insert(0) += 1; + *range_check.entry(write_row.reg[i]).or_insert(0) += 1; + *range_check.entry(value_row.reg[i]).or_insert(0) += 1; + } + + // Store the rows + rows.push(read_row); + rows.push(write_row); + rows.push(value_row); + } + 2 => { + // RVR + let mut first_read_row = MemUnalign0Row:: { + step: F::from_canonical_u64(step_first_read), + addr: F::from_canonical_u64(addr_prior), + // offset: F::from_canonical_u64(0), + width: F::from_canonical_u64(CHUNKS), + // wr: F::from_bool(false), + // pc: F::from_canonical_u64(0), + reset: F::from_bool(true), + sel_up_to_down: F::from_bool(true), + ..Default::default() + }; + + let mut value_row = MemUnalign0Row:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u64(addr), + offset: F::from_canonical_u64(offset), + width: F::from_canonical_u64(width), + // wr: F::from_bool(false), + pc: F::from_canonical_u64(next_pc), + // reset: F::from_bool(false), + sel_prove: F::from_bool(true), + ..Default::default() + }; + + let mut second_read_row = MemUnalign0Row:: { + step: F::from_canonical_u64(step_second_read), + addr: F::from_canonical_u64(addr_next), + // offset: F::from_canonical_u64(0), + width: F::from_canonical_u64(CHUNKS), + // wr: F::from_bool(false), + pc: F::from_canonical_u64(next_pc + 1), + // reset: F::from_bool(false), + sel_down_to_up: F::from_bool(true), + ..Default::default() + }; + + for i in 0..CHUNKS { + first_read_row.reg[i] = F::from_canonical_u64(value_first_read[i]); + first_read_row.sel[i] = F::from_bool(true); + + value_row.reg[i] = F::from_canonical_u64(value[shift + i]); + value_row.sel[i] = F::from_bool(i == offset); + + second_read_row.reg[i] = F::from_canonical_u64(value_second_read[i]); + second_read_row.sel[i] = F::from_bool(true); + + // Store the range check + *range_check.entry(first_read_row.reg[i]).or_insert(0) += 1; + *range_check.entry(value_row.reg[i]).or_insert(0) += 1; + *range_check.entry(second_read_row.reg[i]).or_insert(0) += 1; + } + + // Store the rows + rows.push(first_read_row); + rows.push(value_row); + rows.push(second_read_row); + } + 3 => { + // RWVWR + let mut first_read_row = MemUnalign0Row:: { + step: F::from_canonical_u64(step_first_read), + addr: F::from_canonical_u64(addr_prior), + // offset: F::from_canonical_u64(0), + width: F::from_canonical_u64(CHUNKS), + // wr: F::from_bool(false), + // pc: F::from_canonical_u64(0), + reset: F::from_bool(true), + sel_up_to_down: F::from_bool(true), + ..Default::default() + }; + + let mut first_write_row = MemUnalign0Row:: { + step: F::from_canonical_u64(step_first_write), + addr: F::from_canonical_u64(addr_prior), + // offset: F::from_canonical_u64(0), + width: F::from_canonical_u64(CHUNKS), + wr: F::from_bool(true), + pc: F::from_canonical_u64(next_pc), + // reset: F::from_bool(false), + sel_up_to_down: F::from_bool(true), + ..Default::default() + }; + + let mut value_row = MemUnalign0Row:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u64(addr), + offset: F::from_canonical_u64(offset), + width: F::from_canonical_u64(width), + // wr: F::from_bool(false), + pc: F::from_canonical_u64(next_pc + 1), + // reset: F::from_bool(false), + sel_prove: F::from_bool(true), + ..Default::default() + }; + + let mut second_write_row = MemUnalign0Row:: { + step: F::from_canonical_u64(step_second_write), + addr: F::from_canonical_u64(addr_next), + // offset: F::from_canonical_u64(0), + width: F::from_canonical_u64(CHUNKS), + wr: F::from_bool(true), + pc: F::from_canonical_u64(next_pc + 2), + // reset: F::from_bool(false), + sel_down_to_up: F::from_bool(true), + ..Default::default() + }; + + let mut second_read_row = MemUnalign0Row:: { + step: F::from_canonical_u64(step_second_read), + addr: F::from_canonical_u64(addr_next), + // offset: F::from_canonical_u64(0), + width: F::from_canonical_u64(CHUNKS), + // wr: F::from_bool(false), + pc: F::from_canonical_u64(next_pc + 3), + reset: F::from_bool(false), + sel_down_to_up: F::from_bool(true), + ..Default::default() + }; + + for i in 0..CHUNKS { + first_read_row.reg[i] = F::from_canonical_u64(value_first_read[i]); + first_read_row.sel[i] = F::from_bool(i < offset); + + first_write_row.reg[i] = F::from_canonical_u64(value_first_write[i]); + first_write_row.sel[i] = F::from_bool(i >= offset); + + value_row.reg[i] = F::from_canonical_u64(value[shift + i]); + value_row.sel[i] = F::from_bool(i == offset); + + second_write_row.reg[i] = F::from_canonical_u64(value_second_write[i]); + second_write_row.sel[i] = F::from_bool(i < shift); + + second_read_row.reg[i] = F::from_canonical_u64(value_second_read[i]); + second_read_row.sel[i] = F::from_bool(i >= shift); + + // Store the range check + *range_check.entry(first_read_row.reg[i]).or_insert(0) += 1; + *range_check.entry(first_write_row.reg[i]).or_insert(0) += 1; + *range_check.entry(value_row.reg[i]).or_insert(0) += 1; + *range_check.entry(second_write_row.reg[i]).or_insert(0) += 1; + *range_check.entry(second_read_row.reg[i]).or_insert(0) += 1; + } + + // Store the rows + rows.push(first_read_row); + rows.push(first_write_row); + rows.push(value_row); + rows.push(second_write_row); + rows.push(second_read_row); + } + _ => panic!("MemUnalignedSM::process_slice() got invalid program={}", program), + } + + // TBD + // for (i, a_byte) in a_bytes.iter().enumerate() { + // let row = MemUnalignedRomSM::::calculate_table_row( + // mem_unaligned_rom_op, + // i as u64, + // *a_byte as u64, + // in2_low, + // ); + // multiplicity[row as usize] += 1; + // } + + // Return successfully + rows + } + + pub fn prove_instance( + &self, + operations: Vec, + prover_buffer: &mut [F], + offset: u64, + ) { + Self::prove_internal( + &self.wcm, + &self.mem_unaligned_rom_sm, + &self.std, + operations, + prover_buffer, + offset, + ); + } + + fn prove_internal( + wcm: &WitnessManager, + mem_unaligned_rom_sm: &MemUnalignedRomSM, + std: &Std, + operations: Vec, + prover_buffer: &mut [F], + offset: u64, + ) { + let pctx = wcm.get_pctx(); + + let air = pctx.pilout.get_air(MEM_UNALIGNED_AIRGROUP_ID, MEM_UNALIGNED_AIR_IDS[0]); + let air_mem_unaligned_rom = pctx + .pilout + .get_air(MEM_UNALIGNED_ROM_AIRGROUP_ID, MEM_UNALIGNED_ROM_AIR_IDS[0]); + assert!(operations.len() <= air.num_rows()); + + info!( + "{}: ··· Creating Binary extension instance [{} / {} rows filled {:.2}%]", + Self::MY_NAME, + operations.len(), + air.num_rows(), + operations.len() as f64 / air.num_rows() as f64 * 100.0 + ); + + let mut multiplicity_table = vec![0u64; air_mem_unaligned_rom.num_rows()]; + let mut range_check: HashMap = HashMap::new(); + let mut trace_buffer = + BinaryExtension0Trace::::map_buffer(prover_buffer, air.num_rows(), offset as usize) + .unwrap(); + + for (i, operation) in operations.iter().enumerate() { + let row = Self::process_slice(operation, &mut multiplicity_table, &mut range_check); + trace_buffer[i] = row; + } + + let padding_row = + BinaryExtension0Row:: { op: F::from_canonical_u64(0x25), ..Default::default() }; + + for i in operations.len()..air.num_rows() { + trace_buffer[i] = padding_row; + } + + let padding_size = air.num_rows() - operations.len(); + for i in 0..8 { + let multiplicity = padding_size as u64; + let row = MemUnalignedRomSM::::calculate_table_row( + BinaryExtensionTableOp::SignExtendW, + i, + 0, + 0, + ); + multiplicity_table[row as usize] += multiplicity; + } + + mem_unaligned_rom_sm.process_slice(&multiplicity_table); + + let range_id = std.get_range(BigInt::from(0), BigInt::from(0xFFFFFF), None); + + for (value, multiplicity) in &range_check { + std.range_check( + F::from_canonical_u64(*value), + F::from_canonical_u64(*multiplicity), + range_id, + ); + } + + + std::thread::spawn(move || { + drop(operations); + drop(multiplicity_table); + drop(range_check); + }); + } } -impl WitnessComponent for MemAlignSM { +impl WitnessComponent for MemUnalignedSM { fn calculate_witness( &self, _stage: u32, @@ -74,32 +483,66 @@ impl WitnessComponent for MemAlignSM { } } -impl Provable for MemAlignSM { - fn calculate(&self, operation: MemOp) -> Result> { +impl Provable for MemUnalignedSM { + fn calculate(&self, operation: MemUnalignedOp) -> Result> { + // TODO: Perform the aligned read/writes + match operation { - MemOp::Read(addr) => self.read(addr), - MemOp::Write(addr, val) => self.write(addr, val), + MemUnalignedOp::Read(addr, width) => self.read(addr, width), + MemUnalignedOp::Write(addr, width, val) => self.write(addr, width, val), } } - fn prove(&self, operations: &[MemOp], drain: bool, scope: &Scope) { + fn prove(&self, operations: &[ZiskRequiredMemory], drain: bool, _scope: &Scope) { if let Ok(mut inputs) = self.inputs.lock() { inputs.extend_from_slice(operations); - while inputs.len() >= PROVE_CHUNK_SIZE || (drain && !inputs.is_empty()) { - let num_drained = std::cmp::min(PROVE_CHUNK_SIZE, inputs.len()); - let _drained_inputs = inputs.drain(..num_drained).collect::>(); + let pctx = self.wcm.get_pctx(); + let air = + pctx.pilout.get_air(MEM_UNALIGNED_AIRGROUP_ID, MEM_UNALIGNED_AIR_IDS[0]); + + while inputs.len() >= air.num_rows() || (drain && !inputs.is_empty()) { + let num_drained = std::cmp::min(air.num_rows(), inputs.len()); + let drained_inputs = inputs.drain(..num_drained).collect::>(); + + let mem_unaligned_rom_sm = self.mem_unaligned_rom_sm.clone(); + let wcm = self.wcm.clone(); - scope.spawn(move |_| { - // TODO! Implement prove drained_inputs (a chunk of operations) - }); + let std = self.std.clone(); + + let sctx = self.wcm.get_sctx().clone(); + + let (mut prover_buffer, offset) = create_prover_buffer( + &wcm.get_ectx(), + &wcm.get_sctx(), + MEM_UNALIGNED_AIRGROUP_ID, + MEM_UNALIGNED_AIR_IDS[0], + ); + + Self::prove_internal( + &wcm, + &mem_unaligned_rom_sm, + &std, + drained_inputs, + &mut prover_buffer, + offset, + ); + + let air_instance = AirInstance::new( + sctx, + MEM_UNALIGNED_AIRGROUP_ID, + MEM_UNALIGNED_AIR_IDS[0], + None, + prover_buffer, + ); + wcm.get_pctx().air_instance_repo.add_air_instance(air_instance, None); } } } fn calculate_prove( &self, - operation: MemOp, + operation: MemUnalignedOp, drain: bool, scope: &Scope, ) -> Result> { @@ -109,4 +552,4 @@ impl Provable for MemAlignSM { result } -} +} \ No newline at end of file From 0f4441fbcf7fe63f78733594b5b31ef6fcb51d09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9ctor=20Masip?= Date: Tue, 5 Nov 2024 17:24:52 +0000 Subject: [PATCH 05/44] First version of mem_align executor done --- state-machines/mem/pil/mem_align_rom.pil | 2 +- state-machines/mem/src/lib.rs | 8 +- state-machines/mem/src/mem_align_rom_sm.rs | 174 ++++---- state-machines/mem/src/mem_align_sm.rs | 472 +++++++++++---------- 4 files changed, 341 insertions(+), 315 deletions(-) diff --git a/state-machines/mem/pil/mem_align_rom.pil b/state-machines/mem/pil/mem_align_rom.pil index 211da1dd..5f5d8b6b 100644 --- a/state-machines/mem/pil/mem_align_rom.pil +++ b/state-machines/mem/pil/mem_align_rom.pil @@ -37,7 +37,7 @@ airtemplate MemAlignRom(const int N = MEM_ALIGN_ROM_SIZE, const int MEM_BYTES = [[0,0,1,0,0], [0,0,2,0,0], [0,0,3,0,0], [0,0,4,0,0], [0,0,5,0,0]:2, [0,0,6,0,0]:2, [0,0,7,0,0]:3]]...; // RWVWR 5*4+10+10+15 = 55 | 188 => N = 2^8 col fixed WIDTH = [[[8,1,8,2,8,4]:5, [8,1,8,2]:2, [8,1]], // RV - [[8,8,1,8,8,2,8,8,4], [8,8,1,8,8,2,8,8,4]:4, [8,8,1,8,8,2]:2, [8,8,1]], // RWV + [[8,8,1,8,8,2,8,8,4]:5, [8,8,1,8,8,2]:2, [8,8,1]], // RWV [[8,8,8]:4, [8,4,8,8,8,8]:2, [8,2,8,8,4,8,8,8,8]], // RVR [[8,8,8,8,8]:4, [8,8,4,8,8,8,8,8,8,8]:2, [8,8,2,8,8,8,8,4,8,8,8,8,8,8,8]]]...; // RWVWR diff --git a/state-machines/mem/src/lib.rs b/state-machines/mem/src/lib.rs index a2de5bc2..f117ca7c 100644 --- a/state-machines/mem/src/lib.rs +++ b/state-machines/mem/src/lib.rs @@ -1,9 +1,9 @@ -mod mem_align_sm; mod mem_align_rom_sm; -mod mem_sm; +mod mem_align_sm; mod mem_proxy; +mod mem_sm; -pub use mem_align_sm::*; pub use mem_align_rom_sm::*; -pub use mem_sm::*; +pub use mem_align_sm::*; pub use mem_proxy::*; +pub use mem_sm::*; diff --git a/state-machines/mem/src/mem_align_rom_sm.rs b/state-machines/mem/src/mem_align_rom_sm.rs index cca9dea6..3406fc64 100644 --- a/state-machines/mem/src/mem_align_rom_sm.rs +++ b/state-machines/mem/src/mem_align_rom_sm.rs @@ -8,14 +8,18 @@ use p3_field::Field; use proofman::{WitnessComponent, WitnessManager}; use proofman_common::AirInstance; use rayon::prelude::*; + use sm_common::create_prover_buffer; -use zisk_core::{zisk_ops::ZiskOp, P2_11, P2_19, P2_8}; -use zisk_pil::{MEM_UNALIGNED_ROM_AIRGROUP_ID, MEM_UNALIGNED_ROM_AIR_IDS}; +use zisk_pil::{MEM_ALIGN_ROM_AIR_IDS, ZISK_AIRGROUP_ID}; + +use crate::MemOps; +const CHUNKS: usize = 8; const MEM_WIDTHS: [u64; 4] = [1, 2, 4, 8]; -const PROGRAM_SIZES: [u64; 4] = [2, 3, 3, 5]; +const OP_SIZES: [usize; 4] = [2, 3, 3, 5]; -pub struct MemUnalignedRomSM { +pub struct MemAlignRomSM { + // Witness computation manager wcm: Arc>, // Count of registered predecessors @@ -32,26 +36,24 @@ pub enum ExtensionTableSMErr { InvalidOpcode, } -impl MemUnalignedRomSM { - const MY_NAME: &'static str = "MemUnalignedRom"; +impl MemAlignRomSM { + const MY_NAME: &'static str = "MemAlignRom"; pub fn new(wcm: Arc>, airgroup_id: usize, air_ids: &[usize]) -> Arc { let pctx = wcm.get_pctx(); - let air = pctx - .pilout - .get_air(MEM_UNALIGNED_ROM_AIRGROUP_ID, MEM_UNALIGNED_ROM_AIR_IDS[0]); + let air = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_ALIGN_ROM_AIR_IDS[0]); - let mem_unaligned_rom = Self { + let mem_align_rom = Self { wcm: wcm.clone(), registered_predecessors: AtomicU32::new(0), num_rows: air.num_rows(), - line: 0, + line: Mutex::new(0), multiplicity: Mutex::new(vec![0; air.num_rows()]), }; - let mem_unaligned_rom = Arc::new(mem_unaligned_rom); - wcm.register_component(mem_unaligned_rom.clone(), Some(airgroup_id), Some(air_ids)); + let mem_align_rom = Arc::new(mem_align_rom); + wcm.register_component(mem_align_rom.clone(), Some(airgroup_id), Some(air_ids)); - mem_unaligned_rom + mem_align_rom } pub fn register_predecessor(&self) { @@ -64,74 +66,85 @@ impl MemUnalignedRomSM { } } - pub fn process_slice(&self, input: &[u64]) { - let mut multiplicity = self.multiplicity.lock().unwrap(); - - for (i, val) in input.iter().enumerate() { - multiplicity[i] += *val; - } + pub fn get_op_size(op: MemOps) -> usize { + OP_SIZES[op as usize] } - //lookup_proves(MEM_UNALIGNED_ROM_ID, [OP, OFFSET, A, B, C0, C1], multiplicity); - // lookup_proves(MEM_ALIGN_ROM_ID, [PC, DELTA_PC, DELTA_ADDR, OFFSET, WIDTH, FLAGS], multiplicity); - pub fn calculate_rom_row(opcode: MemUnalignedRomOp, offset: u64, a: u64, b: u64) -> u64 { - // Calculate the different row offset contributors, according to the PIL - assert!(a <= 0xff); - let offset_a: u64 = a; - assert!(offset < 0x08); - let offset_offset: u64 = offset * P2_8; - assert!(b <= 0x3f); - let offset_b: u64 = b * P2_11; - let offset_opcode: u64 = Self::offset_opcode(opcode); - - offset_a + offset_offset + offset_b + offset_opcode + pub fn calculate_rom_rows(opcode: MemOps, offset: usize, width: usize) -> Vec { + match opcode { + MemOps::OneRead | MemOps::OneWrite => { + // Sanity check + assert!(offset + width <= CHUNKS); + let possible_widths = match offset { + x if x <= 4 => vec![1, 2, 4], + x if x <= 6 => vec![1, 2], + x if x == 7 => vec![1], + _ => panic!("Invalid offset={}", offset), + }; + Self::get_rows(opcode, possible_widths, offset, width) + } + MemOps::TwoReads | MemOps::TwoWrites => { + // Sanity check + assert!(offset + width > CHUNKS); + let possible_widths = match offset { + x if x == 0 => panic!("Invalid offset={}", offset), + x if x <= 4 => vec![8], + x if x <= 6 => vec![4, 8], + x if x == 7 => vec![2, 4, 8], + _ => panic!("Invalid offset={}", offset), + }; + Self::get_rows(opcode, possible_widths, offset, width) + } + } } - pub fn get_program(offset: u64, width: u64, is_wr: bool) -> usize { - match (is_wr, offset + width > 8) { - (false, false) => 0, // RV // TODO: Use an enum instead! - (true, false) => 1, // RWV - (false, true) => 2, // RVR - (true, true) => 3, // RWVWR + fn get_rows( + opcode: MemOps, + possible_widths: Vec, + offset: usize, + width: usize, + ) -> Vec { + // Sanity check + assert!(possible_widths.contains(&width)); + + let width_idx = possible_widths.iter().position(|&w| w == width).unwrap(); + let opcode_idx = opcode as usize; + match opcode { + MemOps::OneRead | MemOps::OneWrite => { + let value_row = offset * possible_widths.len() * OP_SIZES[opcode_idx] + + (offset + width_idx + 1) * OP_SIZES[opcode_idx] + - 1; + match opcode { + MemOps::OneRead => vec![value_row - 1, value_row], + MemOps::OneWrite => vec![value_row - 2, value_row - 1, value_row], + _ => unreachable!(), + } + } + MemOps::TwoReads => { + let value_row = offset * possible_widths.len() * OP_SIZES[opcode_idx] + + (offset + width_idx + 1) * OP_SIZES[opcode_idx] + - 2; + return vec![value_row - 1, value_row, value_row + 1]; + } + MemOps::TwoWrites => { + let value_row = offset * possible_widths.len() * OP_SIZES[opcode_idx] + + (offset + width_idx + 1) * OP_SIZES[opcode_idx] + - 3; + return vec![value_row - 2, value_row - 1, value_row, value_row + 1, value_row + 2]; + } } } - pub fn get_program_size(offset: u64, width: u64, is_wr: bool) -> usize { - PROGRAM_SIZES[Self::get_program(offset, width, is_wr)] + pub fn calculate_next_pc(op: MemOps, offset: usize, width: usize) -> usize { + let rows = Self::calculate_rom_rows(op, offset, width); + rows[1] } - // TODO - pub fn calculate_next_pc(offset: u8, width: u8, is_wr: bool) -> u64 { - match (offset, width) { - (x,1) if x < 5 => (x+1) * PROGRAM_SIZES[0] - 1, - (x,2) => 2 * PROGRAM_SIZES[0] - 1, - (x,4) => 3 * PROGRAM_SIZES[0] - 1, - (x,8) => panic!("Aligned Memory access: offset=0, width=8"), - - (1,1) => 4 * PROGRAM_SIZES[0] - 1, - (1,2) => 5 * PROGRAM_SIZES[0] - 1, - (1,4) => 6 * PROGRAM_SIZES[0] - 1, - // (1,8) => 7 * PROGRAM_SIZES[0] - 1, // Two words - - (2,1) => 4 * PROGRAM_SIZES[0] - 1, - (2,2) => 5 * PROGRAM_SIZES[0] - 1, - (2,4) => 6 * PROGRAM_SIZES[0] - 1, - // (2,8) => 7 * PROGRAM_SIZES[0] - 1, // Two words - } - } + pub fn process_slice(&self, input: &[u8]) { + let mut multiplicity = self.multiplicity.lock().unwrap(); - fn offset_opcode(opcode: MemUnalignedRomOp) -> u64 { - match opcode { - MemUnalignedRomOp::Sll => 0, - MemUnalignedRomOp::Srl => P2_19, - MemUnalignedRomOp::Sra => 2 * P2_19, - MemUnalignedRomOp::SllW => 3 * P2_19, - MemUnalignedRomOp::SrlW => 4 * P2_19, - MemUnalignedRomOp::SraW => 5 * P2_19, - MemUnalignedRomOp::SignExtendB => 6 * P2_19, - MemUnalignedRomOp::SignExtendH => 6 * P2_19 + P2_11, - MemUnalignedRomOp::SignExtendW => 6 * P2_19 + 2 * P2_11, - //_ => panic!("MemUnalignedRomSM::offset_opcode() got invalid opcode={:?}", opcode), + for (i, val) in input.iter().enumerate() { + multiplicity[i] += *val as u64; } } @@ -142,11 +155,8 @@ impl MemUnalignedRomSM { let mut multiplicity = self.multiplicity.lock().unwrap(); - let (is_myne, instance_global_idx) = dctx.add_instance( - MEM_UNALIGNED_ROM_AIRGROUP_ID, - MEM_UNALIGNED_ROM_AIR_IDS[0], - 1, - ); + let (is_myne, instance_global_idx) = + dctx.add_instance(ZISK_AIRGROUP_ID, MEM_ALIGN_ROM_AIR_IDS[0], 1); let owner = dctx.owner(instance_global_idx); let mut multiplicity_ = std::mem::take(&mut *multiplicity); @@ -157,8 +167,8 @@ impl MemUnalignedRomSM { let (mut prover_buffer, offset) = create_prover_buffer( &self.wcm.get_ectx(), &self.wcm.get_sctx(), - MEM_UNALIGNED_ROM_AIRGROUP_ID, - MEM_UNALIGNED_ROM_AIR_IDS[0], + ZISK_AIRGROUP_ID, + MEM_ALIGN_ROM_AIR_IDS[0], ); prover_buffer[offset as usize..offset as usize + self.num_rows] @@ -174,8 +184,8 @@ impl MemUnalignedRomSM { let air_instance = AirInstance::new( self.wcm.get_sctx(), - MEM_UNALIGNED_ROM_AIRGROUP_ID, - MEM_UNALIGNED_ROM_AIR_IDS[0], + ZISK_AIRGROUP_ID, + MEM_ALIGN_ROM_AIR_IDS[0], None, prover_buffer, ); @@ -187,4 +197,4 @@ impl MemUnalignedRomSM { } } -impl WitnessComponent for MemUnalignedRomSM {} \ No newline at end of file +impl WitnessComponent for MemAlignRomSM {} diff --git a/state-machines/mem/src/mem_align_sm.rs b/state-machines/mem/src/mem_align_sm.rs index a2f0105c..61b197b5 100644 --- a/state-machines/mem/src/mem_align_sm.rs +++ b/state-machines/mem/src/mem_align_sm.rs @@ -1,91 +1,132 @@ -use std::sync::{ - atomic::{AtomicU32, Ordering}, - Arc, Mutex, +use std::{ + collections::HashMap, + sync::{ + atomic::{AtomicU32, Ordering}, + Arc, Mutex, + }, }; -use p3_field::Field; +use log::info; +use num_bigint::BigInt; +use p3_field::PrimeField; +use pil_std_lib::Std; use proofman::{WitnessComponent, WitnessManager}; -use proofman_common::{ExecutionCtx, ProofCtx, SetupCtx}; +use proofman_common::AirInstance; use rayon::Scope; -use sm_common::{MemUnalignedOp, OpResult, Provable}; + +use sm_common::{create_prover_buffer, OpResult, Provable}; use zisk_core::ZiskRequiredMemory; -use zisk_pil::{MEM_AIRGROUP_ID, MEM_UNALIGNED_AIR_IDS}; +use zisk_pil::{ + MemAlign3Row, MemAlign3Trace, MEM_ALIGN_AIR_IDS, MEM_ALIGN_ROM_AIR_IDS, ZISK_AIRGROUP_ID, +}; + +use crate::MemAlignRomSM; + +#[derive(Debug, Clone, Copy)] +pub enum MemOps { + OneRead, + OneWrite, + TwoReads, + TwoWrites, +} const PROVE_CHUNK_SIZE: usize = 1 << 12; -const CHUNKS: u64 = 8; +const CHUNKS: usize = 8; +const CHUNKS_U64: u64 = CHUNKS as u64; + +pub struct MemAlignSM { + // Witness computation manager + wcm: Arc>, + + // STD + std: Arc>, -pub struct MemUnalignedSM { // Count of registered predecessors registered_predecessors: AtomicU32, // Inputs - inputs: Mutex>, -} + inputs: Mutex>, -#[allow(unused, unused_variables)] -impl MemUnalignedSM { - const MY_NAME: &'static str = "MemUnaligned"; + // Secondary State machines + mem_align_rom_sm: Arc>, +} - pub fn new(wcm: Arc>) -> Arc { - let mem_unaligned_sm = - Self { registered_predecessors: AtomicU32::new(0), inputs: Mutex::new(Vec::new()) }; - let mem_unaligned_sm = Arc::new(mem_unaligned_sm); +impl MemAlignSM { + const MY_NAME: &'static str = "MemAlign"; + + pub fn new( + wcm: Arc>, + std: Arc>, + mem_align_rom_sm: Arc>, + ) -> Arc { + let mem_align_sm = Self { + wcm: wcm.clone(), + std: std.clone(), + registered_predecessors: AtomicU32::new(0), + inputs: Mutex::new(Vec::new()), + mem_align_rom_sm, + }; + let mem_align_sm = Arc::new(mem_align_sm); wcm.register_component( - mem_unaligned_sm.clone(), - Some(MEM_AIRGROUP_ID), - Some(MEM_UNALIGNED_AIR_IDS), + mem_align_sm.clone(), + Some(ZISK_AIRGROUP_ID), + Some(MEM_ALIGN_AIR_IDS), ); - mem_unaligned_sm + // Register the predecessors + std.register_predecessor(); + mem_align_sm.mem_align_rom_sm.register_predecessor(); + + mem_align_sm } pub fn register_predecessor(&self) { self.registered_predecessors.fetch_add(1, Ordering::SeqCst); } - pub fn unregister_predecessor(&self, scope: &Scope) { + pub fn unregister_predecessor(&self) { if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { - >::prove(self, &[], true, scope); + self.mem_align_rom_sm.unregister_predecessor(); + self.std.unregister_predecessor(self.wcm.get_pctx(), None); } } - fn read( - &self, - _addr: u64, - _width: usize, /* , _ctx: &mut ProofCtx, _ectx: &ExecutionCtx */ - ) -> Result> { - Ok((0, true)) - } + #[inline(always)] + pub fn get_mem_ops(unaligned_input: &ZiskRequiredMemory) -> MemOps { + let addr = unaligned_input.address; + let width = unaligned_input.width; - fn write( - &self, - _addr: u64, - _width: usize, - _val: u64, /* , _ctx: &mut ProofCtx, _ectx: &ExecutionCtx */ - ) -> Result> { - Ok((0, true)) + let offset = addr % 8; + + match (unaligned_input.is_write, offset + width > 8) { + (false, false) => MemOps::OneRead, + (true, false) => MemOps::OneWrite, + (false, true) => MemOps::TwoReads, + (true, true) => MemOps::TwoWrites, + } } + #[inline(always)] pub fn process_slice( input: &Vec, - multiplicity: &mut [u64], - range_check: &mut HashMap, - ) -> Vec> { + multiplicity: &mut [u8], + range_check: &mut HashMap, + ) -> Vec> { // Is a write or a read operation - let wr = input[0].is_write; + let _wr = input[0].is_write; // Get the address let addr = input[0].address; let addr_prior = input[1].address; // addr / CHUNKS; - let addr_next = input[2].address; // addr / CHUNKS + CHUNKS; + let addr_next = input[2].address; // addr / CHUNKS + CHUNKS; // Get the value - let value = input[0].value; - let value_first_read = input[1].value; - let value_first_write = input[2].value; - let value_second_read = input[3].value; - let value_second_write = input[4].value; + let value = input[0].value.to_be_bytes(); + let value_first_read = input[1].value.to_be_bytes(); + let value_first_write = input[2].value.to_be_bytes(); + let value_second_read = input[3].value.to_be_bytes(); + let value_second_write = input[4].value.to_be_bytes(); // Get the step let step = input[0].step; @@ -95,28 +136,38 @@ impl MemUnalignedSM { let step_second_write = input[4].step; // Get the offset - let offset = addr % CHUNKS; + let offset = addr % CHUNKS_U64; + let offset = if offset <= usize::MAX as u64 { + offset as usize + } else { + panic!("MemAlignSM::process_slice() got invalid offset={}", offset) + }; // Get the width let width = input[0].width; + let width = if width <= CHUNKS_U64 { + width as usize + } else { + panic!("MemAlignSM::process_slice() got invalid width={}", width) + }; // Compute the shift let shift = (offset + width) % CHUNKS; - // Get the program to be executed, its size and the pc to jump to - let program = MemUnalignedRomSM::get_program(offset, width, wr); - let program_size = MemUnalignedRomSM::get_program_size(offset, width, wr); - let next_pc = MemUnalignedRomSM::calculate_next_pc(offset, width, wr); - - // Initialize and set the rows of the corresponding program - let mut rows: Vec> = Vec::with_capacity(program_size); - match program { - 0 => { // RV - let mut read_row = MemUnalign0Row:: { + // Get the op to be executed, its size and the pc to jump to + let op = Self::get_mem_ops(&input[0]); + let op_size = MemAlignRomSM::::get_op_size(op); + let next_pc = MemAlignRomSM::::calculate_next_pc(op, offset, width); + + // Initialize and set the rows of the corresponding op + let mut rows: Vec> = Vec::with_capacity(op_size); + match op { + MemOps::OneRead => { + // RV + let mut read_row = MemAlign3Row:: { step: F::from_canonical_u64(step_first_read), addr: F::from_canonical_u64(addr_prior), // offset: F::from_canonical_u64(0), - width: F::from_canonical_u64(CHUNKS), // wr: F::from_bool(false), // pc: F::from_canonical_u64(0), reset: F::from_bool(true), @@ -124,40 +175,41 @@ impl MemUnalignedSM { ..Default::default() }; - let mut value_row = MemUnalign0Row:: { + let mut value_row = MemAlign3Row:: { step: F::from_canonical_u64(step), addr: F::from_canonical_u64(addr), - offset: F::from_canonical_u64(offset), - width: F::from_canonical_u64(width), + offset: F::from_canonical_usize(offset), + width: F::from_canonical_usize(width), // wr: F::from_bool(false), - pc: F::from_canonical_u64(next_pc), + pc: F::from_canonical_usize(next_pc), // reset: F::from_bool(false), sel_prove: F::from_bool(true), ..Default::default() }; for i in 0..CHUNKS { - read_row.reg[i] = F::from_canonical_u64(value_first_read[i]); + read_row.reg[i] = F::from_canonical_u8(value_first_read[i]); read_row.sel[i] = F::from_bool(true); - value_row.reg[i] = F::from_canonical_u64(value[shift + i]); + value_row.reg[i] = F::from_canonical_u8(value[shift + i]); value_row.sel[i] = F::from_bool(i == offset); // Store the range check - *range_check.entry(read_row.reg[i]).or_insert(0) += 1; - *range_check.entry(value_row.reg[i]).or_insert(0) += 1; + *range_check.entry(value_first_read[i]).or_insert(0) += 1; + *range_check.entry(value[shift + i]).or_insert(0) += 1; } // Store the rows rows.push(read_row); rows.push(value_row); - }, - 1 => { // RWV - let mut read_row = MemUnalign0Row:: { + } + MemOps::OneWrite => { + // RWV + let mut read_row = MemAlign3Row:: { step: F::from_canonical_u64(step_first_read), addr: F::from_canonical_u64(addr_prior), // offset: F::from_canonical_u64(0), - width: F::from_canonical_u64(CHUNKS), + width: F::from_canonical_u64(CHUNKS_U64), // wr: F::from_bool(false), // pc: F::from_canonical_u64(0), reset: F::from_bool(true), @@ -165,44 +217,44 @@ impl MemUnalignedSM { ..Default::default() }; - let mut write_row = MemUnalign0Row:: { + let mut write_row = MemAlign3Row:: { step: F::from_canonical_u64(step_first_write), addr: F::from_canonical_u64(addr_prior), // offset: F::from_canonical_u64(0), - width: F::from_canonical_u64(CHUNKS), + width: F::from_canonical_u64(CHUNKS_U64), wr: F::from_bool(true), - pc: F::from_canonical_u64(next_pc), + pc: F::from_canonical_usize(next_pc), // reset: F::from_bool(false), sel_up_to_down: F::from_bool(true), ..Default::default() }; - let mut value_row = MemUnalign0Row:: { + let mut value_row = MemAlign3Row:: { step: F::from_canonical_u64(step), addr: F::from_canonical_u64(addr), - offset: F::from_canonical_u64(offset), - width: F::from_canonical_u64(width), + offset: F::from_canonical_usize(offset), + width: F::from_canonical_usize(width), // wr: F::from_bool(false), - pc: F::from_canonical_u64(next_pc + 1), + pc: F::from_canonical_usize(next_pc + 1), // reset: F::from_bool(false), sel_prove: F::from_bool(true), ..Default::default() }; for i in 0..CHUNKS { - read_row.reg[i] = F::from_canonical_u64(value_first_read[i]); + read_row.reg[i] = F::from_canonical_u8(value_first_read[i]); read_row.sel[i] = F::from_bool(i < offset); - write_row.reg[i] = F::from_canonical_u64(value_first_write[i]); + write_row.reg[i] = F::from_canonical_u8(value_first_write[i]); write_row.sel[i] = F::from_bool(i >= offset); - value_row.reg[i] = F::from_canonical_u64(value[shift + i]); + value_row.reg[i] = F::from_canonical_u8(value[shift + i]); value_row.sel[i] = F::from_bool(i == offset); // Store the range check - *range_check.entry(read_row.reg[i]).or_insert(0) += 1; - *range_check.entry(write_row.reg[i]).or_insert(0) += 1; - *range_check.entry(value_row.reg[i]).or_insert(0) += 1; + *range_check.entry(value_first_read[i]).or_insert(0) += 1; + *range_check.entry(value_first_write[i]).or_insert(0) += 1; + *range_check.entry(value[shift + i]).or_insert(0) += 1; } // Store the rows @@ -210,13 +262,13 @@ impl MemUnalignedSM { rows.push(write_row); rows.push(value_row); } - 2 => { + MemOps::TwoReads => { // RVR - let mut first_read_row = MemUnalign0Row:: { + let mut first_read_row = MemAlign3Row:: { step: F::from_canonical_u64(step_first_read), addr: F::from_canonical_u64(addr_prior), // offset: F::from_canonical_u64(0), - width: F::from_canonical_u64(CHUNKS), + width: F::from_canonical_u64(CHUNKS_U64), // wr: F::from_bool(false), // pc: F::from_canonical_u64(0), reset: F::from_bool(true), @@ -224,44 +276,44 @@ impl MemUnalignedSM { ..Default::default() }; - let mut value_row = MemUnalign0Row:: { + let mut value_row = MemAlign3Row:: { step: F::from_canonical_u64(step), addr: F::from_canonical_u64(addr), - offset: F::from_canonical_u64(offset), - width: F::from_canonical_u64(width), + offset: F::from_canonical_usize(offset), + width: F::from_canonical_usize(width), // wr: F::from_bool(false), - pc: F::from_canonical_u64(next_pc), + pc: F::from_canonical_usize(next_pc), // reset: F::from_bool(false), sel_prove: F::from_bool(true), ..Default::default() }; - let mut second_read_row = MemUnalign0Row:: { + let mut second_read_row = MemAlign3Row:: { step: F::from_canonical_u64(step_second_read), addr: F::from_canonical_u64(addr_next), // offset: F::from_canonical_u64(0), - width: F::from_canonical_u64(CHUNKS), + width: F::from_canonical_u64(CHUNKS_U64), // wr: F::from_bool(false), - pc: F::from_canonical_u64(next_pc + 1), + pc: F::from_canonical_usize(next_pc + 1), // reset: F::from_bool(false), sel_down_to_up: F::from_bool(true), ..Default::default() }; for i in 0..CHUNKS { - first_read_row.reg[i] = F::from_canonical_u64(value_first_read[i]); + first_read_row.reg[i] = F::from_canonical_u8(value_first_read[i]); first_read_row.sel[i] = F::from_bool(true); - value_row.reg[i] = F::from_canonical_u64(value[shift + i]); + value_row.reg[i] = F::from_canonical_u8(value[shift + i]); value_row.sel[i] = F::from_bool(i == offset); - second_read_row.reg[i] = F::from_canonical_u64(value_second_read[i]); + second_read_row.reg[i] = F::from_canonical_u8(value_second_read[i]); second_read_row.sel[i] = F::from_bool(true); // Store the range check - *range_check.entry(first_read_row.reg[i]).or_insert(0) += 1; - *range_check.entry(value_row.reg[i]).or_insert(0) += 1; - *range_check.entry(second_read_row.reg[i]).or_insert(0) += 1; + *range_check.entry(value_first_read[i]).or_insert(0) += 1; + *range_check.entry(value[shift + i]).or_insert(0) += 1; + *range_check.entry(value_second_read[i]).or_insert(0) += 1; } // Store the rows @@ -269,13 +321,13 @@ impl MemUnalignedSM { rows.push(value_row); rows.push(second_read_row); } - 3 => { + MemOps::TwoWrites => { // RWVWR - let mut first_read_row = MemUnalign0Row:: { + let mut first_read_row = MemAlign3Row:: { step: F::from_canonical_u64(step_first_read), addr: F::from_canonical_u64(addr_prior), // offset: F::from_canonical_u64(0), - width: F::from_canonical_u64(CHUNKS), + width: F::from_canonical_u64(CHUNKS_U64), // wr: F::from_bool(false), // pc: F::from_canonical_u64(0), reset: F::from_bool(true), @@ -283,76 +335,76 @@ impl MemUnalignedSM { ..Default::default() }; - let mut first_write_row = MemUnalign0Row:: { + let mut first_write_row = MemAlign3Row:: { step: F::from_canonical_u64(step_first_write), addr: F::from_canonical_u64(addr_prior), // offset: F::from_canonical_u64(0), - width: F::from_canonical_u64(CHUNKS), + width: F::from_canonical_u64(CHUNKS_U64), wr: F::from_bool(true), - pc: F::from_canonical_u64(next_pc), + pc: F::from_canonical_usize(next_pc), // reset: F::from_bool(false), sel_up_to_down: F::from_bool(true), ..Default::default() }; - let mut value_row = MemUnalign0Row:: { + let mut value_row = MemAlign3Row:: { step: F::from_canonical_u64(step), addr: F::from_canonical_u64(addr), - offset: F::from_canonical_u64(offset), - width: F::from_canonical_u64(width), + offset: F::from_canonical_usize(offset), + width: F::from_canonical_usize(width), // wr: F::from_bool(false), - pc: F::from_canonical_u64(next_pc + 1), + pc: F::from_canonical_usize(next_pc + 1), // reset: F::from_bool(false), sel_prove: F::from_bool(true), ..Default::default() }; - let mut second_write_row = MemUnalign0Row:: { + let mut second_write_row = MemAlign3Row:: { step: F::from_canonical_u64(step_second_write), addr: F::from_canonical_u64(addr_next), // offset: F::from_canonical_u64(0), - width: F::from_canonical_u64(CHUNKS), + width: F::from_canonical_u64(CHUNKS_U64), wr: F::from_bool(true), - pc: F::from_canonical_u64(next_pc + 2), + pc: F::from_canonical_usize(next_pc + 2), // reset: F::from_bool(false), sel_down_to_up: F::from_bool(true), ..Default::default() }; - let mut second_read_row = MemUnalign0Row:: { + let mut second_read_row = MemAlign3Row:: { step: F::from_canonical_u64(step_second_read), addr: F::from_canonical_u64(addr_next), // offset: F::from_canonical_u64(0), - width: F::from_canonical_u64(CHUNKS), + width: F::from_canonical_u64(CHUNKS_U64), // wr: F::from_bool(false), - pc: F::from_canonical_u64(next_pc + 3), + pc: F::from_canonical_usize(next_pc + 3), reset: F::from_bool(false), sel_down_to_up: F::from_bool(true), ..Default::default() }; for i in 0..CHUNKS { - first_read_row.reg[i] = F::from_canonical_u64(value_first_read[i]); + first_read_row.reg[i] = F::from_canonical_u8(value_first_read[i]); first_read_row.sel[i] = F::from_bool(i < offset); - first_write_row.reg[i] = F::from_canonical_u64(value_first_write[i]); + first_write_row.reg[i] = F::from_canonical_u8(value_first_write[i]); first_write_row.sel[i] = F::from_bool(i >= offset); - value_row.reg[i] = F::from_canonical_u64(value[shift + i]); + value_row.reg[i] = F::from_canonical_u8(value[shift + i]); value_row.sel[i] = F::from_bool(i == offset); - second_write_row.reg[i] = F::from_canonical_u64(value_second_write[i]); + second_write_row.reg[i] = F::from_canonical_u8(value_second_write[i]); second_write_row.sel[i] = F::from_bool(i < shift); - second_read_row.reg[i] = F::from_canonical_u64(value_second_read[i]); + second_read_row.reg[i] = F::from_canonical_u8(value_second_read[i]); second_read_row.sel[i] = F::from_bool(i >= shift); // Store the range check - *range_check.entry(first_read_row.reg[i]).or_insert(0) += 1; - *range_check.entry(first_write_row.reg[i]).or_insert(0) += 1; - *range_check.entry(value_row.reg[i]).or_insert(0) += 1; - *range_check.entry(second_write_row.reg[i]).or_insert(0) += 1; - *range_check.entry(second_read_row.reg[i]).or_insert(0) += 1; + *range_check.entry(value_first_read[i]).or_insert(0) += 1; + *range_check.entry(value_first_write[i]).or_insert(0) += 1; + *range_check.entry(value[shift + i]).or_insert(0) += 1; + *range_check.entry(value_second_write[i]).or_insert(0) += 1; + *range_check.entry(value_second_read[i]).or_insert(0) += 1; } // Store the rows @@ -362,19 +414,13 @@ impl MemUnalignedSM { rows.push(second_write_row); rows.push(second_read_row); } - _ => panic!("MemUnalignedSM::process_slice() got invalid program={}", program), } - // TBD - // for (i, a_byte) in a_bytes.iter().enumerate() { - // let row = MemUnalignedRomSM::::calculate_table_row( - // mem_unaligned_rom_op, - // i as u64, - // *a_byte as u64, - // in2_low, - // ); - // multiplicity[row as usize] += 1; - // } + // Compute and store the ROM row multiplicity + let rom_rows = MemAlignRomSM::::calculate_rom_rows(op, offset, width); + for &row in rom_rows.iter() { + multiplicity[row] += 1; + } // Return successfully rows @@ -382,15 +428,15 @@ impl MemUnalignedSM { pub fn prove_instance( &self, - operations: Vec, + inputs: Vec, prover_buffer: &mut [F], offset: u64, ) { Self::prove_internal( &self.wcm, - &self.mem_unaligned_rom_sm, + &self.mem_align_rom_sm, &self.std, - operations, + inputs, prover_buffer, offset, ); @@ -398,114 +444,97 @@ impl MemUnalignedSM { fn prove_internal( wcm: &WitnessManager, - mem_unaligned_rom_sm: &MemUnalignedRomSM, + mem_align_rom_sm: &MemAlignRomSM, std: &Std, - operations: Vec, + inputs: Vec, prover_buffer: &mut [F], offset: u64, ) { let pctx = wcm.get_pctx(); - let air = pctx.pilout.get_air(MEM_UNALIGNED_AIRGROUP_ID, MEM_UNALIGNED_AIR_IDS[0]); - let air_mem_unaligned_rom = pctx - .pilout - .get_air(MEM_UNALIGNED_ROM_AIRGROUP_ID, MEM_UNALIGNED_ROM_AIR_IDS[0]); - assert!(operations.len() <= air.num_rows()); + let air_mem_align = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_ALIGN_AIR_IDS[0]); + let air_mem_align_rom = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_ALIGN_ROM_AIR_IDS[0]); + assert!(inputs.len() <= air_mem_align.num_rows()); info!( - "{}: ··· Creating Binary extension instance [{} / {} rows filled {:.2}%]", + "{}: ··· Creating Mem Align instance [{} / {} rows filled {:.2}%]", Self::MY_NAME, - operations.len(), - air.num_rows(), - operations.len() as f64 / air.num_rows() as f64 * 100.0 + inputs.len(), + air_mem_align.num_rows(), + inputs.len() as f64 / air_mem_align.num_rows() as f64 * 100.0 ); - let mut multiplicity_table = vec![0u64; air_mem_unaligned_rom.num_rows()]; - let mut range_check: HashMap = HashMap::new(); - let mut trace_buffer = - BinaryExtension0Trace::::map_buffer(prover_buffer, air.num_rows(), offset as usize) - .unwrap(); - - for (i, operation) in operations.iter().enumerate() { - let row = Self::process_slice(operation, &mut multiplicity_table, &mut range_check); - trace_buffer[i] = row; + // let mut multiplicity_table = vec![0u8; air_mem_align_rom.num_rows()]; + let mut multiplicity_table = vec![0u8; air_mem_align_rom.num_rows()]; + let mut range_check: HashMap = HashMap::new(); + let mut trace_buffer = MemAlign3Trace::::map_buffer( + prover_buffer, + air_mem_align.num_rows(), + offset as usize, + ) + .unwrap(); + + // Process the inputs while saving the multiplcities and range checks + let mut rows_processed = 0; + let rows = Self::process_slice(&inputs, &mut multiplicity_table, &mut range_check); + for (i, &row) in rows.iter().enumerate() { + trace_buffer[rows_processed + i] = row; } + rows_processed += rows.len(); - let padding_row = - BinaryExtension0Row:: { op: F::from_canonical_u64(0x25), ..Default::default() }; + // Pad the remaining rows with trivailly satisfying rows + let padding_row = MemAlign3Row::::default(); - for i in operations.len()..air.num_rows() { + for i in rows_processed..air_mem_align.num_rows() { trace_buffer[i] = padding_row; } - let padding_size = air.num_rows() - operations.len(); - for i in 0..8 { - let multiplicity = padding_size as u64; - let row = MemUnalignedRomSM::::calculate_table_row( - BinaryExtensionTableOp::SignExtendW, - i, - 0, - 0, - ); - multiplicity_table[row as usize] += multiplicity; - } - - mem_unaligned_rom_sm.process_slice(&multiplicity_table); + // TODO: Store the padding multiplicity + let _padding_size = air_mem_align.num_rows() - rows_processed; + // for i in 0..8 { + // let multiplicity = padding_size as u64; + // let row = MemAlignRomSM::::calculate_rom_row( + // op, offset, width + // ); + // multiplicity_table[row as usize] += multiplicity; + // } - let range_id = std.get_range(BigInt::from(0), BigInt::from(0xFFFFFF), None); + // Compute the ROM multiplicities + mem_align_rom_sm.process_slice(&multiplicity_table); - for (value, multiplicity) in &range_check { + // Perform the range checks + let range_id = std.get_range(BigInt::from(0), BigInt::from(0xFF), None); + for (&value, &multiplicity) in range_check.iter() { std.range_check( - F::from_canonical_u64(*value), - F::from_canonical_u64(*multiplicity), + F::from_canonical_u8(value), + F::from_canonical_u64(multiplicity), range_id, ); } - std::thread::spawn(move || { - drop(operations); + drop(inputs); drop(multiplicity_table); drop(range_check); }); } } -impl WitnessComponent for MemUnalignedSM { - fn calculate_witness( - &self, - _stage: u32, - _air_instance: Option, - _pctx: Arc>, - _ectx: Arc, - _sctx: Arc, - ) { - } -} - -impl Provable for MemUnalignedSM { - fn calculate(&self, operation: MemUnalignedOp) -> Result> { - // TODO: Perform the aligned read/writes - - match operation { - MemUnalignedOp::Read(addr, width) => self.read(addr, width), - MemUnalignedOp::Write(addr, width, val) => self.write(addr, width, val), - } - } +impl WitnessComponent for MemAlignSM {} +impl Provable for MemAlignSM { fn prove(&self, operations: &[ZiskRequiredMemory], drain: bool, _scope: &Scope) { if let Ok(mut inputs) = self.inputs.lock() { inputs.extend_from_slice(operations); let pctx = self.wcm.get_pctx(); - let air = - pctx.pilout.get_air(MEM_UNALIGNED_AIRGROUP_ID, MEM_UNALIGNED_AIR_IDS[0]); + let air_mem_align = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_ALIGN_AIR_IDS[0]); - while inputs.len() >= air.num_rows() || (drain && !inputs.is_empty()) { - let num_drained = std::cmp::min(air.num_rows(), inputs.len()); + while inputs.len() >= air_mem_align.num_rows() || (drain && !inputs.is_empty()) { + let num_drained = std::cmp::min(air_mem_align.num_rows(), inputs.len()); let drained_inputs = inputs.drain(..num_drained).collect::>(); - let mem_unaligned_rom_sm = self.mem_unaligned_rom_sm.clone(); + let mem_align_rom_sm = self.mem_align_rom_sm.clone(); let wcm = self.wcm.clone(); let std = self.std.clone(); @@ -515,13 +544,13 @@ impl Provable for MemUnalignedSM { let (mut prover_buffer, offset) = create_prover_buffer( &wcm.get_ectx(), &wcm.get_sctx(), - MEM_UNALIGNED_AIRGROUP_ID, - MEM_UNALIGNED_AIR_IDS[0], + ZISK_AIRGROUP_ID, + MEM_ALIGN_AIR_IDS[0], ); Self::prove_internal( &wcm, - &mem_unaligned_rom_sm, + &mem_align_rom_sm, &std, drained_inputs, &mut prover_buffer, @@ -530,8 +559,8 @@ impl Provable for MemUnalignedSM { let air_instance = AirInstance::new( sctx, - MEM_UNALIGNED_AIRGROUP_ID, - MEM_UNALIGNED_AIR_IDS[0], + ZISK_AIRGROUP_ID, + MEM_ALIGN_AIR_IDS[0], None, prover_buffer, ); @@ -539,17 +568,4 @@ impl Provable for MemUnalignedSM { } } } - - fn calculate_prove( - &self, - operation: MemUnalignedOp, - drain: bool, - scope: &Scope, - ) -> Result> { - let result = self.calculate(operation.clone()); - - self.prove(&[operation], drain, scope); - - result - } -} \ No newline at end of file +} From 3ccfde21a826137cd63fab49590c88348d9b78c6 Mon Sep 17 00:00:00 2001 From: Xavier Pinsach <10213118+xavi-pinsach@users.noreply.github.com> Date: Wed, 6 Nov 2024 08:32:27 +0000 Subject: [PATCH 06/44] wip --- Cargo.lock | 10 ++ Cargo.toml | 24 ++-- state-machines/main/src/main_sm.rs | 18 ++- state-machines/mem/src/mem_proxy.rs | 182 +++++++++++++++++++++++++--- state-machines/mem/src/mem_sm.rs | 33 ++--- witness-computation/src/executor.rs | 17 ++- 6 files changed, 214 insertions(+), 70 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c9705649..ffe2fb46 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1474,6 +1474,7 @@ dependencies = [ [[package]] name = "pil-std-lib" version = "0.1.0" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#e87671a5ca63cd1f312554aeca4b1e4b0a3905a9" dependencies = [ "log", "num-bigint", @@ -1491,6 +1492,7 @@ dependencies = [ [[package]] name = "pilout" version = "0.1.0" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#e87671a5ca63cd1f312554aeca4b1e4b0a3905a9" dependencies = [ "bytes", "log", @@ -1610,6 +1612,7 @@ dependencies = [ [[package]] name = "proofman" version = "0.1.0" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#e87671a5ca63cd1f312554aeca4b1e4b0a3905a9" dependencies = [ "colored", "env_logger", @@ -1630,6 +1633,7 @@ dependencies = [ [[package]] name = "proofman-common" version = "0.1.0" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#e87671a5ca63cd1f312554aeca4b1e4b0a3905a9" dependencies = [ "env_logger", "log", @@ -1647,6 +1651,7 @@ dependencies = [ [[package]] name = "proofman-hints" version = "0.1.0" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#e87671a5ca63cd1f312554aeca4b1e4b0a3905a9" dependencies = [ "p3-field", "proofman-common", @@ -1656,6 +1661,7 @@ dependencies = [ [[package]] name = "proofman-macros" version = "0.1.0" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#e87671a5ca63cd1f312554aeca4b1e4b0a3905a9" dependencies = [ "proc-macro2", "quote", @@ -1665,6 +1671,7 @@ dependencies = [ [[package]] name = "proofman-starks-lib-c" version = "0.1.0" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#e87671a5ca63cd1f312554aeca4b1e4b0a3905a9" dependencies = [ "log", ] @@ -1672,6 +1679,7 @@ dependencies = [ [[package]] name = "proofman-util" version = "0.1.0" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#e87671a5ca63cd1f312554aeca4b1e4b0a3905a9" dependencies = [ "colored", "sysinfo 0.31.4", @@ -2306,6 +2314,7 @@ checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" [[package]] name = "stark" version = "0.1.0" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#e87671a5ca63cd1f312554aeca4b1e4b0a3905a9" dependencies = [ "log", "p3-field", @@ -2635,6 +2644,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#e87671a5ca63cd1f312554aeca4b1e4b0a3905a9" dependencies = [ "proofman-starks-lib-c", ] diff --git a/Cargo.toml b/Cargo.toml index da96aae8..f5ce1892 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,19 +26,19 @@ opt-level = 3 opt-level = 3 [workspace.dependencies] -# proofman-common = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", branch ="develop" } -# proofman-macros = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", branch ="develop" } -# proofman-util = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", branch ="develop" } -# proofman = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", branch ="develop" } -# pil-std-lib = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", branch ="develop" } -# stark = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", branch ="develop" } +proofman-common = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", branch ="develop" } +proofman-macros = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", branch ="develop" } +proofman-util = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", branch ="develop" } +proofman = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", branch ="develop" } +pil-std-lib = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", branch ="develop" } +stark = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", branch ="develop" } #Local development -proofman-common = { path = "../pil2-proofman/common" } -proofman-macros = { path = "../pil2-proofman/macros" } -proofman-util = { path = "../pil2-proofman/util" } -proofman = { path = "../pil2-proofman/proofman" } -pil-std-lib = { path = "../pil2-proofman/pil2-components/lib/std/rs" } -stark = { path = "../pil2-proofman/provers/stark" } +# proofman-common = { path = "../pil2-proofman/common" } +# proofman-macros = { path = "../pil2-proofman/macros" } +# proofman-util = { path = "../pil2-proofman/util" } +# proofman = { path = "../pil2-proofman/proofman" } +# pil-std-lib = { path = "../pil2-proofman/pil2-components/lib/std/rs" } +# stark = { path = "../pil2-proofman/provers/stark" } p3-field = { git = "https://github.com/Plonky3/Plonky3.git", rev = "c3d754ef77b9fce585b46b972af751fe6e7a9803" } log = "0.4" diff --git a/state-machines/main/src/main_sm.rs b/state-machines/main/src/main_sm.rs index c7d15619..3283e69a 100644 --- a/state-machines/main/src/main_sm.rs +++ b/state-machines/main/src/main_sm.rs @@ -12,10 +12,8 @@ use proofman_common::{AirInstance, ProofCtx}; use proofman::WitnessComponent; use sm_arith::ArithSM; -use sm_mem::MemSM; use zisk_pil::{ - Main0Row, Main0Trace, BINARY_AIRGROUP_ID, BINARY_AIR_IDS, BINARY_EXTENSION_AIRGROUP_ID, - BINARY_EXTENSION_AIR_IDS, MAIN_AIRGROUP_ID, MAIN_AIR_IDS, + Main0Row, Main0Trace, BINARY_AIR_IDS, BINARY_EXTENSION_AIR_IDS, MAIN_AIR_IDS, ZISK_AIRGROUP_ID, }; use ziskemu::{Emu, EmuTrace, ZiskEmulator}; @@ -56,7 +54,7 @@ impl MainSM { ) -> Arc { let main_sm = Arc::new(Self { wcm: wcm.clone(), arith_sm, binary_sm }); - wcm.register_component(main_sm.clone(), Some(MAIN_AIRGROUP_ID), Some(MAIN_AIR_IDS)); + wcm.register_component(main_sm.clone(), Some(ZISK_AIRGROUP_ID), Some(MAIN_AIR_IDS)); // For all the secondary state machines, register the main state machine as a predecessor main_sm.binary_sm.register_predecessor(); @@ -76,7 +74,7 @@ impl MainSM { let segment_trace = &vec_traces[segment_id]; let offset = iectx.offset; - let air = pctx.pilout.get_air(MAIN_AIRGROUP_ID, MAIN_AIR_IDS[0]); + let air = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MAIN_AIR_IDS[0]); let filled = segment_trace.steps.len() + 1; info!( "{}: ··· Creating Main segment #{} [{} / {} rows filled {:.2}%]", @@ -170,7 +168,7 @@ impl MainSM { let sctx = self.wcm.get_sctx(); let mut air_instance = AirInstance::new( sctx.clone(), - MAIN_AIRGROUP_ID, + ZISK_AIRGROUP_ID, MAIN_AIR_IDS[0], Some(segment_id), buffer, @@ -192,7 +190,7 @@ impl MainSM { iectx: &mut InstanceExtensionCtx, pctx: &ProofCtx, ) { - let air = pctx.pilout.get_air(BINARY_AIRGROUP_ID, BINARY_AIR_IDS[0]); + let air = pctx.pilout.get_air(ZISK_AIRGROUP_ID, BINARY_AIR_IDS[0]); timer_start_debug!(PROCESS_BINARY); let inputs = ZiskEmulator::process_slice_required::( @@ -212,7 +210,7 @@ impl MainSM { let buffer = std::mem::take(&mut iectx.prover_buffer); iectx.air_instance = Some(AirInstance::new( self.wcm.get_sctx(), - BINARY_AIRGROUP_ID, + ZISK_AIRGROUP_ID, BINARY_AIR_IDS[0], None, buffer, @@ -227,7 +225,7 @@ impl MainSM { iectx: &mut InstanceExtensionCtx, pctx: &ProofCtx, ) { - let air = pctx.pilout.get_air(BINARY_EXTENSION_AIRGROUP_ID, BINARY_EXTENSION_AIR_IDS[0]); + let air = pctx.pilout.get_air(ZISK_AIRGROUP_ID, BINARY_EXTENSION_AIR_IDS[0]); let inputs = ZiskEmulator::process_slice_required::( zisk_rom, @@ -242,7 +240,7 @@ impl MainSM { let buffer = std::mem::take(&mut iectx.prover_buffer); iectx.air_instance = Some(AirInstance::new( self.wcm.get_sctx(), - BINARY_EXTENSION_AIRGROUP_ID, + ZISK_AIRGROUP_ID, BINARY_EXTENSION_AIR_IDS[0], None, buffer, diff --git a/state-machines/mem/src/mem_proxy.rs b/state-machines/mem/src/mem_proxy.rs index 73dc7394..4a3466f8 100644 --- a/state-machines/mem/src/mem_proxy.rs +++ b/state-machines/mem/src/mem_proxy.rs @@ -1,28 +1,26 @@ use std::sync::{ atomic::{AtomicU32, Ordering}, - Arc, Mutex, + Arc, }; use crate::{MemAlignSM, MemSM}; use p3_field::PrimeField; use proofman_util::{timer_start_debug, timer_stop_and_log_debug}; -use sm_common::{MemOp, MemUnalignedOp}; use zisk_core::ZiskRequiredMemory; use proofman::{WitnessComponent, WitnessManager}; -#[allow(dead_code)] -const PROVE_CHUNK_SIZE: usize = 1 << 12; +pub enum MemOps { + OneRead, + OneWrite, + TwoReads, + TwoWrites, +} -#[allow(dead_code)] pub struct MemProxy { // Count of registered predecessors registered_predecessors: AtomicU32, - // Inputs - inputs_aligned: Mutex>, - inputs_unaligned: Mutex>, - // Secondary State machines mem_sm: Arc>, mem_align_sm: Arc, @@ -35,8 +33,6 @@ impl MemProxy { let mem_proxy = Self { registered_predecessors: AtomicU32::new(0), - inputs_aligned: Mutex::new(Vec::new()), - inputs_unaligned: Mutex::new(Vec::new()), mem_sm: mem_sm.clone(), mem_align_sm: mem_align_sm.clone(), }; @@ -68,7 +64,7 @@ impl MemProxy { ) -> Result<(), Box> { let mut aligned = std::mem::take(&mut operations[0]); let non_aligned = std::mem::take(&mut operations[1]); - let new_aligned = Vec::new(); + let mut new_aligned = Vec::new(); // Step 1. Sort the aligned memory accesses timer_start_debug!(MEM_SORT); @@ -76,14 +72,17 @@ impl MemProxy { timer_stop_and_log_debug!(MEM_SORT); // Step 2. For each non-aligned memory access - non_aligned.iter().for_each(|mem| { + non_aligned.iter().for_each(|unaligned_access| { + let mem_ops = Self::get_mem_ops(unaligned_access); + // Step 2.1 Find the possible aligned memory access - let potential_aligned_mem = self.get_potential_aligned_mem(&aligned, &mem); + let aligned_accesses = self.get_aligned_accesses(&unaligned_access, mem_ops, &aligned); // Step 2.2 Align memory access using mem_align state machine // self.mem_aligned_sm.align_mem_accesses(potential_aligned_mem, mem, &mut new_aligned); // Step 2.3 Store the new aligned memory access(es) + new_aligned.extend(aligned_accesses); }); // Step 3. Concatenate the new aligned memory accesses with the original aligned memory accesses @@ -95,13 +94,160 @@ impl MemProxy { Ok(()) } - fn get_potential_aligned_mem( + #[inline(always)] + fn get_aligned_accesses( &self, - aligned_accesses: &[ZiskRequiredMemory], unaligned_access: &ZiskRequiredMemory, + mem_ops: MemOps, + aligned_accesses: &[ZiskRequiredMemory], ) -> Vec { - let mut aligned_mem = Vec::new(); - aligned_mem + // Align down to a 8 byte addres + let addr = unaligned_access.address & !7; + match mem_ops { + MemOps::OneRead => { + // Look for last write to the same address + let last_write_addr = Self::get_last_write(addr, aligned_accesses); + let last_write_addr = last_write_addr.unwrap_or(ZiskRequiredMemory { + step: unaligned_access.step, + is_write: false, + address: addr, + width: 8, + value: 0, + }); + vec![last_write_addr] + } + MemOps::OneWrite => { + // Look for last write to the same address + let last_write_addr = Self::get_last_write(addr, aligned_accesses); + + // Modify the value of the write to the same address + let mut last_write_addr = last_write_addr.unwrap_or(ZiskRequiredMemory { + step: unaligned_access.step, + is_write: true, + address: addr, + width: 8, + value: 0, + }); + + Self::write_value(&unaligned_access, &mut last_write_addr); + vec![last_write_addr] + } + MemOps::TwoReads => { + // Look for last write to the same address and same address + 8 + let last_write_addr = Self::get_last_write(addr, aligned_accesses); + let last_write_addr_p = Self::get_last_write(addr + 8, aligned_accesses); + + let last_write_addr = last_write_addr.unwrap_or(ZiskRequiredMemory { + step: unaligned_access.step, + is_write: false, + address: addr, + width: 8, + value: 0, + }); + + let last_write_addr_p = last_write_addr_p.unwrap_or(ZiskRequiredMemory { + step: unaligned_access.step, + is_write: false, + address: addr + 8, + width: 8, + value: 0, + }); + + vec![last_write_addr, last_write_addr_p] + } + MemOps::TwoWrites => { + // Look for last write to the same address and same address + 8 + let last_write_addr = Self::get_last_write(addr, aligned_accesses); + let last_write_addr_p = Self::get_last_write(addr + 8, aligned_accesses); + + let mut last_write_addr = last_write_addr.unwrap_or(ZiskRequiredMemory { + step: unaligned_access.step, + is_write: true, + address: addr, + width: 8, + value: 1, + }); + + let mut last_write_addr_p = last_write_addr_p.unwrap_or(ZiskRequiredMemory { + step: unaligned_access.step, + is_write: true, + address: addr + 8, + width: 8, + value: 1, + }); + + Self::write_values(&unaligned_access, &mut last_write_addr, &mut last_write_addr_p); + vec![last_write_addr, last_write_addr_p] + } + } + } + + #[inline(always)] + fn get_last_write( + addr: u64, + aligned_accesses: &[ZiskRequiredMemory], + ) -> Option { + Some( + aligned_accesses + .iter() + .rev() + .find(|mem| mem.address == addr && mem.is_write) + .unwrap() + .clone(), + ) + } + + #[inline(always)] + fn write_value(unaligned: &ZiskRequiredMemory, aligned: &mut ZiskRequiredMemory) { + let offset = 8 - (unaligned.address & 7); + let width_in_bits = unaligned.width * 8; + + let mask = !(((1u64 << width_in_bits) - 1) << ((offset - unaligned.width) * 8)); + + aligned.value = + (aligned.value & mask) | (unaligned.value << ((offset - unaligned.width) * 8)); + } + + #[inline(always)] + fn write_values( + unaligned: &ZiskRequiredMemory, + aligned: &mut ZiskRequiredMemory, + aligned_next: &mut ZiskRequiredMemory, + ) { + let offset = unaligned.address & 7; + let bytes_to_write = 8 - offset; + let right_bits = (unaligned.width - bytes_to_write) * 8; + + // Left write + let left_value = unaligned.value >> right_bits; + let left_memory = + ZiskRequiredMemory { width: bytes_to_write, value: left_value, ..*unaligned }; + Self::write_value(&left_memory, aligned); + + // Right write + let mask = (1u64 << right_bits) - 1; + let right_value = unaligned.value & mask; + + let right_memory = ZiskRequiredMemory { + address: 0, + width: unaligned.width - bytes_to_write, + value: right_value, + ..*unaligned + }; + Self::write_value(&right_memory, aligned_next); + } + + #[inline(always)] + pub fn get_mem_ops(input: &ZiskRequiredMemory) -> MemOps { + let addr = input.address; + let width = input.width; + let offset = addr & 7; + match (input.is_write, offset + width > 8) { + (false, false) => MemOps::OneRead, + (true, false) => MemOps::OneWrite, + (false, true) => MemOps::TwoReads, + (true, true) => MemOps::TwoWrites, + } } } diff --git a/state-machines/mem/src/mem_sm.rs b/state-machines/mem/src/mem_sm.rs index 145dd873..d16fa680 100644 --- a/state-machines/mem/src/mem_sm.rs +++ b/state-machines/mem/src/mem_sm.rs @@ -9,7 +9,7 @@ use proofman_common::AirInstance; use proofman_util::{timer_start_debug, timer_stop_and_log_debug}; use rayon::prelude::*; -use sm_common::{create_prover_buffer, MemOp}; +use sm_common::create_prover_buffer; use zisk_core::ZiskRequiredMemory; use zisk_pil::{Mem0Trace, MEM_AIRGROUP_ID, MEM_AIR_IDS, ZISK_AIRGROUP_ID}; @@ -19,22 +19,12 @@ pub struct MemSM { // Count of registered predecessors registered_predecessors: AtomicU32, - - // Inputs - inputs: Mutex>, - - _phantom: std::marker::PhantomData, } #[allow(unused, unused_variables)] impl MemSM { pub fn new(wcm: Arc>) -> Arc { - let mem_sm = Self { - wcm: wcm.clone(), - registered_predecessors: AtomicU32::new(0), - inputs: Mutex::new(Vec::new()), - _phantom: std::marker::PhantomData, - }; + let mem_sm = Self { wcm: wcm.clone(), registered_predecessors: AtomicU32::new(0) }; let mem_sm = Arc::new(mem_sm); wcm.register_component(mem_sm.clone(), Some(MEM_AIRGROUP_ID), Some(MEM_AIR_IDS)); @@ -56,18 +46,18 @@ impl MemSM { mem_accesses.sort_by_key(|mem| mem.address); timer_stop_and_log_debug!(MEM_SORT_2); - let air = self.wcm.get_pctx().pilout.get_air(MEM_AIRGROUP_ID, MEM_AIR_IDS[0]); + let pctx = self.wcm.get_pctx(); + let ectx = self.wcm.get_ectx(); + let sctx = self.wcm.get_sctx(); + + let air = pctx.pilout.get_air(MEM_AIRGROUP_ID, MEM_AIR_IDS[0]); let num_chunks = (mem_accesses.len() as f64 / (air.num_rows() - 1) as f64).ceil() as usize; - let mut prover_buffers = vec![Vec::new(); num_chunks]; + let mut prover_buffers = Mutex::new(vec![Vec::new(); num_chunks]); let mut offsets = vec![0; num_chunks]; let mut global_idxs = vec![0; num_chunks]; - let pctx = self.wcm.get_pctx(); - let ectx = self.wcm.get_ectx(); - let sctx = self.wcm.get_sctx(); - for i in 0..num_chunks { if let (true, global_idx) = self.wcm.get_ectx().dctx.write().unwrap().add_instance( ZISK_AIRGROUP_ID, @@ -76,8 +66,8 @@ impl MemSM { ) { let (buffer, offset) = create_prover_buffer::(&ectx, &sctx, ZISK_AIRGROUP_ID, MEM_AIR_IDS[0]); - - prover_buffers.push(buffer); + + prover_buffers.lock().unwrap().push(buffer); offsets.push(offset); global_idxs.push(global_idx); } @@ -91,12 +81,13 @@ impl MemSM { mem_accesses[segment_id * ((air.num_rows() - 1) - 1)].clone() }; + let prover_buffer = std::mem::take(&mut prover_buffers.lock().unwrap()[segment_id]); self.prove_instance( mem_ops, mem_first_row, segment_id, segment_id == mem_accesses.len() - 1, - prover_buffers[segment_id], + prover_buffer, offsets[segment_id], global_idxs[segment_id], ); diff --git a/witness-computation/src/executor.rs b/witness-computation/src/executor.rs index f06d0c87..9262ee62 100644 --- a/witness-computation/src/executor.rs +++ b/witness-computation/src/executor.rs @@ -20,8 +20,7 @@ use std::{ }; use zisk_core::{Riscv2zisk, ZiskOperationType, ZiskRom, ZISK_OPERATION_TYPE_VARIANTS}; use zisk_pil::{ - BINARY_AIRGROUP_ID, BINARY_AIR_IDS, BINARY_EXTENSION_AIRGROUP_ID, BINARY_EXTENSION_AIR_IDS, - MAIN_AIRGROUP_ID, MAIN_AIR_IDS, ROM_AIRGROUP_ID, ROM_AIR_IDS, + BINARY_AIR_IDS, BINARY_EXTENSION_AIR_IDS, MAIN_AIR_IDS, ROM_AIR_IDS, ZISK_AIRGROUP_ID, }; use ziskemu::{EmuOptions, ZiskEmulator}; @@ -105,7 +104,7 @@ impl ZiskExecutor { ectx: Arc, sctx: Arc, ) { - let air_main = pctx.pilout.get_air(MAIN_AIRGROUP_ID, MAIN_AIR_IDS[0]); + let air_main = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MAIN_AIR_IDS[0]); // Prepare the settings for the emulator let emu_options = EmuOptions { @@ -130,9 +129,9 @@ impl ZiskExecutor { // machine. We aim to track the starting point of execution for every N instructions // across different operation types. Currently, we are only collecting data for // Binary and BinaryE operations. - let air_binary = pctx.pilout.get_air(BINARY_AIRGROUP_ID, BINARY_AIR_IDS[0]); + let air_binary = pctx.pilout.get_air(ZISK_AIRGROUP_ID, BINARY_AIR_IDS[0]); let air_binary_e = - pctx.pilout.get_air(BINARY_EXTENSION_AIRGROUP_ID, BINARY_EXTENSION_AIR_IDS[0]); + pctx.pilout.get_air(ZISK_AIRGROUP_ID, BINARY_EXTENSION_AIR_IDS[0]); let mut op_sizes = [0u64; ZISK_OPERATION_TYPE_VARIANTS]; // The starting points for the Main is allocated using None operation @@ -205,7 +204,7 @@ impl ZiskExecutor { // ROM State Machine // ---------------------------------------------- let (rom_is_mine, rom_instance_gid) = - ectx.dctx.write().unwrap().add_instance(ROM_AIRGROUP_ID, ROM_AIR_IDS[0], 1); + ectx.dctx.write().unwrap().add_instance(ZISK_AIRGROUP_ID, ROM_AIR_IDS[0], 1); let rom_thread = if rom_is_mine { let rom_sm = self.rom_sm.clone(); @@ -228,10 +227,10 @@ impl ZiskExecutor { let mut main_segnent_id = 0; for emu_slice in emu_slices.points.iter() { let (airgroup_id, air_id) = match emu_slice.op_type { - ZiskOperationType::None => (MAIN_AIRGROUP_ID, MAIN_AIR_IDS[0]), - ZiskOperationType::Binary => (BINARY_AIRGROUP_ID, BINARY_AIR_IDS[0]), + ZiskOperationType::None => (ZISK_AIRGROUP_ID, MAIN_AIR_IDS[0]), + ZiskOperationType::Binary => (ZISK_AIRGROUP_ID, BINARY_AIR_IDS[0]), ZiskOperationType::BinaryE => { - (BINARY_EXTENSION_AIRGROUP_ID, BINARY_EXTENSION_AIR_IDS[0]) + (ZISK_AIRGROUP_ID, BINARY_EXTENSION_AIR_IDS[0]) } _ => panic!("Invalid operation type"), }; From c7a6cf9bc0dbe641329e9aca6aae8ddc8f44fea0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9ctor=20Masip?= Date: Wed, 6 Nov 2024 10:04:30 +0000 Subject: [PATCH 07/44] Working in the new version --- pil/src/pil_helpers/pilout.rs | 4 +- pil/src/pil_helpers/traces.rs | 20 +-- state-machines/mem/pil/mem_align.pil | 32 ++-- state-machines/mem/pil/mem_align_rom.pil | 18 +-- state-machines/mem/src/mem_align_rom_sm.rs | 86 ++++++----- state-machines/mem/src/mem_align_sm.rs | 161 ++++++++++----------- 6 files changed, 162 insertions(+), 159 deletions(-) diff --git a/pil/src/pil_helpers/pilout.rs b/pil/src/pil_helpers/pilout.rs index 2df7024e..bae5828f 100644 --- a/pil/src/pil_helpers/pilout.rs +++ b/pil/src/pil_helpers/pilout.rs @@ -2,7 +2,7 @@ // Manual modifications are not recommended and may be overwritten. use proofman_common::WitnessPilout; -pub const PILOUT_HASH: &[u8] = b"ZiskMem-hash"; +pub const PILOUT_HASH: &[u8] = b"Zisk-hash"; //AIRGROUP CONSTANTS @@ -36,7 +36,7 @@ pub struct Pilout; impl Pilout { pub fn pilout() -> WitnessPilout { - let mut pilout = WitnessPilout::new("ZiskMem", 2, PILOUT_HASH.to_vec()); + let mut pilout = WitnessPilout::new("Zisk", 2, PILOUT_HASH.to_vec()); let air_group = pilout.add_air_group(Some("Zisk")); diff --git a/pil/src/pil_helpers/traces.rs b/pil/src/pil_helpers/traces.rs index f1af3709..7d45888d 100644 --- a/pil/src/pil_helpers/traces.rs +++ b/pil/src/pil_helpers/traces.rs @@ -7,42 +7,42 @@ trace!(Main0Row, Main0Trace { a: [F; 2], b: [F; 2], c: [F; 2], flag: F, pc: F, a_src_imm: F, a_src_mem: F, a_offset_imm0: F, a_imm1: F, a_src_step: F, b_src_imm: F, b_src_mem: F, b_offset_imm0: F, b_imm1: F, b_src_ind: F, ind_width: F, is_external_op: F, op: F, store_ra: F, store_mem: F, store_ind: F, store_offset: F, set_pc: F, jmp_offset1: F, jmp_offset2: F, m32: F, addr1: F, __debug_operation_bus_enabled: F, }); -trace!(Rom0Row, Rom0Trace { +trace!(Rom1Row, Rom1Trace { line: F, a_offset_imm0: F, a_imm1: F, b_offset_imm0: F, b_imm1: F, ind_width: F, op: F, store_offset: F, jmp_offset1: F, jmp_offset2: F, flags: F, multiplicity: F, }); -trace!(Mem0Row, Mem0Trace { +trace!(Mem2Row, Mem2Trace { addr: F, step: F, sel: F, wr: F, value: [F; 2], addr_changes: F, same_value: F, first_addr_access_is_read: F, }); -trace!(MemAlign0Row, MemAlign0Trace { +trace!(MemAlign3Row, MemAlign3Trace { addr: F, offset: F, width: F, wr: F, pc: F, reset: F, sel_up_to_down: F, sel_down_to_up: F, reg: [F; 8], sel: [F; 8], sel_prove: F, step: F, }); -trace!(MemAlignRom0Row, MemAlignRom0Trace { +trace!(MemAlignRom4Row, MemAlignRom4Trace { multiplicity: F, }); -trace!(Binary0Row, Binary0Trace { +trace!(Binary5Row, Binary5Trace { m_op: F, mode32: F, free_in_a: [F; 8], free_in_b: [F; 8], free_in_c: [F; 8], carry: [F; 8], use_last_carry: F, op_is_min_max: F, multiplicity: F, main_step: F, }); -trace!(BinaryTable0Row, BinaryTable0Trace { +trace!(BinaryTable6Row, BinaryTable6Trace { multiplicity: F, }); -trace!(BinaryExtension0Row, BinaryExtension0Trace { +trace!(BinaryExtension7Row, BinaryExtension7Trace { op: F, in1: [F; 8], in2_low: F, out: [[F; 2]; 8], op_is_shift: F, in2: [F; 2], main_step: F, multiplicity: F, }); -trace!(BinaryExtensionTable0Row, BinaryExtensionTable0Trace { +trace!(BinaryExtensionTable8Row, BinaryExtensionTable8Trace { multiplicity: F, }); -trace!(SpecifiedRanges0Row, SpecifiedRanges0Trace { +trace!(SpecifiedRanges9Row, SpecifiedRanges9Trace { mul: [F; 2], }); -trace!(U8Air0Row, U8Air0Trace { +trace!(U8Air10Row, U8Air10Trace { mul: F, }); diff --git a/state-machines/mem/pil/mem_align.pil b/state-machines/mem/pil/mem_align.pil index 184b86a2..2293ba35 100644 --- a/state-machines/mem/pil/mem_align.pil +++ b/state-machines/mem/pil/mem_align.pil @@ -87,10 +87,10 @@ require "std_range_check.pil" Notice that it is enough with 8 combinations. */ -airtemplate MemAlign(const int N = 2**10, const int RC = 2, const int MEM_BYTES = 8, const int CHUNK_BITS = 8) { - const int MEM_HALF_BYTES = MEM_BYTES / 2; +airtemplate MemAlign(const int N = 2**10, const int RC = 2, const int CHUNK_NUM = 8, const int CHUNK_BITS = 8) { + const int CHUNK_NUM_HALF = CHUNK_NUM / 2; - col witness addr; // MEM_BYTES-byte address, real address = addr * MEM_BYTES + col witness addr; // CHUNK_NUM-byte address, real address = addr * CHUNK_NUM col witness offset; // 0..7, position at which the operation starts col witness width; // 1,2,4,8, width of the operation col witness wr; // 1 if the operation is a write, 0 otherwise @@ -98,8 +98,8 @@ airtemplate MemAlign(const int N = 2**10, const int RC = 2, const int MEM_BYTES col witness reset; // 1 at the beginning of the operation (indicating an address reset), 0 otherwise col witness sel_up_to_down; // 1 if the next value is the current value (e.g. R -> W) col witness sel_down_to_up; // 1 if the next value is the previous value (e.g. W -> R) - col witness reg[MEM_BYTES]; // Register values, 1 byte each - col witness sel[MEM_BYTES]; // Selectors, 1 if the value is used, 0 otherwise + col witness reg[CHUNK_NUM]; // Register values, 1 byte each + col witness sel[CHUNK_NUM]; // Selectors, 1 if the value is used, 0 otherwise // 1] Ensure the MemAlign follows the program @@ -107,7 +107,7 @@ airtemplate MemAlign(const int N = 2**10, const int RC = 2, const int MEM_BYTES // - reg' == reg in transitions R -> V, R -> W, W -> V, // - 'reg == reg in transitions V <- W, W <- R, // in any case, sel_up_to_down,sel_down_to_up are 0 in [V] steps. - for (int i = 0; i < MEM_BYTES; i++) { + for (int i = 0; i < CHUNK_NUM; i++) { range_check(reg[i], 0, 2**CHUNK_BITS-1); (reg[i]' - reg[i]) * sel[i] * sel_up_to_down === 0; @@ -118,7 +118,7 @@ airtemplate MemAlign(const int N = 2**10, const int RC = 2, const int MEM_BYTES L1 * pc === 0; // The program should start at the first line // We compress selectors, so we should ensure they are binary - for (int i = 0; i < MEM_BYTES; i++) { + for (int i = 0; i < CHUNK_NUM; i++) { sel[i] * (1 - sel[i]) === 0; } wr * (1 - wr) === 0; @@ -127,10 +127,10 @@ airtemplate MemAlign(const int N = 2**10, const int RC = 2, const int MEM_BYTES sel_down_to_up * (1 - sel_down_to_up) === 0; expr flags = 0; - for (int i = 0; i < MEM_BYTES; i++) { + for (int i = 0; i < CHUNK_NUM; i++) { flags += sel[i] * 2**i; } - flags += wr * 2**MEM_BYTES + reset * 2**(MEM_BYTES + 1) + sel_up_to_down * 2**(MEM_BYTES + 2) + sel_down_to_up * 2**(MEM_BYTES + 3); + flags += wr * 2**CHUNK_NUM + reset * 2**(CHUNK_NUM + 1) + sel_up_to_down * 2**(CHUNK_NUM + 2) + sel_down_to_up * 2**(CHUNK_NUM + 3); lookup_assumes(MEM_ALIGN_ROM_ID, [pc, pc'-pc, (addr-'addr)*(1-reset), offset, width, flags]); @@ -144,8 +144,8 @@ airtemplate MemAlign(const int N = 2**10, const int RC = 2, const int MEM_BYTES expr assume_val[RC]; for (int i = 0; i < RC; i++) { assume_val[i] = 0; - for (int j = 0; j < MEM_HALF_BYTES; j++) { - assume_val[i] += reg[j + i * MEM_HALF_BYTES] * 2**j; + for (int j = 0; j < CHUNK_NUM_HALF; j++) { + assume_val[i] += reg[j + i * CHUNK_NUM_HALF] * 2**j; } } @@ -158,16 +158,16 @@ airtemplate MemAlign(const int N = 2**10, const int RC = 2, const int MEM_BYTES expr prove_val[RC]; for (int i = 0; i < RC; i++) { prove_val[i] = 0; - for (int j = 0; j < MEM_HALF_BYTES; j++) { + for (int j = 0; j < CHUNK_NUM_HALF; j++) { expr _prove_val = 0; - for (int k = j; k < j + MEM_HALF_BYTES; k++) { - _prove_val += reg[(k + i * MEM_HALF_BYTES) % MEM_BYTES] * 2**(k-j); + for (int k = j; k < j + CHUNK_NUM_HALF; k++) { + _prove_val += reg[(k + i * CHUNK_NUM_HALF) % CHUNK_NUM] * 2**(k-j); } - prove_val[i] += sel[j + i * MEM_HALF_BYTES] * _prove_val; + prove_val[i] += sel[j + i * CHUNK_NUM_HALF] * _prove_val; } } // We prove and assume with the same permutation check but with disjoint and different sign selectors col witness step; - permutation(MEMORY_ID, cols: [wr * (MEMORY_STORE_OP - MEMORY_LOAD_OP) + MEMORY_LOAD_OP, addr * MEM_BYTES + offset, step, width, ...prove_val], sel: sel_prove - sel_assume); + permutation(MEMORY_ID, cols: [wr * (MEMORY_STORE_OP - MEMORY_LOAD_OP) + MEMORY_LOAD_OP, addr * CHUNK_NUM + offset, step, width, ...prove_val], sel: sel_prove - sel_assume); } \ No newline at end of file diff --git a/state-machines/mem/pil/mem_align_rom.pil b/state-machines/mem/pil/mem_align_rom.pil index 5f5d8b6b..863959a5 100644 --- a/state-machines/mem/pil/mem_align_rom.pil +++ b/state-machines/mem/pil/mem_align_rom.pil @@ -13,7 +13,7 @@ const int MEM_ALIGN_ROM_SIZE = P2_8; // Note1: The offset and width are sufficient to group programs with the same number of operations. // Note2: The first instruction is always a read. -airtemplate MemAlignRom(const int N = MEM_ALIGN_ROM_SIZE, const int MEM_BYTES = 8, const int disable_fixed = 0) { +airtemplate MemAlignRom(const int N = MEM_ALIGN_ROM_SIZE, const int CHUNK_NUM = 8, const int disable_fixed = 0) { if (N < MEM_ALIGN_ROM_SIZE) { error(`N must be at least ${MEM_ALIGN_ROM_SIZE}, but N=${N} was provided`); } @@ -83,8 +83,8 @@ airtemplate MemAlignRom(const int N = MEM_ALIGN_ROM_SIZE, const int MEM_BYTES = int delta_addr = 0; int is_write = 0; int reset = 0; - int sel[MEM_BYTES]; - for (int j = 0; j < MEM_BYTES; j++) { + int sel[CHUNK_NUM]; + for (int j = 0; j < CHUNK_NUM; j++) { sel[j] = 0; } int sel_up_to_down = 0; @@ -109,7 +109,7 @@ airtemplate MemAlignRom(const int N = MEM_ALIGN_ROM_SIZE, const int MEM_BYTES = delta_addr = 1; // is_write = 0; // reset = 0; - // sel = [0:MEM_BYTES] + // sel = [0:CHUNK_NUM] // sel_up_to_down = 0; // sel_down_to_up = 0; } @@ -140,7 +140,7 @@ airtemplate MemAlignRom(const int N = MEM_ALIGN_ROM_SIZE, const int MEM_BYTES = // delta_addr = 0; // is_write = 0; // reset = 0; - // sel = [0:MEM_BYTES] + // sel = [0:CHUNK_NUM] // sel_up_to_down = 0; // sel_down_to_up = 0; } @@ -162,7 +162,7 @@ airtemplate MemAlignRom(const int N = MEM_ALIGN_ROM_SIZE, const int MEM_BYTES = delta_addr = 1; // is_write = 0; // reset = 0; - // sel = [0:MEM_BYTES] + // sel = [0:CHUNK_NUM] // sel_up_to_down = 1; // sel_down_to_up = 0; } else { // R @@ -202,7 +202,7 @@ airtemplate MemAlignRom(const int N = MEM_ALIGN_ROM_SIZE, const int MEM_BYTES = delta_addr = 1; // is_write = 0; // reset = 0; - // sel = [0:MEM_BYTES] + // sel = [0:CHUNK_NUM] // sel_up_to_down = 0; // sel_down_to_up = 0; } else if (next % 5 == 3) { // W @@ -230,10 +230,10 @@ airtemplate MemAlignRom(const int N = MEM_ALIGN_ROM_SIZE, const int MEM_BYTES = DELTA_PC[i] = delta_pc; DELTA_ADDR[i] = delta_addr; FLAGS[i] = 0; - for (int j = 0; j < MEM_BYTES; j++) { + for (int j = 0; j < CHUNK_NUM; j++) { FLAGS[i] += sel[j] * 2**j; } - FLAGS[i] += is_write * 2**MEM_BYTES + reset * 2**(MEM_BYTES + 1) + sel_up_to_down * 2**(MEM_BYTES + 2) + sel_down_to_up * 2**(MEM_BYTES + 3); + FLAGS[i] += is_write * 2**CHUNK_NUM + reset * 2**(CHUNK_NUM + 1) + sel_up_to_down * 2**(CHUNK_NUM + 2) + sel_down_to_up * 2**(CHUNK_NUM + 3); } lookup_proves(MEM_ALIGN_ROM_ID, [PC, DELTA_PC, DELTA_ADDR, OFFSET, WIDTH, FLAGS], multiplicity); diff --git a/state-machines/mem/src/mem_align_rom_sm.rs b/state-machines/mem/src/mem_align_rom_sm.rs index 3406fc64..2f7027aa 100644 --- a/state-machines/mem/src/mem_align_rom_sm.rs +++ b/state-machines/mem/src/mem_align_rom_sm.rs @@ -1,6 +1,9 @@ -use std::sync::{ - atomic::{AtomicU32, Ordering}, - Arc, Mutex, +use std::{ + collections::HashMap, + sync::{ + atomic::{AtomicU32, Ordering}, + Arc, Mutex, + }, }; use log::info; @@ -12,7 +15,7 @@ use rayon::prelude::*; use sm_common::create_prover_buffer; use zisk_pil::{MEM_ALIGN_ROM_AIR_IDS, ZISK_AIRGROUP_ID}; -use crate::MemOps; +use crate::MemOp; const CHUNKS: usize = 8; const MEM_WIDTHS: [u64; 4] = [1, 2, 4, 8]; @@ -27,8 +30,7 @@ pub struct MemAlignRomSM { // Rom data num_rows: usize, - line: Mutex, - multiplicity: Mutex>, + multiplicity: Mutex>, // row_num -> multiplicity } #[derive(Debug)] @@ -42,13 +44,13 @@ impl MemAlignRomSM { pub fn new(wcm: Arc>, airgroup_id: usize, air_ids: &[usize]) -> Arc { let pctx = wcm.get_pctx(); let air = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_ALIGN_ROM_AIR_IDS[0]); + let num_rows = air.num_rows(); let mem_align_rom = Self { wcm: wcm.clone(), registered_predecessors: AtomicU32::new(0), - num_rows: air.num_rows(), - line: Mutex::new(0), - multiplicity: Mutex::new(vec![0; air.num_rows()]), + num_rows, + multiplicity: Mutex::new(HashMap::with_capacity(num_rows)), }; let mem_align_rom = Arc::new(mem_align_rom); wcm.register_component(mem_align_rom.clone(), Some(airgroup_id), Some(air_ids)); @@ -66,13 +68,14 @@ impl MemAlignRomSM { } } - pub fn get_op_size(op: MemOps) -> usize { + pub fn get_mem_align_op_size(op: MemOp) -> usize { OP_SIZES[op as usize] } - pub fn calculate_rom_rows(opcode: MemOps, offset: usize, width: usize) -> Vec { + fn calculate_rom_rows(opcode: MemOp, offset: usize, width: usize) -> Vec { + // Calculate the ROM rows based on the requested opcode, offset, and width match opcode { - MemOps::OneRead | MemOps::OneWrite => { + MemOp::OneRead | MemOp::OneWrite => { // Sanity check assert!(offset + width <= CHUNKS); let possible_widths = match offset { @@ -81,9 +84,9 @@ impl MemAlignRomSM { x if x == 7 => vec![1], _ => panic!("Invalid offset={}", offset), }; - Self::get_rows(opcode, possible_widths, offset, width) + Self::get_row_idxs(opcode, possible_widths, offset, width) } - MemOps::TwoReads | MemOps::TwoWrites => { + MemOp::TwoReads | MemOp::TwoWrites => { // Sanity check assert!(offset + width > CHUNKS); let possible_widths = match offset { @@ -93,58 +96,73 @@ impl MemAlignRomSM { x if x == 7 => vec![2, 4, 8], _ => panic!("Invalid offset={}", offset), }; - Self::get_rows(opcode, possible_widths, offset, width) + Self::get_row_idxs(opcode, possible_widths, offset, width) } } } - fn get_rows( - opcode: MemOps, + fn get_row_idxs( + opcode: MemOp, possible_widths: Vec, offset: usize, width: usize, - ) -> Vec { + ) -> Vec { // Sanity check assert!(possible_widths.contains(&width)); let width_idx = possible_widths.iter().position(|&w| w == width).unwrap(); let opcode_idx = opcode as usize; match opcode { - MemOps::OneRead | MemOps::OneWrite => { - let value_row = offset * possible_widths.len() * OP_SIZES[opcode_idx] + MemOp::OneRead | MemOp::OneWrite => { + let value_row = (offset * possible_widths.len() * OP_SIZES[opcode_idx] + (offset + width_idx + 1) * OP_SIZES[opcode_idx] - - 1; + - 1) as u64; match opcode { - MemOps::OneRead => vec![value_row - 1, value_row], - MemOps::OneWrite => vec![value_row - 2, value_row - 1, value_row], + MemOp::OneRead => vec![value_row - 1, value_row], + MemOp::OneWrite => vec![value_row - 2, value_row - 1, value_row], _ => unreachable!(), } } - MemOps::TwoReads => { - let value_row = offset * possible_widths.len() * OP_SIZES[opcode_idx] + MemOp::TwoReads => { + let value_row = (offset * possible_widths.len() * OP_SIZES[opcode_idx] + (offset + width_idx + 1) * OP_SIZES[opcode_idx] - - 2; + - 2) as u64; return vec![value_row - 1, value_row, value_row + 1]; } - MemOps::TwoWrites => { - let value_row = offset * possible_widths.len() * OP_SIZES[opcode_idx] + MemOp::TwoWrites => { + let value_row = (offset * possible_widths.len() * OP_SIZES[opcode_idx] + (offset + width_idx + 1) * OP_SIZES[opcode_idx] - - 3; + - 3) as u64; return vec![value_row - 2, value_row - 1, value_row, value_row + 1, value_row + 2]; } } } - pub fn calculate_next_pc(op: MemOps, offset: usize, width: usize) -> usize { + pub fn calculate_next_pc(op: MemOp, offset: usize, width: usize) -> u64 { let rows = Self::calculate_rom_rows(op, offset, width); + + // The "next" pc is always found on the second row of the program being executed rows[1] } - pub fn process_slice(&self, input: &[u8]) { + pub fn update_multiplicity_by_input(&self, opcode: MemOp, offset: usize, width: usize) { + let row_idxs = Self::calculate_rom_rows(opcode, offset, width); + self.update_multiplicity_by_idx(&row_idxs); + } + + pub fn update_multiplicity_by_idx(&self, idxs: &[u64]) { + let mut multiplicity = self.multiplicity.lock().unwrap(); + + for &i in idxs { + *multiplicity.entry(F::from_canonical_u64(i)).or_insert(0) += 1; + } + } + + pub fn update_multiplicity(&self, inputs: &[u64]) { let mut multiplicity = self.multiplicity.lock().unwrap(); - for (i, val) in input.iter().enumerate() { - multiplicity[i] += *val as u64; + for (idx, mul) in inputs.iter().enumerate() { + *multiplicity.entry(F::from_canonical_usize(idx)).or_insert(0) += *mul; } } @@ -177,7 +195,7 @@ impl MemAlignRomSM { .for_each(|(i, input)| *input = F::from_canonical_u64(multiplicity_[i])); info!( - "{}: ··· Creating Binary extension table instance [{} rows filled 100%]", + "{}: ··· Creating Mem Align ROM instance [{} rows filled 100%]", Self::MY_NAME, self.num_rows, ); diff --git a/state-machines/mem/src/mem_align_sm.rs b/state-machines/mem/src/mem_align_sm.rs index 61b197b5..e2f0ae98 100644 --- a/state-machines/mem/src/mem_align_sm.rs +++ b/state-machines/mem/src/mem_align_sm.rs @@ -16,14 +16,12 @@ use rayon::Scope; use sm_common::{create_prover_buffer, OpResult, Provable}; use zisk_core::ZiskRequiredMemory; -use zisk_pil::{ - MemAlign3Row, MemAlign3Trace, MEM_ALIGN_AIR_IDS, MEM_ALIGN_ROM_AIR_IDS, ZISK_AIRGROUP_ID, -}; +use zisk_pil::{MemAlign3Row, MemAlign3Trace, MEM_ALIGN_AIR_IDS, ZISK_AIRGROUP_ID}; use crate::MemAlignRomSM; #[derive(Debug, Clone, Copy)] -pub enum MemOps { +pub enum MemOp { OneRead, OneWrite, TwoReads, @@ -31,8 +29,10 @@ pub enum MemOps { } const PROVE_CHUNK_SIZE: usize = 1 << 12; -const CHUNKS: usize = 8; -const CHUNKS_U64: u64 = CHUNKS as u64; + +const CHUNK_NUM: usize = 8; +const CHUNK_NUM_U64: u64 = CHUNK_NUM as u64; +const CHUNK_BITS: usize = 8; pub struct MemAlignSM { // Witness computation manager @@ -93,33 +93,33 @@ impl MemAlignSM { } #[inline(always)] - pub fn get_mem_ops(unaligned_input: &ZiskRequiredMemory) -> MemOps { + pub fn get_mem_op(unaligned_input: &ZiskRequiredMemory) -> MemOp { let addr = unaligned_input.address; let width = unaligned_input.width; let offset = addr % 8; match (unaligned_input.is_write, offset + width > 8) { - (false, false) => MemOps::OneRead, - (true, false) => MemOps::OneWrite, - (false, true) => MemOps::TwoReads, - (true, true) => MemOps::TwoWrites, + (false, false) => MemOp::OneRead, + (true, false) => MemOp::OneWrite, + (false, true) => MemOp::TwoReads, + (true, true) => MemOp::TwoWrites, } } #[inline(always)] pub fn process_slice( input: &Vec, - multiplicity: &mut [u8], - range_check: &mut HashMap, + mem_align_rom_sm: &MemAlignRomSM, + range_check: &mut HashMap, ) -> Vec> { // Is a write or a read operation let _wr = input[0].is_write; // Get the address let addr = input[0].address; - let addr_prior = input[1].address; // addr / CHUNKS; - let addr_next = input[2].address; // addr / CHUNKS + CHUNKS; + let addr_prior = input[1].address; // addr / CHUNK_NUM; + let addr_next = input[2].address; // addr / CHUNK_NUM + CHUNK_NUM; // Get the value let value = input[0].value.to_be_bytes(); @@ -136,33 +136,34 @@ impl MemAlignSM { let step_second_write = input[4].step; // Get the offset - let offset = addr % CHUNKS_U64; + let offset = addr % CHUNK_NUM_U64; let offset = if offset <= usize::MAX as u64 { offset as usize } else { - panic!("MemAlignSM::process_slice() got invalid offset={}", offset) + panic!("Invalid offset={}", offset); }; // Get the width let width = input[0].width; - let width = if width <= CHUNKS_U64 { + let width = if width <= CHUNK_NUM_U64 { width as usize } else { - panic!("MemAlignSM::process_slice() got invalid width={}", width) + panic!("Invalid width={}", width); }; // Compute the shift - let shift = (offset + width) % CHUNKS; + let shift = (offset + width) % CHUNK_NUM; // Get the op to be executed, its size and the pc to jump to - let op = Self::get_mem_ops(&input[0]); - let op_size = MemAlignRomSM::::get_op_size(op); + let op = Self::get_mem_op(&input[0]); + let op_size = MemAlignRomSM::::get_mem_align_op_size(op); let next_pc = MemAlignRomSM::::calculate_next_pc(op, offset, width); // Initialize and set the rows of the corresponding op let mut rows: Vec> = Vec::with_capacity(op_size); + // TODO: Can I detatch the "shape" of the program from the mem_align and do it in the mem_align_rom? match op { - MemOps::OneRead => { + MemOp::OneRead => { // RV let mut read_row = MemAlign3Row:: { step: F::from_canonical_u64(step_first_read), @@ -181,13 +182,13 @@ impl MemAlignSM { offset: F::from_canonical_usize(offset), width: F::from_canonical_usize(width), // wr: F::from_bool(false), - pc: F::from_canonical_usize(next_pc), + pc: F::from_canonical_u64(next_pc), // reset: F::from_bool(false), sel_prove: F::from_bool(true), ..Default::default() }; - for i in 0..CHUNKS { + for i in 0..CHUNK_NUM { read_row.reg[i] = F::from_canonical_u8(value_first_read[i]); read_row.sel[i] = F::from_bool(true); @@ -195,21 +196,21 @@ impl MemAlignSM { value_row.sel[i] = F::from_bool(i == offset); // Store the range check - *range_check.entry(value_first_read[i]).or_insert(0) += 1; - *range_check.entry(value[shift + i]).or_insert(0) += 1; + *range_check.entry(read_row.reg[i]).or_insert(0) += 1; + *range_check.entry(value_row.reg[i]).or_insert(0) += 1; } // Store the rows rows.push(read_row); rows.push(value_row); } - MemOps::OneWrite => { + MemOp::OneWrite => { // RWV let mut read_row = MemAlign3Row:: { step: F::from_canonical_u64(step_first_read), addr: F::from_canonical_u64(addr_prior), // offset: F::from_canonical_u64(0), - width: F::from_canonical_u64(CHUNKS_U64), + width: F::from_canonical_u64(CHUNK_NUM_U64), // wr: F::from_bool(false), // pc: F::from_canonical_u64(0), reset: F::from_bool(true), @@ -221,9 +222,9 @@ impl MemAlignSM { step: F::from_canonical_u64(step_first_write), addr: F::from_canonical_u64(addr_prior), // offset: F::from_canonical_u64(0), - width: F::from_canonical_u64(CHUNKS_U64), + width: F::from_canonical_u64(CHUNK_NUM_U64), wr: F::from_bool(true), - pc: F::from_canonical_usize(next_pc), + pc: F::from_canonical_u64(next_pc), // reset: F::from_bool(false), sel_up_to_down: F::from_bool(true), ..Default::default() @@ -235,13 +236,13 @@ impl MemAlignSM { offset: F::from_canonical_usize(offset), width: F::from_canonical_usize(width), // wr: F::from_bool(false), - pc: F::from_canonical_usize(next_pc + 1), + pc: F::from_canonical_u64(next_pc + 1), // reset: F::from_bool(false), sel_prove: F::from_bool(true), ..Default::default() }; - for i in 0..CHUNKS { + for i in 0..CHUNK_NUM { read_row.reg[i] = F::from_canonical_u8(value_first_read[i]); read_row.sel[i] = F::from_bool(i < offset); @@ -252,9 +253,9 @@ impl MemAlignSM { value_row.sel[i] = F::from_bool(i == offset); // Store the range check - *range_check.entry(value_first_read[i]).or_insert(0) += 1; - *range_check.entry(value_first_write[i]).or_insert(0) += 1; - *range_check.entry(value[shift + i]).or_insert(0) += 1; + *range_check.entry(read_row.reg[i]).or_insert(0) += 1; + *range_check.entry(write_row.reg[i]).or_insert(0) += 1; + *range_check.entry(value_row.reg[i]).or_insert(0) += 1; } // Store the rows @@ -262,13 +263,13 @@ impl MemAlignSM { rows.push(write_row); rows.push(value_row); } - MemOps::TwoReads => { + MemOp::TwoReads => { // RVR let mut first_read_row = MemAlign3Row:: { step: F::from_canonical_u64(step_first_read), addr: F::from_canonical_u64(addr_prior), // offset: F::from_canonical_u64(0), - width: F::from_canonical_u64(CHUNKS_U64), + width: F::from_canonical_u64(CHUNK_NUM_U64), // wr: F::from_bool(false), // pc: F::from_canonical_u64(0), reset: F::from_bool(true), @@ -282,7 +283,7 @@ impl MemAlignSM { offset: F::from_canonical_usize(offset), width: F::from_canonical_usize(width), // wr: F::from_bool(false), - pc: F::from_canonical_usize(next_pc), + pc: F::from_canonical_u64(next_pc), // reset: F::from_bool(false), sel_prove: F::from_bool(true), ..Default::default() @@ -292,15 +293,15 @@ impl MemAlignSM { step: F::from_canonical_u64(step_second_read), addr: F::from_canonical_u64(addr_next), // offset: F::from_canonical_u64(0), - width: F::from_canonical_u64(CHUNKS_U64), + width: F::from_canonical_u64(CHUNK_NUM_U64), // wr: F::from_bool(false), - pc: F::from_canonical_usize(next_pc + 1), + pc: F::from_canonical_u64(next_pc + 1), // reset: F::from_bool(false), sel_down_to_up: F::from_bool(true), ..Default::default() }; - for i in 0..CHUNKS { + for i in 0..CHUNK_NUM { first_read_row.reg[i] = F::from_canonical_u8(value_first_read[i]); first_read_row.sel[i] = F::from_bool(true); @@ -311,9 +312,9 @@ impl MemAlignSM { second_read_row.sel[i] = F::from_bool(true); // Store the range check - *range_check.entry(value_first_read[i]).or_insert(0) += 1; - *range_check.entry(value[shift + i]).or_insert(0) += 1; - *range_check.entry(value_second_read[i]).or_insert(0) += 1; + *range_check.entry(first_read_row.reg[i]).or_insert(0) += 1; + *range_check.entry(value_row.reg[i]).or_insert(0) += 1; + *range_check.entry(second_read_row.reg[i]).or_insert(0) += 1; } // Store the rows @@ -321,13 +322,13 @@ impl MemAlignSM { rows.push(value_row); rows.push(second_read_row); } - MemOps::TwoWrites => { + MemOp::TwoWrites => { // RWVWR let mut first_read_row = MemAlign3Row:: { step: F::from_canonical_u64(step_first_read), addr: F::from_canonical_u64(addr_prior), // offset: F::from_canonical_u64(0), - width: F::from_canonical_u64(CHUNKS_U64), + width: F::from_canonical_u64(CHUNK_NUM_U64), // wr: F::from_bool(false), // pc: F::from_canonical_u64(0), reset: F::from_bool(true), @@ -339,9 +340,9 @@ impl MemAlignSM { step: F::from_canonical_u64(step_first_write), addr: F::from_canonical_u64(addr_prior), // offset: F::from_canonical_u64(0), - width: F::from_canonical_u64(CHUNKS_U64), + width: F::from_canonical_u64(CHUNK_NUM_U64), wr: F::from_bool(true), - pc: F::from_canonical_usize(next_pc), + pc: F::from_canonical_u64(next_pc), // reset: F::from_bool(false), sel_up_to_down: F::from_bool(true), ..Default::default() @@ -353,7 +354,7 @@ impl MemAlignSM { offset: F::from_canonical_usize(offset), width: F::from_canonical_usize(width), // wr: F::from_bool(false), - pc: F::from_canonical_usize(next_pc + 1), + pc: F::from_canonical_u64(next_pc + 1), // reset: F::from_bool(false), sel_prove: F::from_bool(true), ..Default::default() @@ -363,9 +364,9 @@ impl MemAlignSM { step: F::from_canonical_u64(step_second_write), addr: F::from_canonical_u64(addr_next), // offset: F::from_canonical_u64(0), - width: F::from_canonical_u64(CHUNKS_U64), + width: F::from_canonical_u64(CHUNK_NUM_U64), wr: F::from_bool(true), - pc: F::from_canonical_usize(next_pc + 2), + pc: F::from_canonical_u64(next_pc + 2), // reset: F::from_bool(false), sel_down_to_up: F::from_bool(true), ..Default::default() @@ -375,15 +376,15 @@ impl MemAlignSM { step: F::from_canonical_u64(step_second_read), addr: F::from_canonical_u64(addr_next), // offset: F::from_canonical_u64(0), - width: F::from_canonical_u64(CHUNKS_U64), + width: F::from_canonical_u64(CHUNK_NUM_U64), // wr: F::from_bool(false), - pc: F::from_canonical_usize(next_pc + 3), + pc: F::from_canonical_u64(next_pc + 3), reset: F::from_bool(false), sel_down_to_up: F::from_bool(true), ..Default::default() }; - for i in 0..CHUNKS { + for i in 0..CHUNK_NUM { first_read_row.reg[i] = F::from_canonical_u8(value_first_read[i]); first_read_row.sel[i] = F::from_bool(i < offset); @@ -400,11 +401,11 @@ impl MemAlignSM { second_read_row.sel[i] = F::from_bool(i >= shift); // Store the range check - *range_check.entry(value_first_read[i]).or_insert(0) += 1; - *range_check.entry(value_first_write[i]).or_insert(0) += 1; - *range_check.entry(value[shift + i]).or_insert(0) += 1; - *range_check.entry(value_second_write[i]).or_insert(0) += 1; - *range_check.entry(value_second_read[i]).or_insert(0) += 1; + *range_check.entry(first_read_row.reg[i]).or_insert(0) += 1; + *range_check.entry(first_write_row.reg[i]).or_insert(0) += 1; + *range_check.entry(value_row.reg[i]).or_insert(0) += 1; + *range_check.entry(second_write_row.reg[i]).or_insert(0) += 1; + *range_check.entry(second_read_row.reg[i]).or_insert(0) += 1; } // Store the rows @@ -416,11 +417,8 @@ impl MemAlignSM { } } - // Compute and store the ROM row multiplicity - let rom_rows = MemAlignRomSM::::calculate_rom_rows(op, offset, width); - for &row in rom_rows.iter() { - multiplicity[row] += 1; - } + // Update the ROM row multiplicity + mem_align_rom_sm.update_multiplicity_by_input(op, offset, width); // Return successfully rows @@ -453,7 +451,6 @@ impl MemAlignSM { let pctx = wcm.get_pctx(); let air_mem_align = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_ALIGN_AIR_IDS[0]); - let air_mem_align_rom = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_ALIGN_ROM_AIR_IDS[0]); assert!(inputs.len() <= air_mem_align.num_rows()); info!( @@ -464,9 +461,7 @@ impl MemAlignSM { inputs.len() as f64 / air_mem_align.num_rows() as f64 * 100.0 ); - // let mut multiplicity_table = vec![0u8; air_mem_align_rom.num_rows()]; - let mut multiplicity_table = vec![0u8; air_mem_align_rom.num_rows()]; - let mut range_check: HashMap = HashMap::new(); + let mut reg_range_check: HashMap = HashMap::new(); let mut trace_buffer = MemAlign3Trace::::map_buffer( prover_buffer, air_mem_align.num_rows(), @@ -476,7 +471,7 @@ impl MemAlignSM { // Process the inputs while saving the multiplcities and range checks let mut rows_processed = 0; - let rows = Self::process_slice(&inputs, &mut multiplicity_table, &mut range_check); + let rows = Self::process_slice(&inputs, mem_align_rom_sm, &mut reg_range_check); for (i, &row) in rows.iter().enumerate() { trace_buffer[rows_processed + i] = row; } @@ -496,27 +491,19 @@ impl MemAlignSM { // let row = MemAlignRomSM::::calculate_rom_row( // op, offset, width // ); - // multiplicity_table[row as usize] += multiplicity; + // rom_multiplicity[row as usize] += multiplicity; // } - // Compute the ROM multiplicities - mem_align_rom_sm.process_slice(&multiplicity_table); - // Perform the range checks - let range_id = std.get_range(BigInt::from(0), BigInt::from(0xFF), None); - for (&value, &multiplicity) in range_check.iter() { - std.range_check( - F::from_canonical_u8(value), - F::from_canonical_u64(multiplicity), - range_id, - ); + let range_id = std.get_range(BigInt::from(0), BigInt::from((1 << CHUNK_BITS) - 1), None); + for (&value, &multiplicity) in reg_range_check.iter() { + std.range_check(value, F::from_canonical_u64(multiplicity), range_id); } - std::thread::spawn(move || { - drop(inputs); - drop(multiplicity_table); - drop(range_check); - }); + // std::thread::spawn(move || { + // drop(inputs); + // drop(reg_range_check); + // }); } } @@ -536,9 +523,7 @@ impl Provable for MemAlignSM { let mem_align_rom_sm = self.mem_align_rom_sm.clone(); let wcm = self.wcm.clone(); - let std = self.std.clone(); - let sctx = self.wcm.get_sctx().clone(); let (mut prover_buffer, offset) = create_prover_buffer( From 71ca095bda9148f834a325386e61912189bff86a Mon Sep 17 00:00:00 2001 From: Xavier Pinsach <10213118+xavi-pinsach@users.noreply.github.com> Date: Wed, 6 Nov 2024 10:56:09 +0000 Subject: [PATCH 08/44] wip --- Cargo.lock | 52 ++++----- emulator/src/emu.rs | 28 ++--- emulator/src/emu_full_trace.rs | 4 +- pil/src/lib.rs | 3 - pil/src/pil_helpers/traces.rs | 22 ++-- state-machines/binary/src/binary_basic.rs | 13 +-- .../binary/src/binary_basic_table.rs | 2 +- state-machines/binary/src/binary_extension.rs | 37 +++--- .../binary/src/binary_extension_table.rs | 13 +-- state-machines/main/src/main_sm.rs | 18 +-- state-machines/mem/src/lib.rs | 4 +- state-machines/mem/src/mem_align_sm.rs | 4 +- state-machines/mem/src/mem_proxy.rs | 58 +++++++--- state-machines/mem/src/mem_sm.rs | 105 ++++++------------ state-machines/rom/src/rom.rs | 12 +- witness-computation/src/executor.rs | 7 +- 16 files changed, 179 insertions(+), 203 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ffe2fb46..26b57af5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -47,9 +47,9 @@ checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" [[package]] name = "anstream" -version = "0.6.17" +version = "0.6.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23a1e53f0f5d86382dafe1cf314783b2044280f406e7e1506368220ad11b1338" +checksum = "8acc5369981196006228e28809f761875c0327210a891e941f4c683b3a99529b" dependencies = [ "anstyle", "anstyle-parse", @@ -96,9 +96,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.92" +version = "1.0.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74f37166d7d48a0284b99dd824694c26119c700b53bf0d1540cdb147dbdaaf13" +checksum = "4c95c10ba0b00a02636238b814946408b1322d5ac4760326e6fb8ec956d85775" dependencies = [ "backtrace", ] @@ -198,9 +198,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.1.34" +version = "1.1.36" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67b9470d453346108f93a59222a9a1a5724db32d0a4727b7ab7ace4b4d822dc9" +checksum = "baee610e9452a8f6f0a1b6194ec09ff9e2d85dea54432acdae41aa0761c95d70" dependencies = [ "jobserver", "libc", @@ -680,9 +680,9 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.15.0" +version = "0.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e087f84d4f86bf4b218b927129862374b72199ae7d8657835f1e89000eea4fb" +checksum = "3a9bfc1af68b1726ea47d3d5109de126281def866b33970e10fbab11b5dafab3" [[package]] name = "heck" @@ -1474,7 +1474,7 @@ dependencies = [ [[package]] name = "pil-std-lib" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#e87671a5ca63cd1f312554aeca4b1e4b0a3905a9" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#34427c192773b6e372b430ebd70912a3a7b2b4aa" dependencies = [ "log", "num-bigint", @@ -1492,7 +1492,7 @@ dependencies = [ [[package]] name = "pilout" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#e87671a5ca63cd1f312554aeca4b1e4b0a3905a9" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#34427c192773b6e372b430ebd70912a3a7b2b4aa" dependencies = [ "bytes", "log", @@ -1612,7 +1612,7 @@ dependencies = [ [[package]] name = "proofman" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#e87671a5ca63cd1f312554aeca4b1e4b0a3905a9" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#34427c192773b6e372b430ebd70912a3a7b2b4aa" dependencies = [ "colored", "env_logger", @@ -1633,7 +1633,7 @@ dependencies = [ [[package]] name = "proofman-common" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#e87671a5ca63cd1f312554aeca4b1e4b0a3905a9" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#34427c192773b6e372b430ebd70912a3a7b2b4aa" dependencies = [ "env_logger", "log", @@ -1651,7 +1651,7 @@ dependencies = [ [[package]] name = "proofman-hints" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#e87671a5ca63cd1f312554aeca4b1e4b0a3905a9" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#34427c192773b6e372b430ebd70912a3a7b2b4aa" dependencies = [ "p3-field", "proofman-common", @@ -1661,7 +1661,7 @@ dependencies = [ [[package]] name = "proofman-macros" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#e87671a5ca63cd1f312554aeca4b1e4b0a3905a9" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#34427c192773b6e372b430ebd70912a3a7b2b4aa" dependencies = [ "proc-macro2", "quote", @@ -1671,7 +1671,7 @@ dependencies = [ [[package]] name = "proofman-starks-lib-c" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#e87671a5ca63cd1f312554aeca4b1e4b0a3905a9" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#34427c192773b6e372b430ebd70912a3a7b2b4aa" dependencies = [ "log", ] @@ -1679,7 +1679,7 @@ dependencies = [ [[package]] name = "proofman-util" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#e87671a5ca63cd1f312554aeca4b1e4b0a3905a9" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#34427c192773b6e372b430ebd70912a3a7b2b4aa" dependencies = [ "colored", "sysinfo 0.31.4", @@ -1784,9 +1784,9 @@ dependencies = [ [[package]] name = "quinn-udp" -version = "0.5.6" +version = "0.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e346e016eacfff12233c243718197ca12f148c84e1e84268a896699b41c71780" +checksum = "7d5a626c6807713b15cac82a6acaccd6043c9a5408c24baae07611fec3f243da" dependencies = [ "cfg_aliases", "libc", @@ -2010,9 +2010,9 @@ checksum = "583034fd73374156e66797ed8e5b0d5690409c9226b22d87cb7f19821c05d152" [[package]] name = "rustix" -version = "0.38.38" +version = "0.38.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa260229e6538e52293eeb577aabd09945a09d6d9cc0fc550ed7529056c2e32a" +checksum = "375116bee2be9ed569afe2154ea6a99dfdffd257f533f187498c2a8f5feaf4ee" dependencies = [ "bitflags 2.6.0", "errno", @@ -2314,7 +2314,7 @@ checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" [[package]] name = "stark" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#e87671a5ca63cd1f312554aeca4b1e4b0a3905a9" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#34427c192773b6e372b430ebd70912a3a7b2b4aa" dependencies = [ "log", "p3-field", @@ -2456,18 +2456,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.67" +version = "1.0.68" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b3c6efbfc763e64eb85c11c25320f0737cb7364c4b6336db90aa9ebe27a0bbd" +checksum = "02dd99dc800bbb97186339685293e1cc5d9df1f8fae2d0aecd9ff1c77efea892" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.67" +version = "1.0.68" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b607164372e89797d78b8e23a6d67d5d1038c1c65efd52e1389ef8b77caba2a6" +checksum = "a7c61ec9a6f64d2793d8a45faba21efbe3ced62a886d44c36a009b2b519b4c7e" dependencies = [ "proc-macro2", "quote", @@ -2644,7 +2644,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#e87671a5ca63cd1f312554aeca4b1e4b0a3905a9" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#34427c192773b6e372b430ebd70912a3a7b2b4aa" dependencies = [ "proofman-starks-lib-c", ] diff --git a/emulator/src/emu.rs b/emulator/src/emu.rs index b5b8da69..cc24f555 100644 --- a/emulator/src/emu.rs +++ b/emulator/src/emu.rs @@ -501,9 +501,9 @@ impl<'a> Emu<'a> { } // Log emulation step, if requested - if options.print_step.is_some() - && (options.print_step.unwrap() != 0) - && ((self.ctx.inst_ctx.step % options.print_step.unwrap()) == 0) + if options.print_step.is_some() && + (options.print_step.unwrap() != 0) && + ((self.ctx.inst_ctx.step % options.print_step.unwrap()) == 0) { println!("step={}", self.ctx.inst_ctx.step); } @@ -706,9 +706,9 @@ impl<'a> Emu<'a> { // Increment step counter self.ctx.inst_ctx.step += 1; - if self.ctx.inst_ctx.end - || ((self.ctx.inst_ctx.step - self.ctx.last_callback_step) - == self.ctx.callback_steps) + if self.ctx.inst_ctx.end || + ((self.ctx.inst_ctx.step - self.ctx.last_callback_step) == + self.ctx.callback_steps) { // In run() we have checked the callback consistency with ctx.do_callback let callback = callback.as_ref().unwrap(); @@ -924,11 +924,11 @@ impl<'a> Emu<'a> { let mut current_box_id = 0; let mut current_step_idx = loop { - if current_box_id == vec_traces.len() - 1 - || vec_traces[current_box_id + 1].start_state.step >= emu_trace_start.step + if current_box_id == vec_traces.len() - 1 || + vec_traces[current_box_id + 1].start_state.step >= emu_trace_start.step { - break emu_trace_start.step as usize - - vec_traces[current_box_id].start_state.step as usize; + break emu_trace_start.step as usize - + vec_traces[current_box_id].start_state.step as usize; } current_box_id += 1; }; @@ -1039,8 +1039,8 @@ impl<'a> Emu<'a> { let b = [inst_ctx.b & 0xFFFFFFFF, (inst_ctx.b >> 32) & 0xFFFFFFFF]; let c = [inst_ctx.c & 0xFFFFFFFF, (inst_ctx.c >> 32) & 0xFFFFFFFF]; - let addr1 = (inst.b_offset_imm0 as i64 - + if inst.b_src == SRC_IND { inst_ctx.a as i64 } else { 0 }) as u64; + let addr1 = (inst.b_offset_imm0 as i64 + + if inst.b_src == SRC_IND { inst_ctx.a as i64 } else { 0 }) as u64; let jmp_offset1 = if inst.jmp_offset1 >= 0 { F::from_canonical_u64(inst.jmp_offset1 as u64) @@ -1118,8 +1118,8 @@ impl<'a> Emu<'a> { m32: F::from_bool(inst.m32), addr1: F::from_canonical_u64(addr1), __debug_operation_bus_enabled: F::from_bool( - inst.op_type == ZiskOperationType::Binary - || inst.op_type == ZiskOperationType::BinaryE, + inst.op_type == ZiskOperationType::Binary || + inst.op_type == ZiskOperationType::BinaryE, ), } } diff --git a/emulator/src/emu_full_trace.rs b/emulator/src/emu_full_trace.rs index 67b49842..343e0c05 100644 --- a/emulator/src/emu_full_trace.rs +++ b/emulator/src/emu_full_trace.rs @@ -1,3 +1,3 @@ -use zisk_pil::Main0Row; +use zisk_pil::MainRow; -pub type EmuFullTraceStep = Main0Row; +pub type EmuFullTraceStep = MainRow; diff --git a/pil/src/lib.rs b/pil/src/lib.rs index 5d31b15a..0e95ca27 100644 --- a/pil/src/lib.rs +++ b/pil/src/lib.rs @@ -7,8 +7,5 @@ pub const ARITH_AIRGROUP_ID: usize = 101; pub const ARITH32_AIR_IDS: &[usize] = &[4, 5]; pub const ARITH64_AIR_IDS: &[usize] = &[6]; pub const ARITH3264_AIR_IDS: &[usize] = &[7]; -pub const MEM_AIRGROUP_ID: usize = 105; -pub const MEM_ALIGN_AIR_IDS: &[usize] = &[1]; -pub const MEM_UNALIGNED_AIR_IDS: &[usize] = &[2, 3]; pub const QUICKOPS_AIRGROUP_ID: usize = 102; pub const QUICKOPS_AIR_IDS: &[usize] = &[10]; diff --git a/pil/src/pil_helpers/traces.rs b/pil/src/pil_helpers/traces.rs index f1af3709..1545cdca 100644 --- a/pil/src/pil_helpers/traces.rs +++ b/pil/src/pil_helpers/traces.rs @@ -3,46 +3,46 @@ use proofman_common as common; pub use proofman_macros::trace; -trace!(Main0Row, Main0Trace { +trace!(MainRow, MainTrace { a: [F; 2], b: [F; 2], c: [F; 2], flag: F, pc: F, a_src_imm: F, a_src_mem: F, a_offset_imm0: F, a_imm1: F, a_src_step: F, b_src_imm: F, b_src_mem: F, b_offset_imm0: F, b_imm1: F, b_src_ind: F, ind_width: F, is_external_op: F, op: F, store_ra: F, store_mem: F, store_ind: F, store_offset: F, set_pc: F, jmp_offset1: F, jmp_offset2: F, m32: F, addr1: F, __debug_operation_bus_enabled: F, }); -trace!(Rom0Row, Rom0Trace { +trace!(RomRow, RomTrace { line: F, a_offset_imm0: F, a_imm1: F, b_offset_imm0: F, b_imm1: F, ind_width: F, op: F, store_offset: F, jmp_offset1: F, jmp_offset2: F, flags: F, multiplicity: F, }); -trace!(Mem0Row, Mem0Trace { +trace!(MemRow, MemTrace { addr: F, step: F, sel: F, wr: F, value: [F; 2], addr_changes: F, same_value: F, first_addr_access_is_read: F, }); -trace!(MemAlign0Row, MemAlign0Trace { +trace!(MemAlignRow, MemAlignTrace { addr: F, offset: F, width: F, wr: F, pc: F, reset: F, sel_up_to_down: F, sel_down_to_up: F, reg: [F; 8], sel: [F; 8], sel_prove: F, step: F, }); -trace!(MemAlignRom0Row, MemAlignRom0Trace { +trace!(MemAlignRomRow, MemAlignRomTrace { multiplicity: F, }); -trace!(Binary0Row, Binary0Trace { +trace!(BinaryRow, BinaryTrace { m_op: F, mode32: F, free_in_a: [F; 8], free_in_b: [F; 8], free_in_c: [F; 8], carry: [F; 8], use_last_carry: F, op_is_min_max: F, multiplicity: F, main_step: F, }); -trace!(BinaryTable0Row, BinaryTable0Trace { +trace!(BinaryTableRow, BinaryTableTrace { multiplicity: F, }); -trace!(BinaryExtension0Row, BinaryExtension0Trace { +trace!(BinaryExtensionRow, BinaryExtensionTrace { op: F, in1: [F; 8], in2_low: F, out: [[F; 2]; 8], op_is_shift: F, in2: [F; 2], main_step: F, multiplicity: F, }); -trace!(BinaryExtensionTable0Row, BinaryExtensionTable0Trace { +trace!(BinaryExtensionTableRow, BinaryExtensionTableTrace { multiplicity: F, }); -trace!(SpecifiedRanges0Row, SpecifiedRanges0Trace { +trace!(SpecifiedRangesRow, SpecifiedRangesTrace { mul: [F; 2], }); -trace!(U8Air0Row, U8Air0Trace { +trace!(U8AirRow, U8AirTrace { mul: F, }); diff --git a/state-machines/binary/src/binary_basic.rs b/state-machines/binary/src/binary_basic.rs index 3755609b..d2dd1276 100644 --- a/state-machines/binary/src/binary_basic.rs +++ b/state-machines/binary/src/binary_basic.rs @@ -10,9 +10,9 @@ use proofman_common::AirInstance; use proofman_util::{timer_start_trace, timer_stop_and_log_trace}; use rayon::Scope; use sm_common::{create_prover_buffer, OpResult, Provable}; -use zisk_pil::{Binary0Row, Binary0Trace, BINARY_AIR_IDS, BINARY_TABLE_AIR_IDS, ZISK_AIRGROUP_ID}; use std::cmp::Ordering as CmpOrdering; use zisk_core::{zisk_ops::ZiskOp, ZiskRequiredOperation}; +use zisk_pil::{BinaryRow, BinaryTrace, BINARY_AIR_IDS, BINARY_TABLE_AIR_IDS, ZISK_AIRGROUP_ID}; use crate::{BinaryBasicTableOp, BinaryBasicTableSM}; @@ -144,9 +144,9 @@ impl BinaryBasicSM { pub fn process_slice( operation: &ZiskRequiredOperation, multiplicity: &mut [u64], - ) -> Binary0Row { + ) -> BinaryRow { // Create an empty trace - let mut row: Binary0Row = Default::default(); + let mut row: BinaryRow = Default::default(); // Execute the opcode let c: u64; @@ -658,8 +658,7 @@ impl BinaryBasicSM { timer_start_trace!(BINARY_TRACE); let pctx = wcm.get_pctx(); let air = pctx.pilout.get_air(ZISK_AIRGROUP_ID, BINARY_AIR_IDS[0]); - let air_binary_table = - pctx.pilout.get_air(ZISK_AIRGROUP_ID, BINARY_TABLE_AIR_IDS[0]); + let air_binary_table = pctx.pilout.get_air(ZISK_AIRGROUP_ID, BINARY_TABLE_AIR_IDS[0]); assert!(operations.len() <= air.num_rows()); info!( @@ -672,7 +671,7 @@ impl BinaryBasicSM { let mut multiplicity_table = vec![0u64; air_binary_table.num_rows()]; let mut trace_buffer = - Binary0Trace::::map_buffer(prover_buffer, air.num_rows(), offset as usize).unwrap(); + BinaryTrace::::map_buffer(prover_buffer, air.num_rows(), offset as usize).unwrap(); for (i, operation) in operations.iter().enumerate() { let row = Self::process_slice(operation, &mut multiplicity_table); @@ -681,7 +680,7 @@ impl BinaryBasicSM { timer_stop_and_log_trace!(BINARY_TRACE); timer_start_trace!(BINARY_PADDING); - let padding_row = Binary0Row:: { + let padding_row = BinaryRow:: { m_op: F::from_canonical_u8(0x20), multiplicity: F::zero(), main_step: F::zero(), /* TODO: remove, since main_step is just for diff --git a/state-machines/binary/src/binary_basic_table.rs b/state-machines/binary/src/binary_basic_table.rs index 7f2070d9..028e8cbd 100644 --- a/state-machines/binary/src/binary_basic_table.rs +++ b/state-machines/binary/src/binary_basic_table.rs @@ -10,7 +10,7 @@ use proofman_common::AirInstance; use rayon::prelude::*; use sm_common::create_prover_buffer; use zisk_core::{zisk_ops::ZiskOp, P2_16, P2_17, P2_18, P2_19, P2_8}; -use zisk_pil::{ZISK_AIRGROUP_ID, BINARY_TABLE_AIR_IDS}; +use zisk_pil::{BINARY_TABLE_AIR_IDS, ZISK_AIRGROUP_ID}; #[derive(Debug, Clone, PartialEq, Copy)] #[repr(u8)] diff --git a/state-machines/binary/src/binary_extension.rs b/state-machines/binary/src/binary_extension.rs index d53414e9..4e66044a 100644 --- a/state-machines/binary/src/binary_extension.rs +++ b/state-machines/binary/src/binary_extension.rs @@ -17,7 +17,10 @@ use proofman_util::{timer_start_debug, timer_stop_and_log_debug}; use rayon::Scope; use sm_common::{create_prover_buffer, OpResult, Provable}; use zisk_core::{zisk_ops::ZiskOp, ZiskRequiredOperation}; -use zisk_pil::{BinaryExtension0Row, BinaryExtension0Trace, BINARY_EXTENSION_AIR_IDS, BINARY_EXTENSION_TABLE_AIR_IDS, ZISK_AIRGROUP_ID}; +use zisk_pil::{ + BinaryExtensionRow, BinaryExtensionTrace, BINARY_EXTENSION_AIR_IDS, + BINARY_EXTENSION_TABLE_AIR_IDS, ZISK_AIRGROUP_ID, +}; const MASK_32: u64 = 0xFFFFFFFF; const MASK_64: u64 = 0xFFFFFFFFFFFFFFFF; @@ -117,12 +120,12 @@ impl BinaryExtensionSM { fn opcode_is_shift(opcode: ZiskOp) -> bool { match opcode { - ZiskOp::Sll - | ZiskOp::Srl - | ZiskOp::Sra - | ZiskOp::SllW - | ZiskOp::SrlW - | ZiskOp::SraW => true, + ZiskOp::Sll | + ZiskOp::Srl | + ZiskOp::Sra | + ZiskOp::SllW | + ZiskOp::SrlW | + ZiskOp::SraW => true, ZiskOp::SignExtendB | ZiskOp::SignExtendH | ZiskOp::SignExtendW => false, @@ -134,12 +137,12 @@ impl BinaryExtensionSM { match opcode { ZiskOp::SllW | ZiskOp::SrlW | ZiskOp::SraW => true, - ZiskOp::Sll - | ZiskOp::Srl - | ZiskOp::Sra - | ZiskOp::SignExtendB - | ZiskOp::SignExtendH - | ZiskOp::SignExtendW => false, + ZiskOp::Sll | + ZiskOp::Srl | + ZiskOp::Sra | + ZiskOp::SignExtendB | + ZiskOp::SignExtendH | + ZiskOp::SignExtendW => false, _ => panic!("BinaryExtensionSM::opcode_is_shift() got invalid opcode={:?}", opcode), } @@ -149,7 +152,7 @@ impl BinaryExtensionSM { operation: &ZiskRequiredOperation, multiplicity: &mut [u64], range_check: &mut HashMap, - ) -> BinaryExtension0Row { + ) -> BinaryExtensionRow { // Get the opcode let op = operation.opcode; @@ -158,7 +161,7 @@ impl BinaryExtensionSM { // Create an empty trace let mut row = - BinaryExtension0Row:: { op: F::from_canonical_u8(op), ..Default::default() }; + BinaryExtensionRow:: { op: F::from_canonical_u8(op), ..Default::default() }; // Set if the opcode is a shift operation let op_is_shift = Self::opcode_is_shift(opcode); @@ -405,7 +408,7 @@ impl BinaryExtensionSM { let mut multiplicity_table = vec![0u64; air_binary_extension_table.num_rows()]; let mut range_check: HashMap = HashMap::new(); let mut trace_buffer = - BinaryExtension0Trace::::map_buffer(prover_buffer, air.num_rows(), offset as usize) + BinaryExtensionTrace::::map_buffer(prover_buffer, air.num_rows(), offset as usize) .unwrap(); for (i, operation) in operations.iter().enumerate() { @@ -416,7 +419,7 @@ impl BinaryExtensionSM { timer_start_debug!(BINARY_EXTENSION_PADDING); let padding_row = - BinaryExtension0Row:: { op: F::from_canonical_u64(0x25), ..Default::default() }; + BinaryExtensionRow:: { op: F::from_canonical_u64(0x25), ..Default::default() }; for i in operations.len()..air.num_rows() { trace_buffer[i] = padding_row; diff --git a/state-machines/binary/src/binary_extension_table.rs b/state-machines/binary/src/binary_extension_table.rs index 8fbc1de3..e48a8fcf 100644 --- a/state-machines/binary/src/binary_extension_table.rs +++ b/state-machines/binary/src/binary_extension_table.rs @@ -10,7 +10,7 @@ use proofman_common::AirInstance; use rayon::prelude::*; use sm_common::create_prover_buffer; use zisk_core::{zisk_ops::ZiskOp, P2_11, P2_19, P2_8}; -use zisk_pil::{ZISK_AIRGROUP_ID, BINARY_EXTENSION_TABLE_AIR_IDS}; +use zisk_pil::{BINARY_EXTENSION_TABLE_AIR_IDS, ZISK_AIRGROUP_ID}; #[derive(Debug, Clone, PartialEq, Copy)] #[repr(u8)] @@ -47,9 +47,7 @@ impl BinaryExtensionTableSM { pub fn new(wcm: Arc>, airgroup_id: usize, air_ids: &[usize]) -> Arc { let pctx = wcm.get_pctx(); - let air = pctx - .pilout - .get_air(ZISK_AIRGROUP_ID, BINARY_EXTENSION_TABLE_AIR_IDS[0]); + let air = pctx.pilout.get_air(ZISK_AIRGROUP_ID, BINARY_EXTENSION_TABLE_AIR_IDS[0]); let binary_extension_table = Self { wcm: wcm.clone(), @@ -130,11 +128,8 @@ impl BinaryExtensionTableSM { let mut multiplicity = self.multiplicity.lock().unwrap(); - let (is_myne, instance_global_idx) = dctx.add_instance( - ZISK_AIRGROUP_ID, - BINARY_EXTENSION_TABLE_AIR_IDS[0], - 1, - ); + let (is_myne, instance_global_idx) = + dctx.add_instance(ZISK_AIRGROUP_ID, BINARY_EXTENSION_TABLE_AIR_IDS[0], 1); let owner = dctx.owner(instance_global_idx); let mut multiplicity_ = std::mem::take(&mut *multiplicity); diff --git a/state-machines/main/src/main_sm.rs b/state-machines/main/src/main_sm.rs index 3283e69a..d9532d81 100644 --- a/state-machines/main/src/main_sm.rs +++ b/state-machines/main/src/main_sm.rs @@ -13,7 +13,7 @@ use proofman_common::{AirInstance, ProofCtx}; use proofman::WitnessComponent; use sm_arith::ArithSM; use zisk_pil::{ - Main0Row, Main0Trace, BINARY_AIR_IDS, BINARY_EXTENSION_AIR_IDS, MAIN_AIR_IDS, ZISK_AIRGROUP_ID, + MainRow, MainTrace, BINARY_AIR_IDS, BINARY_EXTENSION_AIR_IDS, MAIN_AIR_IDS, ZISK_AIRGROUP_ID, }; use ziskemu::{Emu, EmuTrace, ZiskEmulator}; @@ -87,12 +87,12 @@ impl MainSM { // Set Row 0 of the current segment let row0 = if segment_id == 0 { - Main0Row:: { + MainRow:: { pc: F::from_canonical_u64(ROM_ENTRY), op: F::from_canonical_u8(ZiskOp::CopyB.code()), a_src_imm: F::one(), b_src_imm: F::one(), - ..Main0Row::default() + ..MainRow::default() } } else { let emu_trace_previous = vec_traces[segment_id - 1].steps.last().unwrap(); @@ -100,7 +100,7 @@ impl MainSM { Emu::from_emu_trace_start(zisk_rom, &vec_traces[segment_id - 1].last_state); let row_previous = emu.step_slice_full_trace(emu_trace_previous); - Main0Row:: { + MainRow:: { set_pc: row_previous.set_pc, jmp_offset1: row_previous.jmp_offset1, jmp_offset2: if row_previous.flag == F::one() { @@ -120,21 +120,21 @@ impl MainSM { pc: row_previous.pc, a_src_imm: F::one(), b_src_imm: F::one(), - ..Main0Row::default() + ..MainRow::default() } }; let mut emu = Emu::from_emu_trace_start(zisk_rom, &segment_trace.start_state); - let rng = offset as usize..(offset as usize + Main0Row::::ROW_SIZE); + let rng = offset as usize..(offset as usize + MainRow::::ROW_SIZE); iectx.prover_buffer[rng].copy_from_slice(row0.as_slice()); // Set Rows 1 to N of the current segment (N = maximum number of air rows) let total_rows = segment_trace.steps.len(); const SLICE_ROWS: usize = 4096; - let mut partial_trace = Main0Trace::::new(SLICE_ROWS); + let mut partial_trace = MainTrace::::new(SLICE_ROWS); - let mut last_row = Main0Row::::default(); + let mut last_row = MainRow::::default(); for slice in (0..(air.num_rows())).step_by(SLICE_ROWS) { // process the steps of the chunk let slice_start = std::cmp::min(slice, total_rows); @@ -158,7 +158,7 @@ impl MainSM { //copy the chunk to the prover buffer let partial_buffer = partial_trace.buffer.as_ref().unwrap(); - let buffer_offset_slice = offset as usize + (slice + 1) * Main0Row::::ROW_SIZE; + let buffer_offset_slice = offset as usize + (slice + 1) * MainRow::::ROW_SIZE; let rng = buffer_offset_slice..buffer_offset_slice + partial_buffer.len(); iectx.prover_buffer[rng].copy_from_slice(partial_buffer); diff --git a/state-machines/mem/src/lib.rs b/state-machines/mem/src/lib.rs index 47dd31fd..1f3dfb54 100644 --- a/state-machines/mem/src/lib.rs +++ b/state-machines/mem/src/lib.rs @@ -1,7 +1,7 @@ mod mem_align_sm; -mod mem_sm; mod mem_proxy; +mod mem_sm; pub use mem_align_sm::*; -pub use mem_sm::*; pub use mem_proxy::*; +pub use mem_sm::*; diff --git a/state-machines/mem/src/mem_align_sm.rs b/state-machines/mem/src/mem_align_sm.rs index 0eeb4c38..5dfc7d2f 100644 --- a/state-machines/mem/src/mem_align_sm.rs +++ b/state-machines/mem/src/mem_align_sm.rs @@ -8,7 +8,7 @@ use proofman::{WitnessComponent, WitnessManager}; use proofman_common::{ExecutionCtx, ProofCtx, SetupCtx}; use rayon::Scope; use sm_common::{MemOp, OpResult, Provable}; -use zisk_pil::{MEM_AIRGROUP_ID, MEM_ALIGN_AIR_IDS}; +use zisk_pil::{MEM_ALIGN_AIR_IDS, ZISK_AIRGROUP_ID}; const PROVE_CHUNK_SIZE: usize = 1 << 12; @@ -29,7 +29,7 @@ impl MemAlignSM { wcm.register_component( mem_aligned_sm.clone(), - Some(MEM_AIRGROUP_ID), + Some(ZISK_AIRGROUP_ID), Some(MEM_ALIGN_AIR_IDS), ); diff --git a/state-machines/mem/src/mem_proxy.rs b/state-machines/mem/src/mem_proxy.rs index 4a3466f8..d645126e 100644 --- a/state-machines/mem/src/mem_proxy.rs +++ b/state-machines/mem/src/mem_proxy.rs @@ -85,7 +85,8 @@ impl MemProxy { new_aligned.extend(aligned_accesses); }); - // Step 3. Concatenate the new aligned memory accesses with the original aligned memory accesses + // Step 3. Concatenate the new aligned memory accesses with the original aligned memory + // accesses aligned.extend(new_aligned); // Step 4. Prove the aligned memory accesses using mem state machine @@ -106,7 +107,8 @@ impl MemProxy { match mem_ops { MemOps::OneRead => { // Look for last write to the same address - let last_write_addr = Self::get_last_write(addr, aligned_accesses); + let last_write_addr = + Self::get_last_write(addr, unaligned_access.step, aligned_accesses); let last_write_addr = last_write_addr.unwrap_or(ZiskRequiredMemory { step: unaligned_access.step, is_write: false, @@ -118,7 +120,8 @@ impl MemProxy { } MemOps::OneWrite => { // Look for last write to the same address - let last_write_addr = Self::get_last_write(addr, aligned_accesses); + let last_write_addr = + Self::get_last_write(addr, unaligned_access.step, aligned_accesses); // Modify the value of the write to the same address let mut last_write_addr = last_write_addr.unwrap_or(ZiskRequiredMemory { @@ -134,8 +137,10 @@ impl MemProxy { } MemOps::TwoReads => { // Look for last write to the same address and same address + 8 - let last_write_addr = Self::get_last_write(addr, aligned_accesses); - let last_write_addr_p = Self::get_last_write(addr + 8, aligned_accesses); + let last_write_addr = + Self::get_last_write(addr, unaligned_access.step, aligned_accesses); + let last_write_addr_p = + Self::get_last_write(addr + 8, unaligned_access.step, aligned_accesses); let last_write_addr = last_write_addr.unwrap_or(ZiskRequiredMemory { step: unaligned_access.step, @@ -157,8 +162,10 @@ impl MemProxy { } MemOps::TwoWrites => { // Look for last write to the same address and same address + 8 - let last_write_addr = Self::get_last_write(addr, aligned_accesses); - let last_write_addr_p = Self::get_last_write(addr + 8, aligned_accesses); + let last_write_addr = + Self::get_last_write(addr, unaligned_access.step, aligned_accesses); + let last_write_addr_p = + Self::get_last_write(addr + 8, unaligned_access.step, aligned_accesses); let mut last_write_addr = last_write_addr.unwrap_or(ZiskRequiredMemory { step: unaligned_access.step, @@ -185,16 +192,37 @@ impl MemProxy { #[inline(always)] fn get_last_write( addr: u64, + step: u64, aligned_accesses: &[ZiskRequiredMemory], ) -> Option { - Some( - aligned_accesses - .iter() - .rev() - .find(|mem| mem.address == addr && mem.is_write) - .unwrap() - .clone(), - ) + // Step 1: Find the start of the range for `addr` + let start_index = + match aligned_accesses.binary_search_by_key(&addr, |access| access.address) { + Ok(mut index) => { + // Backtrack to find the first occurrence of `addr` + while index > 0 && aligned_accesses[index - 1].address == addr { + index -= 1; + } + index + } + Err(index) => index, // If no match, use the insertion point as before + }; + + // Step 2: Iterate from start_index forward, storing the last valid write + let mut last_write = None; + for access in &aligned_accesses[start_index..] { + if access.address != addr { + break; // Stop if we move past the given address + } + if access.step >= step { + break; // Stop if step is not less than the given step + } + if access.is_write { + last_write = Some(access.clone()); // Update last write if conditions are met + } + } + + last_write } #[inline(always)] diff --git a/state-machines/mem/src/mem_sm.rs b/state-machines/mem/src/mem_sm.rs index d16fa680..6e871945 100644 --- a/state-machines/mem/src/mem_sm.rs +++ b/state-machines/mem/src/mem_sm.rs @@ -11,7 +11,7 @@ use rayon::prelude::*; use sm_common::create_prover_buffer; use zisk_core::ZiskRequiredMemory; -use zisk_pil::{Mem0Trace, MEM_AIRGROUP_ID, MEM_AIR_IDS, ZISK_AIRGROUP_ID}; +use zisk_pil::{MemTrace, MEM_AIR_IDS, ZISK_AIRGROUP_ID}; pub struct MemSM { // Witness computation manager @@ -27,7 +27,7 @@ impl MemSM { let mem_sm = Self { wcm: wcm.clone(), registered_predecessors: AtomicU32::new(0) }; let mem_sm = Arc::new(mem_sm); - wcm.register_component(mem_sm.clone(), Some(MEM_AIRGROUP_ID), Some(MEM_AIR_IDS)); + wcm.register_component(mem_sm.clone(), Some(ZISK_AIRGROUP_ID), Some(MEM_AIR_IDS)); mem_sm } @@ -50,7 +50,7 @@ impl MemSM { let ectx = self.wcm.get_ectx(); let sctx = self.wcm.get_sctx(); - let air = pctx.pilout.get_air(MEM_AIRGROUP_ID, MEM_AIR_IDS[0]); + let air = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_AIR_IDS[0]); let num_chunks = (mem_accesses.len() as f64 / (air.num_rows() - 1) as f64).ceil() as usize; @@ -67,9 +67,9 @@ impl MemSM { let (buffer, offset) = create_prover_buffer::(&ectx, &sctx, ZISK_AIRGROUP_ID, MEM_AIR_IDS[0]); - prover_buffers.lock().unwrap().push(buffer); - offsets.push(offset); - global_idxs.push(global_idx); + prover_buffers.lock().unwrap()[i] = buffer; + offsets[i] = offset; + global_idxs[i] = global_idx; } } @@ -82,6 +82,7 @@ impl MemSM { }; let prover_buffer = std::mem::take(&mut prover_buffers.lock().unwrap()[segment_id]); + self.prove_instance( mem_ops, mem_first_row, @@ -113,20 +114,22 @@ impl MemSM { global_idx: usize, ) -> Result<(), Box> { let pctx = self.wcm.get_pctx(); + let sctx = self.wcm.get_sctx(); // STEP2: Process the memory inputs and convert them to AIR instances - let air = pctx.pilout.get_air(MEM_AIRGROUP_ID, MEM_AIR_IDS[0]); + let air = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_AIR_IDS[0]); let max_rows_per_segment = air.num_rows() - 1; assert!(mem_ops.len() > 0 && mem_ops.len() <= max_rows_per_segment); - // In a Mem AIR instance the first row is a dummy row used for the continuations between AIR segments - // In a Memory AIR instance, the first row is reserved as a dummy row. - // This dummy row is used to facilitate the continuation state between different AIR segments. - // It ensures seamless transitions when multiple AIR segments are processed consecutively. - // This design avoids discontinuities in memory access patterns and ensures that the memory trace is continuous, - // For this reason we use AIR num_rows - 1 as the number of rows in each memory AIR instance + // In a Mem AIR instance the first row is a dummy row used for the continuations between AIR + // segments In a Memory AIR instance, the first row is reserved as a dummy row. + // This dummy row is used to facilitate the continuation state between different AIR + // segments. It ensures seamless transitions when multiple AIR segments are + // processed consecutively. This design avoids discontinuities in memory access + // patterns and ensures that the memory trace is continuous, For this reason we use + // AIR num_rows - 1 as the number of rows in each memory AIR instance // Create a vector of Mem0Row instances, one for each memory operation // Recall that first row is a dummy row used for the continuations between AIR segments @@ -134,14 +137,10 @@ impl MemSM { // in the prove_witnesses method we drain the memory operations in chunks of n - 1 rows let mut trace = - Mem0Trace::::map_buffer(&mut prover_buffer, air.num_rows(), offset as usize) - .unwrap(); + MemTrace::::map_buffer(&mut prover_buffer, air.num_rows(), offset as usize).unwrap(); - let segment_id_field = F::from_canonical_u64(segment_id as u64); - let is_last_segment_field = F::from_bool(is_last_segment); - - // STEP1. Add the first row to the output vector as equal to the last row of the previous segment - // CASE: last row of segment is read + // STEP1. Add the first row to the output vector as equal to the last row of the previous + // segment CASE: last row of segment is read // // S[n-1] wr = 0, sel = 1, addr, step, value // S+1[0] wr = 0, sel = 0, addr, step, value @@ -199,8 +198,8 @@ impl MemSM { let addr_changes = trace[i - 1].addr != trace[i].addr; trace[i].addr_changes = if addr_changes { F::one() } else { F::zero() }; - let same_value = trace[i - 1].value[0] == trace[i].value[0] - && trace[i - 1].value[1] == trace[i].value[1]; + let same_value = trace[i - 1].value[0] == trace[i].value[0] && + trace[i - 1].value[1] == trace[i].value[1]; trace[i].same_value = if same_value { F::one() } else { F::zero() }; let first_addr_access_is_read = addr_changes && !mem_op.is_write; @@ -209,7 +208,8 @@ impl MemSM { } // STEP3. Add dummy rows to the output vector to fill the remaining rows - //PADDING: At end of memory fill with same addr, incrementing step, same value, sel = 0, rd = 1, wr = 0 + //PADDING: At end of memory fill with same addr, incrementing step, same value, sel = 0, rd + // = 1, wr = 0 let last_row_idx = mem_ops.len(); let addr = trace[last_row_idx].addr; let mut step = trace[last_row_idx].step; @@ -234,14 +234,21 @@ impl MemSM { trace[i].first_addr_access_is_read = F::zero(); } - let air_instance = AirInstance::new( + let mut air_instance = AirInstance::new( self.wcm.get_sctx(), - MEM_AIRGROUP_ID, + ZISK_AIRGROUP_ID, MEM_AIR_IDS[0], Some(segment_id), prover_buffer, ); + air_instance.set_airvalue( + &sctx, + "Mem.mem_segment", + F::from_canonical_u64(segment_id as u64), + ); + air_instance.set_airvalue(&sctx, "Mem.mem_last_segment", F::from_bool(is_last_segment)); + pctx.air_instance_repo.add_air_instance(air_instance, Some(global_idx)); Ok(()) @@ -253,51 +260,3 @@ impl MemSM { } impl WitnessComponent for MemSM {} - -#[cfg(test)] -mod tests { - // use super::*; - // use p3_field::AbstractField; - // use p3_goldilocks::Goldilocks; - // use zisk_core::ZiskRequiredMemory; - - // type GL = Goldilocks; - - // #[test] - // fn test_calculate_witness_rows() { - // let mem_ops = vec![ - // ZiskRequiredMemory::new(0, true, 0, 1, 0), - // ZiskRequiredMemory::new(1, false, 1, 1, 0), - // ZiskRequiredMemory::new(2, true, 2, 1, 0), - // ZiskRequiredMemory::new(3, false, 3, 1, 0), - // ZiskRequiredMemory::new(4, true, 4, 1, 0), - // ZiskRequiredMemory::new(5, false, 5, 1, 0), - // ZiskRequiredMemory::new(6, true, 6, 1, 0), - // ZiskRequiredMemory::new(7, false, 7, 1, 0), - // ZiskRequiredMemory::new(8, true, 8, 1, 0), - // ZiskRequiredMemory::new(9, false, 9, 1, 0), - // ]; - - // let witness_rows = MemWitness::calculate_witness_rows::(mem_ops, 10, 0, true); - - // assert_eq!(witness_rows.len(), 10); - - // // Check the dummy row - // assert_eq!(witness_rows[0].mem_segment, GL::from_canonical_u64(0)); - // assert_eq!(witness_rows[0].mem_last_segment, GL::from_bool(true)); - // assert_eq!(witness_rows[0].addr, GL::default()); - // assert_eq!(witness_rows[0].step, GL::default()); - // assert_eq!(witness_rows[0].sel, GL::default()); - // assert_eq!(witness_rows[0].wr, GL::default()); - // assert_eq!(witness_rows[0].value, [GL::default(), GL::default()]); - // assert_eq!(witness_rows[0].addr_changes, GL::default()); - // assert_eq!(witness_rows[0].same_value, GL::default()); - // assert_eq!(witness_rows[0].first_addr_access_is_read, GL::default()); - - // // Check the remaining rows - // for i in 1..10 { - // assert_eq!(witness_rows[i].mem_segment, GL::from_canonical_u64(0)); - // // ... - // } - // } -} diff --git a/state-machines/rom/src/rom.rs b/state-machines/rom/src/rom.rs index 9fe9aa4d..301a981f 100644 --- a/state-machines/rom/src/rom.rs +++ b/state-machines/rom/src/rom.rs @@ -6,9 +6,7 @@ use proofman_common::{AirInstance, BufferAllocator, SetupCtx}; use proofman_util::create_buffer_fast; use zisk_core::{Riscv2zisk, ZiskPcHistogram, ZiskRom, SRC_IMM}; -use zisk_pil::{ - Pilout, Rom0Row, Rom0Trace, ZISK_AIRGROUP_ID, MAIN_AIR_IDS, ROM_AIR_IDS, -}; +use zisk_pil::{Pilout, RomRow, RomTrace, ROM_AIR_IDS, ZISK_AIRGROUP_ID}; //use ziskemu::ZiskEmulatorErr; use std::error::Error; @@ -41,13 +39,13 @@ impl RomSM { let sctx = self.wcm.get_sctx(); let num_rows = - self.wcm.get_pctx().pilout.get_air(ZISK_AIRGROUP_ID, MAIN_AIR_IDS[0]).num_rows(); + self.wcm.get_pctx().pilout.get_air(ZISK_AIRGROUP_ID, ROM_AIR_IDS[0]).num_rows(); let prover_buffer = Self::compute_trace_rom(rom, buffer_allocator, &sctx, pc_histogram, num_rows as u64)?; let air_instance = - AirInstance::new(sctx.clone(), ZISK_AIRGROUP_ID, MAIN_AIR_IDS[0], None, prover_buffer); + AirInstance::new(sctx.clone(), ZISK_AIRGROUP_ID, ROM_AIR_IDS[0], None, prover_buffer); self.wcm.get_pctx().air_instance_repo.add_air_instance(air_instance, Some(instance_gid)); @@ -100,7 +98,7 @@ impl RomSM { // Create an empty ROM trace let mut rom_trace = - Rom0Trace::::map_buffer(&mut prover_buffer, num_rows, offsets[0] as usize) + RomTrace::::map_buffer(&mut prover_buffer, num_rows, offsets[0] as usize) .expect("RomSM::compute_trace() failed mapping buffer to ROMS0Trace"); // For every instruction in the rom, fill its corresponding ROM trace @@ -174,7 +172,7 @@ impl RomSM { // Padd with zeroes for i in number_of_instructions..num_rows { - rom_trace[i] = Rom0Row::default(); + rom_trace[i] = RomRow::default(); } Ok(prover_buffer) diff --git a/witness-computation/src/executor.rs b/witness-computation/src/executor.rs index 9262ee62..6d928e75 100644 --- a/witness-computation/src/executor.rs +++ b/witness-computation/src/executor.rs @@ -130,8 +130,7 @@ impl ZiskExecutor { // across different operation types. Currently, we are only collecting data for // Binary and BinaryE operations. let air_binary = pctx.pilout.get_air(ZISK_AIRGROUP_ID, BINARY_AIR_IDS[0]); - let air_binary_e = - pctx.pilout.get_air(ZISK_AIRGROUP_ID, BINARY_EXTENSION_AIR_IDS[0]); + let air_binary_e = pctx.pilout.get_air(ZISK_AIRGROUP_ID, BINARY_EXTENSION_AIR_IDS[0]); let mut op_sizes = [0u64; ZISK_OPERATION_TYPE_VARIANTS]; // The starting points for the Main is allocated using None operation @@ -229,9 +228,7 @@ impl ZiskExecutor { let (airgroup_id, air_id) = match emu_slice.op_type { ZiskOperationType::None => (ZISK_AIRGROUP_ID, MAIN_AIR_IDS[0]), ZiskOperationType::Binary => (ZISK_AIRGROUP_ID, BINARY_AIR_IDS[0]), - ZiskOperationType::BinaryE => { - (ZISK_AIRGROUP_ID, BINARY_EXTENSION_AIR_IDS[0]) - } + ZiskOperationType::BinaryE => (ZISK_AIRGROUP_ID, BINARY_EXTENSION_AIR_IDS[0]), _ => panic!("Invalid operation type"), }; let segment_id = match emu_slice.op_type { From 668c37d063097878c7cfb3f547bf331aa94a2ac3 Mon Sep 17 00:00:00 2001 From: Xavier Pinsach <10213118+xavi-pinsach@users.noreply.github.com> Date: Wed, 6 Nov 2024 11:11:55 +0000 Subject: [PATCH 09/44] wip --- state-machines/mem/src/mem_proxy.rs | 14 ++++++++++++-- state-machines/mem/src/mem_sm.rs | 3 --- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/state-machines/mem/src/mem_proxy.rs b/state-machines/mem/src/mem_proxy.rs index d645126e..5eb6e7a2 100644 --- a/state-machines/mem/src/mem_proxy.rs +++ b/state-machines/mem/src/mem_proxy.rs @@ -6,7 +6,7 @@ use std::sync::{ use crate::{MemAlignSM, MemSM}; use p3_field::PrimeField; use proofman_util::{timer_start_debug, timer_stop_and_log_debug}; -use zisk_core::ZiskRequiredMemory; +use zisk_core::{ZiskRequiredMemory, RAM_ADDR, SYS_ADDR}; use proofman::{WitnessComponent, WitnessManager}; @@ -89,8 +89,18 @@ impl MemProxy { // accesses aligned.extend(new_aligned); + timer_start_debug!(MEM_SORT_2); + aligned.sort_by_key(|mem| mem.address); + timer_stop_and_log_debug!(MEM_SORT_2); + + let mut idx = 0; + while aligned[idx].address < RAM_ADDR && idx < aligned.len() { + idx += 1; + } + let (_input_aligned, aligned) = aligned.split_at_mut(idx); + // Step 4. Prove the aligned memory accesses using mem state machine - self.mem_sm.prove(&mut aligned); + self.mem_sm.prove(aligned); Ok(()) } diff --git a/state-machines/mem/src/mem_sm.rs b/state-machines/mem/src/mem_sm.rs index 6e871945..a45bb684 100644 --- a/state-machines/mem/src/mem_sm.rs +++ b/state-machines/mem/src/mem_sm.rs @@ -42,9 +42,6 @@ impl MemSM { pub fn prove(&self, mem_accesses: &mut [ZiskRequiredMemory]) { // Sort the (full) aligned memory accesses - timer_start_debug!(MEM_SORT_2); - mem_accesses.sort_by_key(|mem| mem.address); - timer_stop_and_log_debug!(MEM_SORT_2); let pctx = self.wcm.get_pctx(); let ectx = self.wcm.get_ectx(); From 59b5376dcf4745d1e58152aad795b1879581a38b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9ctor=20Masip?= Date: Wed, 6 Nov 2024 11:28:16 +0000 Subject: [PATCH 10/44] minor changes --- state-machines/mem/Cargo.toml | 5 ++- state-machines/mem/pil/mem_align_rom.pil | 50 +++++++++++++++------- state-machines/mem/src/mem_align_rom_sm.rs | 7 ++- 3 files changed, 42 insertions(+), 20 deletions(-) diff --git a/state-machines/mem/Cargo.toml b/state-machines/mem/Cargo.toml index 39264a00..3da87e7b 100644 --- a/state-machines/mem/Cargo.toml +++ b/state-machines/mem/Cargo.toml @@ -8,13 +8,16 @@ sm-common = { path = "../common" } zisk-core = { path = "../../core" } zisk-pil = { path = "../../pil" } -p3-field = { workspace=true } proofman-common = { workspace = true } proofman-macros = { workspace = true } proofman-util = { workspace = true } proofman = { workspace = true } +pil-std-lib = { workspace = true } + +p3-field = { workspace=true } log = { workspace = true } rayon = { workspace = true } +num-bigint = { workspace = true } [features] default = [] diff --git a/state-machines/mem/pil/mem_align_rom.pil b/state-machines/mem/pil/mem_align_rom.pil index 863959a5..322bcd2d 100644 --- a/state-machines/mem/pil/mem_align_rom.pil +++ b/state-machines/mem/pil/mem_align_rom.pil @@ -13,7 +13,7 @@ const int MEM_ALIGN_ROM_SIZE = P2_8; // Note1: The offset and width are sufficient to group programs with the same number of operations. // Note2: The first instruction is always a read. -airtemplate MemAlignRom(const int N = MEM_ALIGN_ROM_SIZE, const int CHUNK_NUM = 8, const int disable_fixed = 0) { +airtemplate MemAlignRom(const int N = MEM_ALIGN_ROM_SIZE, const int CHUNK_NUM = 8, const int DEFAULT_OFFSET = 0, const int DEFAULT_WIDTH = 8, const int disable_fixed = 0) { if (N < MEM_ALIGN_ROM_SIZE) { error(`N must be at least ${MEM_ALIGN_ROM_SIZE}, but N=${N} was provided`); } @@ -28,12 +28,26 @@ airtemplate MemAlignRom(const int N = MEM_ALIGN_ROM_SIZE, const int CHUNK_NUM = return; } - // Not all combinations of offset and width are valid for each program. - // Moreover, offset is set to 0 and width to 8 in aligned memory accesses. - // size - col fixed OFFSET = [[[0,0]:3, [0,1]:3, [0,2]:3, [0,3]:3, [0,4]:3, [0,5]:2, [0,6]:2, [0,7]], // RV 6+6*4+4+4+2 = 40 | 40 - [[0,0,0]:3, [0,0,1]:3, [0,0,2]:3, [0,0,3]:3, [0,0,4]:3, [0,0,5]:2, [0,0,6]:2, [0,0,7]], // RWV 9+9*4+6+6+3 = 60 | 100 - [[0,1,0], [0,2,0], [0,3,0], [0,4,0], [0,5,0]:2, [0,6,0]:2, [0,7,0]:3], // RVR 3*4+6+6+9 = 33 | 133 + // Define the size of each program: RV, RWV, RVR, RWVWR + const int psize[4] = [2, 3, 3, 5]; + + // Not all combinations of offset and width are valid for each program: + const int one_word_combinations = 20; // (0..4,[1,2,4]), (5,6,[1,2]), (7,[1]) -> 5*3 + 2*2 + 1*1 = 20 + const int two_word_combinations = 11; // (1..4,[8]), (5,6,[4,8]), (7,[2,4,8]) -> 4*1 + 2*2 + 1*3 = 11 + + // table_size = combinations * program_size + const int tsize[4] = [one_word_combinations*psize[0], one_word_combinations*psize[1], two_word_combinations*psize[2], two_word_combinations*psize[3]]; + // size + // RV 6+6*4+4+4+2 = 40 | 40 + // RWV 9+9*4+6+6+3 = 60 | 100 + // RVR 3*4+6+6+9 = 33 | 133 + // RWVWR 5*4+10+10+15 = 55 | 188 => N = 2^8 + + // Moreover, offset is set to DEFAULT_OFFSET and width to DEFAULT_WIDTH in aligned memory accesses. + // size + col fixed OFFSET = [[[0,0]:3, [0,1]:3, [0,2]:3, [0,3]:3, [0,4]:3, [0,5]:2, [0,6]:2, [0,7]], // RV 6+6*4+4+4+2 = 40 | 40 + [[0,0,0]:3, [0,0,1]:3, [0,0,2]:3, [0,0,3]:3, [0,0,4]:3, [0,0,5]:2, [0,0,6]:2, [0,0,7]], // RWV 9+9*4+6+6+3 = 60 | 100 + [[0,1,0], [0,2,0], [0,3,0], [0,4,0], [0,5,0]:2, [0,6,0]:2, [0,7,0]:3], // RVR 3*4+6+6+9 = 33 | 133 [[0,0,1,0,0], [0,0,2,0,0], [0,0,3,0,0], [0,0,4,0,0], [0,0,5,0,0]:2, [0,0,6,0,0]:2, [0,0,7,0,0]:3]]...; // RWVWR 5*4+10+10+15 = 55 | 188 => N = 2^8 col fixed WIDTH = [[[8,1,8,2,8,4]:5, [8,1,8,2]:2, [8,1]], // RV @@ -41,10 +55,16 @@ airtemplate MemAlignRom(const int N = MEM_ALIGN_ROM_SIZE, const int CHUNK_NUM = [[8,8,8]:4, [8,4,8,8,8,8]:2, [8,2,8,8,4,8,8,8,8]], // RVR [[8,8,8,8,8]:4, [8,8,4,8,8,8,8,8,8,8]:2, [8,8,2,8,8,8,8,4,8,8,8,8,8,8,8]]]...; // RWVWR - const int psize1 = 40; - const int psize2 = 60; - const int psize3 = 33; - const int psize4 = 55; + // TODO: Do a less-hardcoded version of the OFFSET and WIDTH computation + // col fixed OFFSET; + // col fixed WIDTH; + // for (int i = 0; i < N; i++) { + // int offset = 0; + // int width = 0; + + // OFFSET[i] = offset; + // WIDTH[i] = width; + // } // line | pc | pc'-pc | reset | addr | (addr-'addr)*(1-reset) | // 0 | 0 | 1 | 1 | X1 | 0 | // (RV) @@ -92,7 +112,7 @@ airtemplate MemAlignRom(const int N = MEM_ALIGN_ROM_SIZE, const int CHUNK_NUM = const int line = i; const int next = i+1; - if (line < psize1) // RV + if (line < tsize[0]) // RV { if (line % 2 == 0) { // pc = 0; @@ -114,7 +134,7 @@ airtemplate MemAlignRom(const int N = MEM_ALIGN_ROM_SIZE, const int CHUNK_NUM = // sel_down_to_up = 0; } } - else if (line < psize1+psize2) // RWV + else if (line < tsize[0]+tsize[1]) // RWV { if (line % 3 == 0) { // R // pc = 0; @@ -145,7 +165,7 @@ airtemplate MemAlignRom(const int N = MEM_ALIGN_ROM_SIZE, const int CHUNK_NUM = // sel_down_to_up = 0; } } - else if (line < psize1+psize2+psize3) + else if (line < tsize[0]+tsize[1]+tsize[2]) { if (line % 3 == 0) { // R // pc = 0; @@ -176,7 +196,7 @@ airtemplate MemAlignRom(const int N = MEM_ALIGN_ROM_SIZE, const int CHUNK_NUM = sel_down_to_up = 1; } } - else if (line < psize1+psize2+psize3+psize4) + else if (line < tsize[0]+tsize[1]+tsize[2]+tsize[3]) { if (next % 5 == 0) { // R // pc = 0; diff --git a/state-machines/mem/src/mem_align_rom_sm.rs b/state-machines/mem/src/mem_align_rom_sm.rs index 2f7027aa..ee2cc678 100644 --- a/state-machines/mem/src/mem_align_rom_sm.rs +++ b/state-machines/mem/src/mem_align_rom_sm.rs @@ -17,8 +17,7 @@ use zisk_pil::{MEM_ALIGN_ROM_AIR_IDS, ZISK_AIRGROUP_ID}; use crate::MemOp; -const CHUNKS: usize = 8; -const MEM_WIDTHS: [u64; 4] = [1, 2, 4, 8]; +const CHUNK_NUM: usize = 8; const OP_SIZES: [usize; 4] = [2, 3, 3, 5]; pub struct MemAlignRomSM { @@ -77,7 +76,7 @@ impl MemAlignRomSM { match opcode { MemOp::OneRead | MemOp::OneWrite => { // Sanity check - assert!(offset + width <= CHUNKS); + assert!(offset + width <= CHUNK_NUM); let possible_widths = match offset { x if x <= 4 => vec![1, 2, 4], x if x <= 6 => vec![1, 2], @@ -88,7 +87,7 @@ impl MemAlignRomSM { } MemOp::TwoReads | MemOp::TwoWrites => { // Sanity check - assert!(offset + width > CHUNKS); + assert!(offset + width > CHUNK_NUM); let possible_widths = match offset { x if x == 0 => panic!("Invalid offset={}", offset), x if x <= 4 => vec![8], From b3869526868e4403fc094fbcbc4dfc8edc28aeed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9ctor=20Masip?= Date: Wed, 6 Nov 2024 16:44:21 +0000 Subject: [PATCH 11/44] Working --- state-machines/mem/src/mem_align_rom_sm.rs | 114 ++--- state-machines/mem/src/mem_align_sm.rs | 460 ++++++++++++--------- state-machines/mem/src/mem_proxy.rs | 59 +-- state-machines/mem/src/mem_sm.rs | 4 +- witness-computation/src/executor.rs | 2 +- 5 files changed, 339 insertions(+), 300 deletions(-) diff --git a/state-machines/mem/src/mem_align_rom_sm.rs b/state-machines/mem/src/mem_align_rom_sm.rs index ee2cc678..702eb724 100644 --- a/state-machines/mem/src/mem_align_rom_sm.rs +++ b/state-machines/mem/src/mem_align_rom_sm.rs @@ -7,15 +7,20 @@ use std::{ }; use log::info; -use p3_field::Field; +use p3_field::PrimeField; use proofman::{WitnessComponent, WitnessManager}; use proofman_common::AirInstance; -use rayon::prelude::*; use sm_common::create_prover_buffer; -use zisk_pil::{MEM_ALIGN_ROM_AIR_IDS, ZISK_AIRGROUP_ID}; - -use crate::MemOp; +use zisk_pil::{MemAlignRomRow, MemAlignRomTrace, MEM_ALIGN_ROM_AIR_IDS, ZISK_AIRGROUP_ID}; + +#[derive(Debug, Clone, Copy)] +pub enum MemOp { + OneRead, + OneWrite, + TwoReads, + TwoWrites, +} const CHUNK_NUM: usize = 8; const OP_SIZES: [usize; 4] = [2, 3, 3, 5]; @@ -29,7 +34,7 @@ pub struct MemAlignRomSM { // Rom data num_rows: usize, - multiplicity: Mutex>, // row_num -> multiplicity + multiplicity: Mutex>, // row_num -> multiplicity } #[derive(Debug)] @@ -37,10 +42,10 @@ pub enum ExtensionTableSMErr { InvalidOpcode, } -impl MemAlignRomSM { +impl MemAlignRomSM { const MY_NAME: &'static str = "MemAlignRom"; - pub fn new(wcm: Arc>, airgroup_id: usize, air_ids: &[usize]) -> Arc { + pub fn new(wcm: Arc>) -> Arc { let pctx = wcm.get_pctx(); let air = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_ALIGN_ROM_AIR_IDS[0]); let num_rows = air.num_rows(); @@ -52,7 +57,11 @@ impl MemAlignRomSM { multiplicity: Mutex::new(HashMap::with_capacity(num_rows)), }; let mem_align_rom = Arc::new(mem_align_rom); - wcm.register_component(mem_align_rom.clone(), Some(airgroup_id), Some(air_ids)); + wcm.register_component( + mem_align_rom.clone(), + Some(ZISK_AIRGROUP_ID), + Some(MEM_ALIGN_ROM_AIR_IDS), + ); mem_align_rom } @@ -153,7 +162,7 @@ impl MemAlignRomSM { let mut multiplicity = self.multiplicity.lock().unwrap(); for &i in idxs { - *multiplicity.entry(F::from_canonical_u64(i)).or_insert(0) += 1; + *multiplicity.entry(i).or_insert(0) += 1; } } @@ -161,56 +170,55 @@ impl MemAlignRomSM { let mut multiplicity = self.multiplicity.lock().unwrap(); for (idx, mul) in inputs.iter().enumerate() { - *multiplicity.entry(F::from_canonical_usize(idx)).or_insert(0) += *mul; + *multiplicity.entry(idx as u64).or_insert(0) += *mul; } } pub fn create_air_instance(&self) { - let ectx = self.wcm.get_ectx(); - let mut dctx: std::sync::RwLockWriteGuard<'_, proofman_common::DistributionCtx> = - ectx.dctx.write().unwrap(); + let pctx = self.wcm.get_pctx(); + + let air_mem_align_rom = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_ALIGN_ROM_AIR_IDS[0]); + + // Create the prover buffer + let (mut prover_buffer, offset) = create_prover_buffer( + &self.wcm.get_ectx(), + &self.wcm.get_sctx(), + ZISK_AIRGROUP_ID, + MEM_ALIGN_ROM_AIR_IDS[0], + ); + + let mut trace_buffer = MemAlignRomTrace::::map_buffer( + &mut prover_buffer, + air_mem_align_rom.num_rows(), + offset as usize, + ) + .unwrap(); let mut multiplicity = self.multiplicity.lock().unwrap(); - let (is_myne, instance_global_idx) = - dctx.add_instance(ZISK_AIRGROUP_ID, MEM_ALIGN_ROM_AIR_IDS[0], 1); - let owner = dctx.owner(instance_global_idx); - - let mut multiplicity_ = std::mem::take(&mut *multiplicity); - dctx.distribute_multiplicity(&mut multiplicity_, owner); - - if is_myne { - // Create the prover buffer - let (mut prover_buffer, offset) = create_prover_buffer( - &self.wcm.get_ectx(), - &self.wcm.get_sctx(), - ZISK_AIRGROUP_ID, - MEM_ALIGN_ROM_AIR_IDS[0], - ); - - prover_buffer[offset as usize..offset as usize + self.num_rows] - .par_iter_mut() - .enumerate() - .for_each(|(i, input)| *input = F::from_canonical_u64(multiplicity_[i])); - - info!( - "{}: ··· Creating Mem Align ROM instance [{} rows filled 100%]", - Self::MY_NAME, - self.num_rows, - ); - - let air_instance = AirInstance::new( - self.wcm.get_sctx(), - ZISK_AIRGROUP_ID, - MEM_ALIGN_ROM_AIR_IDS[0], - None, - prover_buffer, - ); - self.wcm - .get_pctx() - .air_instance_repo - .add_air_instance(air_instance, Some(instance_global_idx)); - } + // for row_idx in multiplicity.keys() { + // trace_buffer[*row_idx as usize] = MemAlignRomRow { + // multiplicity: multiplicity + // }; + // } + + info!( + "{}: ··· Creating Mem Align ROM instance [{} rows filled 100%]", + Self::MY_NAME, + self.num_rows, + ); + + let air_instance = AirInstance::new( + self.wcm.get_sctx(), + ZISK_AIRGROUP_ID, + MEM_ALIGN_ROM_AIR_IDS[0], + None, + prover_buffer, + ); + self.wcm + .get_pctx() + .air_instance_repo + .add_air_instance(air_instance, None); } } diff --git a/state-machines/mem/src/mem_align_sm.rs b/state-machines/mem/src/mem_align_sm.rs index e2f0ae98..7cdb8465 100644 --- a/state-machines/mem/src/mem_align_sm.rs +++ b/state-machines/mem/src/mem_align_sm.rs @@ -12,21 +12,12 @@ use p3_field::PrimeField; use pil_std_lib::Std; use proofman::{WitnessComponent, WitnessManager}; use proofman_common::AirInstance; -use rayon::Scope; -use sm_common::{create_prover_buffer, OpResult, Provable}; +use sm_common::create_prover_buffer; use zisk_core::ZiskRequiredMemory; -use zisk_pil::{MemAlign3Row, MemAlign3Trace, MEM_ALIGN_AIR_IDS, ZISK_AIRGROUP_ID}; +use zisk_pil::{MemAlignRow, MemAlignTrace, MEM_ALIGN_AIR_IDS, ZISK_AIRGROUP_ID}; -use crate::MemAlignRomSM; - -#[derive(Debug, Clone, Copy)] -pub enum MemOp { - OneRead, - OneWrite, - TwoReads, - TwoWrites, -} +use crate::{MemAlignRomSM, MemOp}; const PROVE_CHUNK_SIZE: usize = 1 << 12; @@ -45,7 +36,8 @@ pub struct MemAlignSM { registered_predecessors: AtomicU32, // Inputs - inputs: Mutex>, + inputs: Mutex)>>, + input_len: Mutex, // Secondary State machines mem_align_rom_sm: Arc>, @@ -64,6 +56,7 @@ impl MemAlignSM { std: std.clone(), registered_predecessors: AtomicU32::new(0), inputs: Mutex::new(Vec::new()), + input_len: Mutex::new(0), mem_align_rom_sm, }; let mem_align_sm = Arc::new(mem_align_sm); @@ -87,6 +80,9 @@ impl MemAlignSM { pub fn unregister_predecessor(&self) { if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { + // TODO: Fix this... + self.prove_internal(&[], 0); + self.mem_align_rom_sm.unregister_predecessor(); self.std.unregister_predecessor(self.wcm.get_pctx(), None); } @@ -97,9 +93,9 @@ impl MemAlignSM { let addr = unaligned_input.address; let width = unaligned_input.width; - let offset = addr % 8; + let offset = addr & (CHUNK_NUM_U64 - 1); - match (unaligned_input.is_write, offset + width > 8) { + match (unaligned_input.is_write, offset + width > CHUNK_NUM_U64) { (false, false) => MemOp::OneRead, (true, false) => MemOp::OneWrite, (false, true) => MemOp::TwoReads, @@ -107,67 +103,196 @@ impl MemAlignSM { } } + pub fn prove( + &self, + unaligned_access: &ZiskRequiredMemory, + aligned_accesses: &[ZiskRequiredMemory], + ) { + if let (Ok(mut inputs), Ok(mut input_len)) = (self.inputs.lock(), self.input_len.lock()) { + inputs.push((unaligned_access.clone(), aligned_accesses.to_vec())); + *input_len += 1 + aligned_accesses.len(); + + let pctx = self.wcm.get_pctx(); + let air_mem_align = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_ALIGN_AIR_IDS[0]); + + while *input_len >= air_mem_align.num_rows() { + let num_drained = std::cmp::min(air_mem_align.num_rows(), *input_len); + let drained_inputs = inputs.drain(..num_drained).collect::>(); + let drained_len = num_drained; + *input_len -= num_drained; + + self.prove_internal(&drained_inputs, drained_len); + } + } + } + + fn prove_internal( + &self, + inputs: &[(ZiskRequiredMemory, Vec)], + input_len: usize, + ) { + let mem_align_rom_sm = self.mem_align_rom_sm.clone(); + let wcm = self.wcm.clone(); + let std = self.std.clone(); + let sctx = self.wcm.get_sctx().clone(); + + let (mut prover_buffer, offset) = create_prover_buffer( + &wcm.get_ectx(), + &wcm.get_sctx(), + ZISK_AIRGROUP_ID, + MEM_ALIGN_AIR_IDS[0], + ); + + Self::prove_instance( + &wcm, + &mem_align_rom_sm, + &std, + inputs, + input_len, + &mut prover_buffer, + offset, + ); + + let air_instance = + AirInstance::new(sctx, ZISK_AIRGROUP_ID, MEM_ALIGN_AIR_IDS[0], None, prover_buffer); + wcm.get_pctx().air_instance_repo.add_air_instance(air_instance, None); + } + + fn prove_instance( + wcm: &WitnessManager, + mem_align_rom_sm: &MemAlignRomSM, + std: &Std, + inputs: &[(ZiskRequiredMemory, Vec)], + input_len: usize, + prover_buffer: &mut [F], + offset: u64, + ) { + let pctx = wcm.get_pctx(); + + let air_mem_align = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_ALIGN_AIR_IDS[0]); + assert!(input_len <= air_mem_align.num_rows()); + + info!( + "{}: ··· Creating Mem Align instance [{} / {} rows filled {:.2}%]", + Self::MY_NAME, + input_len, + air_mem_align.num_rows(), + input_len as f64 / air_mem_align.num_rows() as f64 * 100.0 + ); + + let mut reg_range_check: HashMap = HashMap::new(); + let mut trace_buffer = MemAlignTrace::::map_buffer( + prover_buffer, + air_mem_align.num_rows(), + offset as usize, + ) + .unwrap(); + + // Process the inputs while saving the values to be range checked + let mut rows_processed = 0; + for (unaligned_input, aligned_inputs) in inputs.iter() { + let rows = Self::process_slice( + unaligned_input, + aligned_inputs, + mem_align_rom_sm, + &mut reg_range_check, + ); + for (j, &row) in rows.iter().enumerate() { + trace_buffer[rows_processed + j] = row; + } + rows_processed += rows.len(); + } + + // Pad the remaining rows with trivailly satisfying rows + let padding_row = MemAlignRow::::default(); + + for i in rows_processed..air_mem_align.num_rows() { + trace_buffer[i] = padding_row; + } + + // TODO: Store the padding multiplicity + let _padding_size = air_mem_align.num_rows() - rows_processed; + // for i in 0..8 { + // let multiplicity = padding_size as u64; + // let row = MemAlignRomSM::::calculate_rom_row( + // op, offset, width + // ); + // rom_multiplicity[row as usize] += multiplicity; + // } + + // Perform the range checks + let range_id = std.get_range(BigInt::from(0), BigInt::from((1 << CHUNK_BITS) - 1), None); + for (&value, &multiplicity) in reg_range_check.iter() { + std.range_check(value, F::from_canonical_u64(multiplicity), range_id); + } + + // std::thread::spawn(move || { + // drop(inputs); + // drop(reg_range_check); + // }); + } + #[inline(always)] pub fn process_slice( - input: &Vec, + unaligned_input: &ZiskRequiredMemory, + aligned_inputs: &[ZiskRequiredMemory], mem_align_rom_sm: &MemAlignRomSM, range_check: &mut HashMap, - ) -> Vec> { - // Is a write or a read operation - let _wr = input[0].is_write; - - // Get the address - let addr = input[0].address; - let addr_prior = input[1].address; // addr / CHUNK_NUM; - let addr_next = input[2].address; // addr / CHUNK_NUM + CHUNK_NUM; - - // Get the value - let value = input[0].value.to_be_bytes(); - let value_first_read = input[1].value.to_be_bytes(); - let value_first_write = input[2].value.to_be_bytes(); - let value_second_read = input[3].value.to_be_bytes(); - let value_second_write = input[4].value.to_be_bytes(); - - // Get the step - let step = input[0].step; - let step_first_read = input[1].step; - let step_first_write = input[2].step; - let step_second_read = input[3].step; - let step_second_write = input[4].step; - - // Get the offset - let offset = addr % CHUNK_NUM_U64; - let offset = if offset <= usize::MAX as u64 { - offset as usize - } else { - panic!("Invalid offset={}", offset); - }; + ) -> Vec> { + // Get the unaligned address + let addr = unaligned_input.address; + + // Get the unaligned value + let value = unaligned_input.value.to_be_bytes(); - // Get the width - let width = input[0].width; + // Get the unaligned step + let step = unaligned_input.step; + + // Get the unaligned width + let width = unaligned_input.width; let width = if width <= CHUNK_NUM_U64 { width as usize } else { panic!("Invalid width={}", width); }; + // Compute the offset + let offset = addr % CHUNK_NUM_U64; + let offset = if offset <= usize::MAX as u64 { + offset as usize + } else { + panic!("Invalid offset={}", offset); + }; + // Compute the shift let shift = (offset + width) % CHUNK_NUM; // Get the op to be executed, its size and the pc to jump to - let op = Self::get_mem_op(&input[0]); + let op = Self::get_mem_op(&unaligned_input); let op_size = MemAlignRomSM::::get_mem_align_op_size(op); let next_pc = MemAlignRomSM::::calculate_next_pc(op, offset, width); // Initialize and set the rows of the corresponding op - let mut rows: Vec> = Vec::with_capacity(op_size); + let mut rows: Vec> = Vec::with_capacity(op_size); // TODO: Can I detatch the "shape" of the program from the mem_align and do it in the mem_align_rom? match op { MemOp::OneRead => { // RV - let mut read_row = MemAlign3Row:: { - step: F::from_canonical_u64(step_first_read), - addr: F::from_canonical_u64(addr_prior), + // Sanity check + assert!(aligned_inputs.len() == 1); + + // Get the aligned address + let addr_read = aligned_inputs[0].address; // addr / CHUNK_NUM; + + // Get the aligned values + let value_read = aligned_inputs[0].value.to_be_bytes(); + + // Get the aligned step + let step_read = aligned_inputs[0].step; + + let mut read_row = MemAlignRow:: { + step: F::from_canonical_u64(step_read), + addr: F::from_canonical_u64(addr_read), // offset: F::from_canonical_u64(0), // wr: F::from_bool(false), // pc: F::from_canonical_u64(0), @@ -176,7 +301,7 @@ impl MemAlignSM { ..Default::default() }; - let mut value_row = MemAlign3Row:: { + let mut value_row = MemAlignRow:: { step: F::from_canonical_u64(step), addr: F::from_canonical_u64(addr), offset: F::from_canonical_usize(offset), @@ -189,7 +314,7 @@ impl MemAlignSM { }; for i in 0..CHUNK_NUM { - read_row.reg[i] = F::from_canonical_u8(value_first_read[i]); + read_row.reg[i] = F::from_canonical_u8(value_read[i]); read_row.sel[i] = F::from_bool(true); value_row.reg[i] = F::from_canonical_u8(value[shift + i]); @@ -206,9 +331,24 @@ impl MemAlignSM { } MemOp::OneWrite => { // RWV - let mut read_row = MemAlign3Row:: { - step: F::from_canonical_u64(step_first_read), - addr: F::from_canonical_u64(addr_prior), + // Sanity check + assert!(aligned_inputs.len() == 2); + + // Get the aligned address + let addr_read_write = aligned_inputs[0].address; // addr / CHUNK_NUM; + + // Get the aligned values + let value_read = aligned_inputs[0].value.to_be_bytes(); + let value_write = aligned_inputs[1].value.to_be_bytes(); + + // Get the aligned step + let step_read = aligned_inputs[0].step; + let step_write = aligned_inputs[1].step; + + // RWV + let mut read_row = MemAlignRow:: { + step: F::from_canonical_u64(step_read), + addr: F::from_canonical_u64(addr_read_write), // offset: F::from_canonical_u64(0), width: F::from_canonical_u64(CHUNK_NUM_U64), // wr: F::from_bool(false), @@ -218,9 +358,9 @@ impl MemAlignSM { ..Default::default() }; - let mut write_row = MemAlign3Row:: { - step: F::from_canonical_u64(step_first_write), - addr: F::from_canonical_u64(addr_prior), + let mut write_row = MemAlignRow:: { + step: F::from_canonical_u64(step_write), + addr: F::from_canonical_u64(addr_read_write), // offset: F::from_canonical_u64(0), width: F::from_canonical_u64(CHUNK_NUM_U64), wr: F::from_bool(true), @@ -230,7 +370,7 @@ impl MemAlignSM { ..Default::default() }; - let mut value_row = MemAlign3Row:: { + let mut value_row = MemAlignRow:: { step: F::from_canonical_u64(step), addr: F::from_canonical_u64(addr), offset: F::from_canonical_usize(offset), @@ -243,10 +383,10 @@ impl MemAlignSM { }; for i in 0..CHUNK_NUM { - read_row.reg[i] = F::from_canonical_u8(value_first_read[i]); + read_row.reg[i] = F::from_canonical_u8(value_read[i]); read_row.sel[i] = F::from_bool(i < offset); - write_row.reg[i] = F::from_canonical_u8(value_first_write[i]); + write_row.reg[i] = F::from_canonical_u8(value_write[i]); write_row.sel[i] = F::from_bool(i >= offset); value_row.reg[i] = F::from_canonical_u8(value[shift + i]); @@ -265,9 +405,25 @@ impl MemAlignSM { } MemOp::TwoReads => { // RVR - let mut first_read_row = MemAlign3Row:: { + // Sanity check + assert!(aligned_inputs.len() == 2); + + // Get the aligned address + let addr_first_read = aligned_inputs[0].address; // addr / CHUNK_NUM; + let addr_second_read = aligned_inputs[1].address; // addr / CHUNK_NUM + CHUNK_NUM; + + // Get the aligned values + let value_first_read = aligned_inputs[0].value.to_be_bytes(); + let value_second_read = aligned_inputs[1].value.to_be_bytes(); + + // Get the aligned step + let step_first_read = aligned_inputs[0].step; + let step_second_read = aligned_inputs[1].step; + + // RVR + let mut first_read_row = MemAlignRow:: { step: F::from_canonical_u64(step_first_read), - addr: F::from_canonical_u64(addr_prior), + addr: F::from_canonical_u64(addr_first_read), // offset: F::from_canonical_u64(0), width: F::from_canonical_u64(CHUNK_NUM_U64), // wr: F::from_bool(false), @@ -277,7 +433,7 @@ impl MemAlignSM { ..Default::default() }; - let mut value_row = MemAlign3Row:: { + let mut value_row = MemAlignRow:: { step: F::from_canonical_u64(step), addr: F::from_canonical_u64(addr), offset: F::from_canonical_usize(offset), @@ -289,9 +445,9 @@ impl MemAlignSM { ..Default::default() }; - let mut second_read_row = MemAlign3Row:: { + let mut second_read_row = MemAlignRow:: { step: F::from_canonical_u64(step_second_read), - addr: F::from_canonical_u64(addr_next), + addr: F::from_canonical_u64(addr_second_read), // offset: F::from_canonical_u64(0), width: F::from_canonical_u64(CHUNK_NUM_U64), // wr: F::from_bool(false), @@ -324,9 +480,29 @@ impl MemAlignSM { } MemOp::TwoWrites => { // RWVWR - let mut first_read_row = MemAlign3Row:: { + // Sanity check + assert!(aligned_inputs.len() == 4); + + // Get the aligned address + let addr_first_read_write = aligned_inputs[0].address; // addr / CHUNK_NUM; + let addr_second_read_write = aligned_inputs[2].address; // addr / CHUNK_NUM + CHUNK_NUM; + + // Get the aligned values + let value_first_read = aligned_inputs[0].value.to_be_bytes(); + let value_first_write = aligned_inputs[1].value.to_be_bytes(); + let value_second_read = aligned_inputs[2].value.to_be_bytes(); + let value_second_write = aligned_inputs[3].value.to_be_bytes(); + + // Get the aligned step + let step_first_read = aligned_inputs[0].step; + let step_first_write = aligned_inputs[1].step; + let step_second_read = aligned_inputs[2].step; + let step_second_write = aligned_inputs[3].step; + + // RWVWR + let mut first_read_row = MemAlignRow:: { step: F::from_canonical_u64(step_first_read), - addr: F::from_canonical_u64(addr_prior), + addr: F::from_canonical_u64(addr_first_read_write), // offset: F::from_canonical_u64(0), width: F::from_canonical_u64(CHUNK_NUM_U64), // wr: F::from_bool(false), @@ -336,9 +512,9 @@ impl MemAlignSM { ..Default::default() }; - let mut first_write_row = MemAlign3Row:: { + let mut first_write_row = MemAlignRow:: { step: F::from_canonical_u64(step_first_write), - addr: F::from_canonical_u64(addr_prior), + addr: F::from_canonical_u64(addr_first_read_write), // offset: F::from_canonical_u64(0), width: F::from_canonical_u64(CHUNK_NUM_U64), wr: F::from_bool(true), @@ -348,7 +524,7 @@ impl MemAlignSM { ..Default::default() }; - let mut value_row = MemAlign3Row:: { + let mut value_row = MemAlignRow:: { step: F::from_canonical_u64(step), addr: F::from_canonical_u64(addr), offset: F::from_canonical_usize(offset), @@ -360,9 +536,9 @@ impl MemAlignSM { ..Default::default() }; - let mut second_write_row = MemAlign3Row:: { + let mut second_write_row = MemAlignRow:: { step: F::from_canonical_u64(step_second_write), - addr: F::from_canonical_u64(addr_next), + addr: F::from_canonical_u64(addr_second_read_write), // offset: F::from_canonical_u64(0), width: F::from_canonical_u64(CHUNK_NUM_U64), wr: F::from_bool(true), @@ -372,9 +548,9 @@ impl MemAlignSM { ..Default::default() }; - let mut second_read_row = MemAlign3Row:: { + let mut second_read_row = MemAlignRow:: { step: F::from_canonical_u64(step_second_read), - addr: F::from_canonical_u64(addr_next), + addr: F::from_canonical_u64(addr_second_read_write), // offset: F::from_canonical_u64(0), width: F::from_canonical_u64(CHUNK_NUM_U64), // wr: F::from_bool(false), @@ -423,134 +599,6 @@ impl MemAlignSM { // Return successfully rows } - - pub fn prove_instance( - &self, - inputs: Vec, - prover_buffer: &mut [F], - offset: u64, - ) { - Self::prove_internal( - &self.wcm, - &self.mem_align_rom_sm, - &self.std, - inputs, - prover_buffer, - offset, - ); - } - - fn prove_internal( - wcm: &WitnessManager, - mem_align_rom_sm: &MemAlignRomSM, - std: &Std, - inputs: Vec, - prover_buffer: &mut [F], - offset: u64, - ) { - let pctx = wcm.get_pctx(); - - let air_mem_align = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_ALIGN_AIR_IDS[0]); - assert!(inputs.len() <= air_mem_align.num_rows()); - - info!( - "{}: ··· Creating Mem Align instance [{} / {} rows filled {:.2}%]", - Self::MY_NAME, - inputs.len(), - air_mem_align.num_rows(), - inputs.len() as f64 / air_mem_align.num_rows() as f64 * 100.0 - ); - - let mut reg_range_check: HashMap = HashMap::new(); - let mut trace_buffer = MemAlign3Trace::::map_buffer( - prover_buffer, - air_mem_align.num_rows(), - offset as usize, - ) - .unwrap(); - - // Process the inputs while saving the multiplcities and range checks - let mut rows_processed = 0; - let rows = Self::process_slice(&inputs, mem_align_rom_sm, &mut reg_range_check); - for (i, &row) in rows.iter().enumerate() { - trace_buffer[rows_processed + i] = row; - } - rows_processed += rows.len(); - - // Pad the remaining rows with trivailly satisfying rows - let padding_row = MemAlign3Row::::default(); - - for i in rows_processed..air_mem_align.num_rows() { - trace_buffer[i] = padding_row; - } - - // TODO: Store the padding multiplicity - let _padding_size = air_mem_align.num_rows() - rows_processed; - // for i in 0..8 { - // let multiplicity = padding_size as u64; - // let row = MemAlignRomSM::::calculate_rom_row( - // op, offset, width - // ); - // rom_multiplicity[row as usize] += multiplicity; - // } - - // Perform the range checks - let range_id = std.get_range(BigInt::from(0), BigInt::from((1 << CHUNK_BITS) - 1), None); - for (&value, &multiplicity) in reg_range_check.iter() { - std.range_check(value, F::from_canonical_u64(multiplicity), range_id); - } - - // std::thread::spawn(move || { - // drop(inputs); - // drop(reg_range_check); - // }); - } } impl WitnessComponent for MemAlignSM {} - -impl Provable for MemAlignSM { - fn prove(&self, operations: &[ZiskRequiredMemory], drain: bool, _scope: &Scope) { - if let Ok(mut inputs) = self.inputs.lock() { - inputs.extend_from_slice(operations); - - let pctx = self.wcm.get_pctx(); - let air_mem_align = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_ALIGN_AIR_IDS[0]); - - while inputs.len() >= air_mem_align.num_rows() || (drain && !inputs.is_empty()) { - let num_drained = std::cmp::min(air_mem_align.num_rows(), inputs.len()); - let drained_inputs = inputs.drain(..num_drained).collect::>(); - - let mem_align_rom_sm = self.mem_align_rom_sm.clone(); - let wcm = self.wcm.clone(); - let std = self.std.clone(); - let sctx = self.wcm.get_sctx().clone(); - - let (mut prover_buffer, offset) = create_prover_buffer( - &wcm.get_ectx(), - &wcm.get_sctx(), - ZISK_AIRGROUP_ID, - MEM_ALIGN_AIR_IDS[0], - ); - - Self::prove_internal( - &wcm, - &mem_align_rom_sm, - &std, - drained_inputs, - &mut prover_buffer, - offset, - ); - - let air_instance = AirInstance::new( - sctx, - ZISK_AIRGROUP_ID, - MEM_ALIGN_AIR_IDS[0], - None, - prover_buffer, - ); - wcm.get_pctx().air_instance_repo.add_air_instance(air_instance, None); - } - } - } -} diff --git a/state-machines/mem/src/mem_proxy.rs b/state-machines/mem/src/mem_proxy.rs index 5eb6e7a2..73c63176 100644 --- a/state-machines/mem/src/mem_proxy.rs +++ b/state-machines/mem/src/mem_proxy.rs @@ -3,33 +3,28 @@ use std::sync::{ Arc, }; -use crate::{MemAlignSM, MemSM}; +use crate::{MemAlignRomSM, MemAlignSM, MemOp, MemSM}; use p3_field::PrimeField; +use pil_std_lib::Std; use proofman_util::{timer_start_debug, timer_stop_and_log_debug}; use zisk_core::{ZiskRequiredMemory, RAM_ADDR, SYS_ADDR}; use proofman::{WitnessComponent, WitnessManager}; -pub enum MemOps { - OneRead, - OneWrite, - TwoReads, - TwoWrites, -} - pub struct MemProxy { // Count of registered predecessors registered_predecessors: AtomicU32, // Secondary State machines mem_sm: Arc>, - mem_align_sm: Arc, + mem_align_sm: Arc>, } impl MemProxy { - pub fn new(wcm: Arc>) -> Arc { + pub fn new(wcm: Arc>, std: Arc>) -> Arc { let mem_sm = MemSM::new(wcm.clone()); - let mem_align_sm = MemAlignSM::new(wcm.clone()); + let mem_align_rom_sm = MemAlignRomSM::new(wcm.clone()); + let mem_align_sm = MemAlignSM::new(wcm.clone(), std, mem_align_rom_sm); let mem_proxy = Self { registered_predecessors: AtomicU32::new(0), @@ -40,7 +35,7 @@ impl MemProxy { wcm.register_component(mem_proxy.clone(), None, None); - // For all the secondary state machines, register the main state machine as a predecessor + // For all the secondary state machines, register the mem_proxy as a predecessor mem_sm.register_predecessor(); mem_align_sm.register_predecessor(); @@ -54,7 +49,7 @@ impl MemProxy { pub fn unregister_predecessor(&self) { if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { self.mem_sm.unregister_predecessor(); - // self.mem_align_sm.unregister_predecessor::(); + self.mem_align_sm.unregister_predecessor(); } } @@ -73,15 +68,16 @@ impl MemProxy { // Step 2. For each non-aligned memory access non_aligned.iter().for_each(|unaligned_access| { - let mem_ops = Self::get_mem_ops(unaligned_access); + // Step 2.1 Ask to the Mem Align SM for the aligned memory accesses generated by the non-aligned one + let mem_op = MemAlignSM::::get_mem_op(unaligned_access); - // Step 2.1 Find the possible aligned memory access - let aligned_accesses = self.get_aligned_accesses(&unaligned_access, mem_ops, &aligned); + // Step 2.2 Ask to the Mem SM for the aligned memory accesses + let aligned_accesses = self.get_aligned_accesses(&unaligned_access, mem_op, &aligned); - // Step 2.2 Align memory access using mem_align state machine - // self.mem_aligned_sm.align_mem_accesses(potential_aligned_mem, mem, &mut new_aligned); + // Step 2.3 Carried with the aligned memory accesses, prove the non-aligned ones + self.mem_align_sm.prove(unaligned_access, &aligned_accesses); - // Step 2.3 Store the new aligned memory access(es) + // Step 2.4 Store the new aligned memory access(es) new_aligned.extend(aligned_accesses); }); @@ -109,13 +105,13 @@ impl MemProxy { fn get_aligned_accesses( &self, unaligned_access: &ZiskRequiredMemory, - mem_ops: MemOps, + mem_op: MemOp, aligned_accesses: &[ZiskRequiredMemory], ) -> Vec { // Align down to a 8 byte addres let addr = unaligned_access.address & !7; - match mem_ops { - MemOps::OneRead => { + match mem_op { + MemOp::OneRead => { // Look for last write to the same address let last_write_addr = Self::get_last_write(addr, unaligned_access.step, aligned_accesses); @@ -128,7 +124,7 @@ impl MemProxy { }); vec![last_write_addr] } - MemOps::OneWrite => { + MemOp::OneWrite => { // Look for last write to the same address let last_write_addr = Self::get_last_write(addr, unaligned_access.step, aligned_accesses); @@ -145,7 +141,7 @@ impl MemProxy { Self::write_value(&unaligned_access, &mut last_write_addr); vec![last_write_addr] } - MemOps::TwoReads => { + MemOp::TwoReads => { // Look for last write to the same address and same address + 8 let last_write_addr = Self::get_last_write(addr, unaligned_access.step, aligned_accesses); @@ -170,7 +166,7 @@ impl MemProxy { vec![last_write_addr, last_write_addr_p] } - MemOps::TwoWrites => { + MemOp::TwoWrites => { // Look for last write to the same address and same address + 8 let last_write_addr = Self::get_last_write(addr, unaligned_access.step, aligned_accesses); @@ -274,19 +270,6 @@ impl MemProxy { }; Self::write_value(&right_memory, aligned_next); } - - #[inline(always)] - pub fn get_mem_ops(input: &ZiskRequiredMemory) -> MemOps { - let addr = input.address; - let width = input.width; - let offset = addr & 7; - match (input.is_write, offset + width > 8) { - (false, false) => MemOps::OneRead, - (true, false) => MemOps::OneWrite, - (false, true) => MemOps::TwoReads, - (true, true) => MemOps::TwoWrites, - } - } } impl WitnessComponent for MemProxy {} diff --git a/state-machines/mem/src/mem_sm.rs b/state-machines/mem/src/mem_sm.rs index a45bb684..7561e08e 100644 --- a/state-machines/mem/src/mem_sm.rs +++ b/state-machines/mem/src/mem_sm.rs @@ -195,8 +195,8 @@ impl MemSM { let addr_changes = trace[i - 1].addr != trace[i].addr; trace[i].addr_changes = if addr_changes { F::one() } else { F::zero() }; - let same_value = trace[i - 1].value[0] == trace[i].value[0] && - trace[i - 1].value[1] == trace[i].value[1]; + let same_value = trace[i - 1].value[0] == trace[i].value[0] + && trace[i - 1].value[1] == trace[i].value[1]; trace[i].same_value = if same_value { F::one() } else { F::zero() }; let first_addr_access_is_read = addr_changes && !mem_op.is_write; diff --git a/witness-computation/src/executor.rs b/witness-computation/src/executor.rs index 6d928e75..cd289cfb 100644 --- a/witness-computation/src/executor.rs +++ b/witness-computation/src/executor.rs @@ -51,7 +51,7 @@ impl ZiskExecutor { let std = Std::new(wcm.clone()); let rom_sm = RomSM::new(wcm.clone()); - let mem_proxy = MemProxy::new(wcm.clone()); + let mem_proxy = MemProxy::new(wcm.clone(), std.clone()); let binary_sm = BinarySM::new(wcm.clone(), std.clone()); let arith_sm = ArithSM::new(wcm.clone()); From 6094b7978b9c2728885b0314f938a0efd251361e Mon Sep 17 00:00:00 2001 From: Xavier Pinsach <10213118+xavi-pinsach@users.noreply.github.com> Date: Thu, 7 Nov 2024 06:36:25 +0000 Subject: [PATCH 12/44] wip --- core/src/zisk_required_operation.rs | 2 +- state-machines/mem/pil/mem_align.pil | 32 ++-- state-machines/mem/pil/mem_align_rom.pil | 20 +-- state-machines/mem/src/mem_proxy.rs | 183 +++++++++++++++++++---- state-machines/mem/src/mem_sm.rs | 2 + witness-computation/src/executor.rs | 4 +- 6 files changed, 183 insertions(+), 60 deletions(-) diff --git a/core/src/zisk_required_operation.rs b/core/src/zisk_required_operation.rs index 59a7aee6..1ccef475 100644 --- a/core/src/zisk_required_operation.rs +++ b/core/src/zisk_required_operation.rs @@ -8,7 +8,7 @@ pub struct ZiskRequiredOperation { pub b: u64, } -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct ZiskRequiredMemory { pub step: u64, pub is_write: bool, diff --git a/state-machines/mem/pil/mem_align.pil b/state-machines/mem/pil/mem_align.pil index 184b86a2..2293ba35 100644 --- a/state-machines/mem/pil/mem_align.pil +++ b/state-machines/mem/pil/mem_align.pil @@ -87,10 +87,10 @@ require "std_range_check.pil" Notice that it is enough with 8 combinations. */ -airtemplate MemAlign(const int N = 2**10, const int RC = 2, const int MEM_BYTES = 8, const int CHUNK_BITS = 8) { - const int MEM_HALF_BYTES = MEM_BYTES / 2; +airtemplate MemAlign(const int N = 2**10, const int RC = 2, const int CHUNK_NUM = 8, const int CHUNK_BITS = 8) { + const int CHUNK_NUM_HALF = CHUNK_NUM / 2; - col witness addr; // MEM_BYTES-byte address, real address = addr * MEM_BYTES + col witness addr; // CHUNK_NUM-byte address, real address = addr * CHUNK_NUM col witness offset; // 0..7, position at which the operation starts col witness width; // 1,2,4,8, width of the operation col witness wr; // 1 if the operation is a write, 0 otherwise @@ -98,8 +98,8 @@ airtemplate MemAlign(const int N = 2**10, const int RC = 2, const int MEM_BYTES col witness reset; // 1 at the beginning of the operation (indicating an address reset), 0 otherwise col witness sel_up_to_down; // 1 if the next value is the current value (e.g. R -> W) col witness sel_down_to_up; // 1 if the next value is the previous value (e.g. W -> R) - col witness reg[MEM_BYTES]; // Register values, 1 byte each - col witness sel[MEM_BYTES]; // Selectors, 1 if the value is used, 0 otherwise + col witness reg[CHUNK_NUM]; // Register values, 1 byte each + col witness sel[CHUNK_NUM]; // Selectors, 1 if the value is used, 0 otherwise // 1] Ensure the MemAlign follows the program @@ -107,7 +107,7 @@ airtemplate MemAlign(const int N = 2**10, const int RC = 2, const int MEM_BYTES // - reg' == reg in transitions R -> V, R -> W, W -> V, // - 'reg == reg in transitions V <- W, W <- R, // in any case, sel_up_to_down,sel_down_to_up are 0 in [V] steps. - for (int i = 0; i < MEM_BYTES; i++) { + for (int i = 0; i < CHUNK_NUM; i++) { range_check(reg[i], 0, 2**CHUNK_BITS-1); (reg[i]' - reg[i]) * sel[i] * sel_up_to_down === 0; @@ -118,7 +118,7 @@ airtemplate MemAlign(const int N = 2**10, const int RC = 2, const int MEM_BYTES L1 * pc === 0; // The program should start at the first line // We compress selectors, so we should ensure they are binary - for (int i = 0; i < MEM_BYTES; i++) { + for (int i = 0; i < CHUNK_NUM; i++) { sel[i] * (1 - sel[i]) === 0; } wr * (1 - wr) === 0; @@ -127,10 +127,10 @@ airtemplate MemAlign(const int N = 2**10, const int RC = 2, const int MEM_BYTES sel_down_to_up * (1 - sel_down_to_up) === 0; expr flags = 0; - for (int i = 0; i < MEM_BYTES; i++) { + for (int i = 0; i < CHUNK_NUM; i++) { flags += sel[i] * 2**i; } - flags += wr * 2**MEM_BYTES + reset * 2**(MEM_BYTES + 1) + sel_up_to_down * 2**(MEM_BYTES + 2) + sel_down_to_up * 2**(MEM_BYTES + 3); + flags += wr * 2**CHUNK_NUM + reset * 2**(CHUNK_NUM + 1) + sel_up_to_down * 2**(CHUNK_NUM + 2) + sel_down_to_up * 2**(CHUNK_NUM + 3); lookup_assumes(MEM_ALIGN_ROM_ID, [pc, pc'-pc, (addr-'addr)*(1-reset), offset, width, flags]); @@ -144,8 +144,8 @@ airtemplate MemAlign(const int N = 2**10, const int RC = 2, const int MEM_BYTES expr assume_val[RC]; for (int i = 0; i < RC; i++) { assume_val[i] = 0; - for (int j = 0; j < MEM_HALF_BYTES; j++) { - assume_val[i] += reg[j + i * MEM_HALF_BYTES] * 2**j; + for (int j = 0; j < CHUNK_NUM_HALF; j++) { + assume_val[i] += reg[j + i * CHUNK_NUM_HALF] * 2**j; } } @@ -158,16 +158,16 @@ airtemplate MemAlign(const int N = 2**10, const int RC = 2, const int MEM_BYTES expr prove_val[RC]; for (int i = 0; i < RC; i++) { prove_val[i] = 0; - for (int j = 0; j < MEM_HALF_BYTES; j++) { + for (int j = 0; j < CHUNK_NUM_HALF; j++) { expr _prove_val = 0; - for (int k = j; k < j + MEM_HALF_BYTES; k++) { - _prove_val += reg[(k + i * MEM_HALF_BYTES) % MEM_BYTES] * 2**(k-j); + for (int k = j; k < j + CHUNK_NUM_HALF; k++) { + _prove_val += reg[(k + i * CHUNK_NUM_HALF) % CHUNK_NUM] * 2**(k-j); } - prove_val[i] += sel[j + i * MEM_HALF_BYTES] * _prove_val; + prove_val[i] += sel[j + i * CHUNK_NUM_HALF] * _prove_val; } } // We prove and assume with the same permutation check but with disjoint and different sign selectors col witness step; - permutation(MEMORY_ID, cols: [wr * (MEMORY_STORE_OP - MEMORY_LOAD_OP) + MEMORY_LOAD_OP, addr * MEM_BYTES + offset, step, width, ...prove_val], sel: sel_prove - sel_assume); + permutation(MEMORY_ID, cols: [wr * (MEMORY_STORE_OP - MEMORY_LOAD_OP) + MEMORY_LOAD_OP, addr * CHUNK_NUM + offset, step, width, ...prove_val], sel: sel_prove - sel_assume); } \ No newline at end of file diff --git a/state-machines/mem/pil/mem_align_rom.pil b/state-machines/mem/pil/mem_align_rom.pil index 211da1dd..863959a5 100644 --- a/state-machines/mem/pil/mem_align_rom.pil +++ b/state-machines/mem/pil/mem_align_rom.pil @@ -13,7 +13,7 @@ const int MEM_ALIGN_ROM_SIZE = P2_8; // Note1: The offset and width are sufficient to group programs with the same number of operations. // Note2: The first instruction is always a read. -airtemplate MemAlignRom(const int N = MEM_ALIGN_ROM_SIZE, const int MEM_BYTES = 8, const int disable_fixed = 0) { +airtemplate MemAlignRom(const int N = MEM_ALIGN_ROM_SIZE, const int CHUNK_NUM = 8, const int disable_fixed = 0) { if (N < MEM_ALIGN_ROM_SIZE) { error(`N must be at least ${MEM_ALIGN_ROM_SIZE}, but N=${N} was provided`); } @@ -37,7 +37,7 @@ airtemplate MemAlignRom(const int N = MEM_ALIGN_ROM_SIZE, const int MEM_BYTES = [[0,0,1,0,0], [0,0,2,0,0], [0,0,3,0,0], [0,0,4,0,0], [0,0,5,0,0]:2, [0,0,6,0,0]:2, [0,0,7,0,0]:3]]...; // RWVWR 5*4+10+10+15 = 55 | 188 => N = 2^8 col fixed WIDTH = [[[8,1,8,2,8,4]:5, [8,1,8,2]:2, [8,1]], // RV - [[8,8,1,8,8,2,8,8,4], [8,8,1,8,8,2,8,8,4]:4, [8,8,1,8,8,2]:2, [8,8,1]], // RWV + [[8,8,1,8,8,2,8,8,4]:5, [8,8,1,8,8,2]:2, [8,8,1]], // RWV [[8,8,8]:4, [8,4,8,8,8,8]:2, [8,2,8,8,4,8,8,8,8]], // RVR [[8,8,8,8,8]:4, [8,8,4,8,8,8,8,8,8,8]:2, [8,8,2,8,8,8,8,4,8,8,8,8,8,8,8]]]...; // RWVWR @@ -83,8 +83,8 @@ airtemplate MemAlignRom(const int N = MEM_ALIGN_ROM_SIZE, const int MEM_BYTES = int delta_addr = 0; int is_write = 0; int reset = 0; - int sel[MEM_BYTES]; - for (int j = 0; j < MEM_BYTES; j++) { + int sel[CHUNK_NUM]; + for (int j = 0; j < CHUNK_NUM; j++) { sel[j] = 0; } int sel_up_to_down = 0; @@ -109,7 +109,7 @@ airtemplate MemAlignRom(const int N = MEM_ALIGN_ROM_SIZE, const int MEM_BYTES = delta_addr = 1; // is_write = 0; // reset = 0; - // sel = [0:MEM_BYTES] + // sel = [0:CHUNK_NUM] // sel_up_to_down = 0; // sel_down_to_up = 0; } @@ -140,7 +140,7 @@ airtemplate MemAlignRom(const int N = MEM_ALIGN_ROM_SIZE, const int MEM_BYTES = // delta_addr = 0; // is_write = 0; // reset = 0; - // sel = [0:MEM_BYTES] + // sel = [0:CHUNK_NUM] // sel_up_to_down = 0; // sel_down_to_up = 0; } @@ -162,7 +162,7 @@ airtemplate MemAlignRom(const int N = MEM_ALIGN_ROM_SIZE, const int MEM_BYTES = delta_addr = 1; // is_write = 0; // reset = 0; - // sel = [0:MEM_BYTES] + // sel = [0:CHUNK_NUM] // sel_up_to_down = 1; // sel_down_to_up = 0; } else { // R @@ -202,7 +202,7 @@ airtemplate MemAlignRom(const int N = MEM_ALIGN_ROM_SIZE, const int MEM_BYTES = delta_addr = 1; // is_write = 0; // reset = 0; - // sel = [0:MEM_BYTES] + // sel = [0:CHUNK_NUM] // sel_up_to_down = 0; // sel_down_to_up = 0; } else if (next % 5 == 3) { // W @@ -230,10 +230,10 @@ airtemplate MemAlignRom(const int N = MEM_ALIGN_ROM_SIZE, const int MEM_BYTES = DELTA_PC[i] = delta_pc; DELTA_ADDR[i] = delta_addr; FLAGS[i] = 0; - for (int j = 0; j < MEM_BYTES; j++) { + for (int j = 0; j < CHUNK_NUM; j++) { FLAGS[i] += sel[j] * 2**j; } - FLAGS[i] += is_write * 2**MEM_BYTES + reset * 2**(MEM_BYTES + 1) + sel_up_to_down * 2**(MEM_BYTES + 2) + sel_down_to_up * 2**(MEM_BYTES + 3); + FLAGS[i] += is_write * 2**CHUNK_NUM + reset * 2**(CHUNK_NUM + 1) + sel_up_to_down * 2**(CHUNK_NUM + 2) + sel_down_to_up * 2**(CHUNK_NUM + 3); } lookup_proves(MEM_ALIGN_ROM_ID, [PC, DELTA_PC, DELTA_ADDR, OFFSET, WIDTH, FLAGS], multiplicity); diff --git a/state-machines/mem/src/mem_proxy.rs b/state-machines/mem/src/mem_proxy.rs index 5eb6e7a2..31c0d401 100644 --- a/state-machines/mem/src/mem_proxy.rs +++ b/state-machines/mem/src/mem_proxy.rs @@ -10,6 +10,7 @@ use zisk_core::{ZiskRequiredMemory, RAM_ADDR, SYS_ADDR}; use proofman::{WitnessComponent, WitnessManager}; +#[derive(Debug, Clone)] pub enum MemOps { OneRead, OneWrite, @@ -60,29 +61,64 @@ impl MemProxy { pub fn prove( &self, - mut operations: [Vec; 2], + mut operations: &mut [Vec; 2], ) -> Result<(), Box> { let mut aligned = std::mem::take(&mut operations[0]); - let non_aligned = std::mem::take(&mut operations[1]); + let unaligned = std::mem::take(&mut operations[1]); let mut new_aligned = Vec::new(); + //trace[63927]: MemRow { addr: 2685533720, step: 5145, sel: 1, wr: 0, value: [2685534552, 0], addr_changes: 0, same_value: 0, first_addr_access_is_read: 0 } + println!("-----------------"); + println!("-- Aligned inputs:"); + for i in 0..aligned.len() { + if aligned[i].address == 2685534096 { + println!("aligned[{}]: {:?} value: {:x}", i, aligned[i], aligned[i].value); + } + } + println!("-- Unaligned inputs:"); + for i in 0..unaligned.len() { + if unaligned[i].address >= (2685534096 - 8) && unaligned[i].address <= (2685534096 + 8) { + println!("unaligned[{}]: {:?} value: {:x}", i, unaligned[i], unaligned[i].value); + } + } + println!("-----------------"); + // Step 1. Sort the aligned memory accesses timer_start_debug!(MEM_SORT); aligned.sort_by_key(|mem| mem.address); timer_stop_and_log_debug!(MEM_SORT); - // Step 2. For each non-aligned memory access - non_aligned.iter().for_each(|unaligned_access| { + // Step 2. For each unaligned memory access + unaligned.iter().for_each(|unaligned_access| { let mem_ops = Self::get_mem_ops(unaligned_access); // Step 2.1 Find the possible aligned memory access - let aligned_accesses = self.get_aligned_accesses(&unaligned_access, mem_ops, &aligned); + // TODO! Remove mem_ops.clone() + let aligned_accesses = self.get_aligned_accesses( + &unaligned_access, + mem_ops.clone(), + &aligned, + &new_aligned, + ); // Step 2.2 Align memory access using mem_align state machine - // self.mem_aligned_sm.align_mem_accesses(potential_aligned_mem, mem, &mut new_aligned); + // self.mem_align_sm.prove(&aligned_accesses, unaligned_access); + + for access in new_aligned.iter() { + if access.step == 4682 { + println!("new_aligned: {:?}", access); + } + } // Step 2.3 Store the new aligned memory access(es) + if unaligned_access.step == 5145 { + println!("*** mem_ops: {:?}", mem_ops); + println!("*** unaligned_access: {:?}", unaligned_access); + println!("*** aligned_accesses: {:?}", aligned_accesses); + } + new_aligned.extend(aligned_accesses); + new_aligned.sort_by_key(|mem| mem.address); }); // Step 3. Concatenate the new aligned memory accesses with the original aligned memory @@ -90,15 +126,26 @@ impl MemProxy { aligned.extend(new_aligned); timer_start_debug!(MEM_SORT_2); - aligned.sort_by_key(|mem| mem.address); + aligned.sort_by_key(|mem| (mem.address, mem.step)); timer_stop_and_log_debug!(MEM_SORT_2); let mut idx = 0; while aligned[idx].address < RAM_ADDR && idx < aligned.len() { idx += 1; } + + println!("Aligned len(): {:?}", aligned.len()); + let (_input_aligned, aligned) = aligned.split_at_mut(idx); + // Filter where address = 2684391184 + println!(""); + for i in 0.. aligned.len() { + if aligned[i].address == 2685534096 { + println!("OJO!!!! mem: {:?}", aligned[i]); + } + } + // Step 4. Prove the aligned memory accesses using mem state machine self.mem_sm.prove(aligned); @@ -111,30 +158,42 @@ impl MemProxy { unaligned_access: &ZiskRequiredMemory, mem_ops: MemOps, aligned_accesses: &[ZiskRequiredMemory], + new_aligned_accesses: &[ZiskRequiredMemory], ) -> Vec { // Align down to a 8 byte addres let addr = unaligned_access.address & !7; match mem_ops { MemOps::OneRead => { // Look for last write to the same address - let last_write_addr = - Self::get_last_write(addr, unaligned_access.step, aligned_accesses); - let last_write_addr = last_write_addr.unwrap_or(ZiskRequiredMemory { + let last_write_addr = Self::get_last_write( + addr, + unaligned_access.step, + aligned_accesses, + Some(new_aligned_accesses), + ); + let mut last_write_addr = last_write_addr.unwrap_or(ZiskRequiredMemory { step: unaligned_access.step, is_write: false, address: addr, width: 8, value: 0, }); + + last_write_addr.step = unaligned_access.step; + vec![last_write_addr] } MemOps::OneWrite => { // Look for last write to the same address - let last_write_addr = - Self::get_last_write(addr, unaligned_access.step, aligned_accesses); + let last_write_addr = Self::get_last_write( + addr, + unaligned_access.step, + aligned_accesses, + Some(new_aligned_accesses), + ); // Modify the value of the write to the same address - let mut last_write_addr = last_write_addr.unwrap_or(ZiskRequiredMemory { + let last_write_addr = last_write_addr.unwrap_or(ZiskRequiredMemory { step: unaligned_access.step, is_write: true, address: addr, @@ -142,17 +201,32 @@ impl MemProxy { value: 0, }); - Self::write_value(&unaligned_access, &mut last_write_addr); - vec![last_write_addr] + let mut last_write_addr_r = last_write_addr.clone(); + last_write_addr_r.step = unaligned_access.step; + last_write_addr_r.is_write = false; + + let mut last_write_addr_w = last_write_addr; + last_write_addr_w.step = unaligned_access.step; + Self::write_value(&unaligned_access, &mut last_write_addr_w); + + vec![last_write_addr_r, last_write_addr_w] } MemOps::TwoReads => { // Look for last write to the same address and same address + 8 - let last_write_addr = - Self::get_last_write(addr, unaligned_access.step, aligned_accesses); - let last_write_addr_p = - Self::get_last_write(addr + 8, unaligned_access.step, aligned_accesses); + let last_write_addr = Self::get_last_write( + addr, + unaligned_access.step, + aligned_accesses, + Some(new_aligned_accesses), + ); + let last_write_addr_p = Self::get_last_write( + addr + 8, + unaligned_access.step, + aligned_accesses, + Some(new_aligned_accesses), + ); - let last_write_addr = last_write_addr.unwrap_or(ZiskRequiredMemory { + let mut last_write_addr = last_write_addr.unwrap_or(ZiskRequiredMemory { step: unaligned_access.step, is_write: false, address: addr, @@ -160,7 +234,7 @@ impl MemProxy { value: 0, }); - let last_write_addr_p = last_write_addr_p.unwrap_or(ZiskRequiredMemory { + let mut last_write_addr_p = last_write_addr_p.unwrap_or(ZiskRequiredMemory { step: unaligned_access.step, is_write: false, address: addr + 8, @@ -168,33 +242,61 @@ impl MemProxy { value: 0, }); + last_write_addr.step = unaligned_access.step; + last_write_addr_p.step = unaligned_access.step; + vec![last_write_addr, last_write_addr_p] } MemOps::TwoWrites => { // Look for last write to the same address and same address + 8 - let last_write_addr = - Self::get_last_write(addr, unaligned_access.step, aligned_accesses); - let last_write_addr_p = - Self::get_last_write(addr + 8, unaligned_access.step, aligned_accesses); + let last_write_addr = Self::get_last_write( + addr, + unaligned_access.step, + aligned_accesses, + Some(new_aligned_accesses), + ); + let last_write_addr_p = Self::get_last_write( + addr + 8, + unaligned_access.step, + aligned_accesses, + Some(new_aligned_accesses), + ); - let mut last_write_addr = last_write_addr.unwrap_or(ZiskRequiredMemory { + // Modify the value of the write to the same address + let last_write_addr = last_write_addr.unwrap_or(ZiskRequiredMemory { step: unaligned_access.step, is_write: true, address: addr, width: 8, - value: 1, + value: 0, }); - let mut last_write_addr_p = last_write_addr_p.unwrap_or(ZiskRequiredMemory { + let mut last_write_addr_r = last_write_addr.clone(); + last_write_addr_r.step = unaligned_access.step; + last_write_addr_r.is_write = false; + + let mut last_write_addr_w = last_write_addr; + last_write_addr_w.step = unaligned_access.step; + Self::write_value(&unaligned_access, &mut last_write_addr_w); + + let last_write_addr_p = last_write_addr_p.unwrap_or(ZiskRequiredMemory { step: unaligned_access.step, is_write: true, address: addr + 8, width: 8, - value: 1, + value: 0, }); - Self::write_values(&unaligned_access, &mut last_write_addr, &mut last_write_addr_p); - vec![last_write_addr, last_write_addr_p] + let mut last_write_addr_p_r = last_write_addr_p.clone(); + last_write_addr_p_r.step = unaligned_access.step; + last_write_addr_p_r.is_write = false; + + let mut last_write_addr_p_w = last_write_addr_p; + last_write_addr_p_w.step = unaligned_access.step; + Self::write_value(&unaligned_access, &mut last_write_addr_p_w); + + Self::write_values(&unaligned_access, &mut last_write_addr_w, &mut last_write_addr_p_w); + vec![last_write_addr_r, last_write_addr_w, last_write_addr_p_r, last_write_addr_p_w] } } } @@ -204,6 +306,7 @@ impl MemProxy { addr: u64, step: u64, aligned_accesses: &[ZiskRequiredMemory], + new_aligned_accesses: Option<&[ZiskRequiredMemory]>, ) -> Option { // Step 1: Find the start of the range for `addr` let start_index = @@ -232,6 +335,24 @@ impl MemProxy { } } + // Step 3: If `new_aligned_accesses` exists, check for a more recent write + if let None = new_aligned_accesses { + return last_write; + } + + let new_aligned_accesses = new_aligned_accesses.unwrap(); + let last_new_write = Self::get_last_write(addr, step, new_aligned_accesses, None); + + if let None = last_write { + return last_new_write; + } + + if let Some(last_new_write) = last_new_write { + if last_new_write.step > last_write.as_ref().unwrap().step { + return Some(last_new_write); + } + } + last_write } diff --git a/state-machines/mem/src/mem_sm.rs b/state-machines/mem/src/mem_sm.rs index a45bb684..89621c6e 100644 --- a/state-machines/mem/src/mem_sm.rs +++ b/state-machines/mem/src/mem_sm.rs @@ -231,6 +231,8 @@ impl MemSM { trace[i].first_addr_access_is_read = F::zero(); } + println!("trace[66094]: {:?}", trace[66094]); + let mut air_instance = AirInstance::new( self.wcm.get_sctx(), ZISK_AIRGROUP_ID, diff --git a/witness-computation/src/executor.rs b/witness-computation/src/executor.rs index 6d928e75..e50bf736 100644 --- a/witness-computation/src/executor.rs +++ b/witness-computation/src/executor.rs @@ -187,7 +187,7 @@ impl ZiskExecutor { // STEP 2. Wait until all inputs are generated // ============================================== // Join all the threads to synchronize the execution - let mem_required = mem_thread.join().expect("Error during Memory witness computation"); + let mut mem_required = mem_thread.join().expect("Error during Memory witness computation"); let rom_required = rom_thread.join().expect("Error during ROM witness computation"); // STEP 3. Generate AIRs and Prove @@ -197,7 +197,7 @@ impl ZiskExecutor { // ---------------------------------------------- let mem_thread = thread::spawn({ let mem_proxy = self.mem_proxy.clone(); - move || mem_proxy.prove(mem_required).expect("Error during Memory witness computation") + move || mem_proxy.prove(&mut mem_required).expect("Error during Memory witness computation") }); // ROM State Machine From 5551def4131c1469207c6bd3fce4ebe5382782f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9ctor=20Masip?= Date: Thu, 7 Nov 2024 07:34:59 +0000 Subject: [PATCH 13/44] Fixing bugs --- state-machines/main/src/main_sm.rs | 8 +- state-machines/mem/pil/mem_align.pil | 10 +- state-machines/mem/src/mem_align_rom_sm.rs | 5 +- state-machines/mem/src/mem_align_sm.rs | 35 ++-- state-machines/mem/src/mem_proxy.rs | 197 +++++++++++++++++---- witness-computation/src/executor.rs | 18 +- 6 files changed, 201 insertions(+), 72 deletions(-) diff --git a/state-machines/main/src/main_sm.rs b/state-machines/main/src/main_sm.rs index d9532d81..ca3d5c0a 100644 --- a/state-machines/main/src/main_sm.rs +++ b/state-machines/main/src/main_sm.rs @@ -1,5 +1,6 @@ use log::info; use p3_field::PrimeField; +use sm_mem::MemProxy; use crate::InstanceExtensionCtx; use proofman_util::{timer_start_debug, timer_stop_and_log_debug}; @@ -26,6 +27,9 @@ pub struct MainSM { /// Witness computation manager wcm: Arc>, + /// Memory state machine + mem_proxy_sm: Arc>, + /// Arithmetic state machine arith_sm: Arc, @@ -49,14 +53,16 @@ impl MainSM { /// * Arc to the MainSM state machine pub fn new( wcm: Arc>, + mem_proxy_sm: Arc>, arith_sm: Arc, binary_sm: Arc>, ) -> Arc { - let main_sm = Arc::new(Self { wcm: wcm.clone(), arith_sm, binary_sm }); + let main_sm = Arc::new(Self { wcm: wcm.clone(), mem_proxy_sm, arith_sm, binary_sm }); wcm.register_component(main_sm.clone(), Some(ZISK_AIRGROUP_ID), Some(MAIN_AIR_IDS)); // For all the secondary state machines, register the main state machine as a predecessor + main_sm.mem_proxy_sm.register_predecessor(); main_sm.binary_sm.register_predecessor(); main_sm.arith_sm.register_predecessor(); diff --git a/state-machines/mem/pil/mem_align.pil b/state-machines/mem/pil/mem_align.pil index 2293ba35..f5405f0d 100644 --- a/state-machines/mem/pil/mem_align.pil +++ b/state-machines/mem/pil/mem_align.pil @@ -62,23 +62,23 @@ require "std_range_check.pil" ========================================================== (offset = 6, width = 4) +----+----+----+----+----+----+----+----+ - | R7 | R6 | R5 | R4 | R3 | R2 | R1 | R0 | [R1] (assume, up_to_down) sel = [1,1,1,1,1,1,0,0] + | R0 | R1 | R2 | R3 | R4 | R5 | R6 | R7 | [R1] (assume, up_to_down) sel = [1,1,1,1,1,1,0,0] +----+----+----+----+----+----+----+----+ ⇓ +----+----+----+----+----+----+====+====+ - | W7 | W6 | W5 | W4 | W3 | W2 | W1 | W0 | [W1] (assume, up_to_down) sel = [0,0,0,0,0,0,1,1] + | W0 | W1 | W2 | W3 | W4 | W5 | W6 | W7 | [W1] (assume, up_to_down) sel = [0,0,0,0,0,0,1,1] +----+----+----+----+----+----+====+====+ ⇓ +====+====+----+----+----+----+====+====+ - | V1 | V0 | V7 | V6 | V5 | V4 | V3 | V2 | [V] (prove) (shift (offset + width) % 8) sel = [0,0,0,0,0,0,1,0] (*) + | V6 | V7 | V0 | V1 | V2 | V3 | V4 | V5 | [V] (prove) (shift (offset + width) % 8) sel = [0,0,0,0,0,0,1,0] (*) +====+====+----+----+----+----+====+====+ ⇓ +====+====+----+----+----+----+----+----+ - | W7 | W6 | W5 | W4 | W3 | W2 | W1 | W0 | [W2] (assume, down_to_up) sel = [1,1,0,0,0,0,0,0] + | W0 | W1 | W2 | W3 | W4 | W5 | W6 | W7 | [W2] (assume, down_to_up) sel = [1,1,0,0,0,0,0,0] +====+====+----+----+----+----+----+----+ ⇓ +----+----+----+----+----+----+----+----+ - | R7 | R6 | R5 | R4 | R3 | R2 | R1 | R0 | [R2] (assume, down_to_up) sel = [0,0,1,1,1,1,1,1] + | R0 | R1 | R2 | R3 | R4 | R5 | R6 | R7 | [R2] (assume, down_to_up) sel = [0,0,1,1,1,1,1,1] +----+----+----+----+----+----+----+----+ (*) In this step, we use the selectors to indicate the "scanning" needed to form the bus value: diff --git a/state-machines/mem/src/mem_align_rom_sm.rs b/state-machines/mem/src/mem_align_rom_sm.rs index 702eb724..5a0e63b1 100644 --- a/state-machines/mem/src/mem_align_rom_sm.rs +++ b/state-machines/mem/src/mem_align_rom_sm.rs @@ -215,10 +215,7 @@ impl MemAlignRomSM { None, prover_buffer, ); - self.wcm - .get_pctx() - .air_instance_repo - .add_air_instance(air_instance, None); + self.wcm.get_pctx().air_instance_repo.add_air_instance(air_instance, None); } } diff --git a/state-machines/mem/src/mem_align_sm.rs b/state-machines/mem/src/mem_align_sm.rs index 7cdb8465..d82ae33e 100644 --- a/state-machines/mem/src/mem_align_sm.rs +++ b/state-machines/mem/src/mem_align_sm.rs @@ -1,3 +1,4 @@ +use core::panic; use std::{ collections::HashMap, sync::{ @@ -19,8 +20,6 @@ use zisk_pil::{MemAlignRow, MemAlignTrace, MEM_ALIGN_AIR_IDS, ZISK_AIRGROUP_ID}; use crate::{MemAlignRomSM, MemOp}; -const PROVE_CHUNK_SIZE: usize = 1 << 12; - const CHUNK_NUM: usize = 8; const CHUNK_NUM_U64: u64 = CHUNK_NUM as u64; const CHUNK_BITS: usize = 8; @@ -37,7 +36,6 @@ pub struct MemAlignSM { // Inputs inputs: Mutex)>>, - input_len: Mutex, // Secondary State machines mem_align_rom_sm: Arc>, @@ -56,7 +54,6 @@ impl MemAlignSM { std: std.clone(), registered_predecessors: AtomicU32::new(0), inputs: Mutex::new(Vec::new()), - input_len: Mutex::new(0), mem_align_rom_sm, }; let mem_align_sm = Arc::new(mem_align_sm); @@ -81,7 +78,7 @@ impl MemAlignSM { pub fn unregister_predecessor(&self) { if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { // TODO: Fix this... - self.prove_internal(&[], 0); + self.prove_internal(&[]); self.mem_align_rom_sm.unregister_predecessor(); self.std.unregister_predecessor(self.wcm.get_pctx(), None); @@ -108,20 +105,18 @@ impl MemAlignSM { unaligned_access: &ZiskRequiredMemory, aligned_accesses: &[ZiskRequiredMemory], ) { - if let (Ok(mut inputs), Ok(mut input_len)) = (self.inputs.lock(), self.input_len.lock()) { + if let Ok(mut inputs) = self.inputs.lock() { inputs.push((unaligned_access.clone(), aligned_accesses.to_vec())); - *input_len += 1 + aligned_accesses.len(); let pctx = self.wcm.get_pctx(); let air_mem_align = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_ALIGN_AIR_IDS[0]); - while *input_len >= air_mem_align.num_rows() { - let num_drained = std::cmp::min(air_mem_align.num_rows(), *input_len); + // TODO: Fix this, I am assuming the wc + while inputs.len() * 5 >= air_mem_align.num_rows() { + let num_drained = std::cmp::min(air_mem_align.num_rows(), inputs.len()); let drained_inputs = inputs.drain(..num_drained).collect::>(); - let drained_len = num_drained; - *input_len -= num_drained; - self.prove_internal(&drained_inputs, drained_len); + self.prove_internal(&drained_inputs); } } } @@ -129,7 +124,6 @@ impl MemAlignSM { fn prove_internal( &self, inputs: &[(ZiskRequiredMemory, Vec)], - input_len: usize, ) { let mem_align_rom_sm = self.mem_align_rom_sm.clone(); let wcm = self.wcm.clone(); @@ -148,7 +142,6 @@ impl MemAlignSM { &mem_align_rom_sm, &std, inputs, - input_len, &mut prover_buffer, offset, ); @@ -163,10 +156,11 @@ impl MemAlignSM { mem_align_rom_sm: &MemAlignRomSM, std: &Std, inputs: &[(ZiskRequiredMemory, Vec)], - input_len: usize, prover_buffer: &mut [F], offset: u64, ) { + let input_len = inputs.len(); + let pctx = wcm.get_pctx(); let air_mem_align = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_ALIGN_AIR_IDS[0]); @@ -191,7 +185,7 @@ impl MemAlignSM { // Process the inputs while saving the values to be range checked let mut rows_processed = 0; for (unaligned_input, aligned_inputs) in inputs.iter() { - let rows = Self::process_slice( + let rows = Self::process_input( unaligned_input, aligned_inputs, mem_align_rom_sm, @@ -201,6 +195,7 @@ impl MemAlignSM { trace_buffer[rows_processed + j] = row; } rows_processed += rows.len(); + println!("rows_processed: {}", rows_processed); } // Pad the remaining rows with trivailly satisfying rows @@ -233,7 +228,7 @@ impl MemAlignSM { } #[inline(always)] - pub fn process_slice( + pub fn process_input( unaligned_input: &ZiskRequiredMemory, aligned_inputs: &[ZiskRequiredMemory], mem_align_rom_sm: &MemAlignRomSM, @@ -481,6 +476,11 @@ impl MemAlignSM { MemOp::TwoWrites => { // RWVWR // Sanity check + if aligned_inputs.len() != 4 { + println!("opcode: {:?}", op); + println!("aligned_inputs: {:?}", aligned_inputs); + println!("unaligned_input: {:?}", unaligned_input); + } assert!(aligned_inputs.len() == 4); // Get the aligned address @@ -488,6 +488,7 @@ impl MemAlignSM { let addr_second_read_write = aligned_inputs[2].address; // addr / CHUNK_NUM + CHUNK_NUM; // Get the aligned values + // TODO: I do not need to establish an order, I can use the field is_write!!! let value_first_read = aligned_inputs[0].value.to_be_bytes(); let value_first_write = aligned_inputs[1].value.to_be_bytes(); let value_second_read = aligned_inputs[2].value.to_be_bytes(); diff --git a/state-machines/mem/src/mem_proxy.rs b/state-machines/mem/src/mem_proxy.rs index 73c63176..03266910 100644 --- a/state-machines/mem/src/mem_proxy.rs +++ b/state-machines/mem/src/mem_proxy.rs @@ -28,16 +28,16 @@ impl MemProxy { let mem_proxy = Self { registered_predecessors: AtomicU32::new(0), - mem_sm: mem_sm.clone(), - mem_align_sm: mem_align_sm.clone(), + mem_sm, + mem_align_sm, }; let mem_proxy = Arc::new(mem_proxy); wcm.register_component(mem_proxy.clone(), None, None); - // For all the secondary state machines, register the mem_proxy as a predecessor - mem_sm.register_predecessor(); - mem_align_sm.register_predecessor(); + // For all the secondary state machines, register the main state machine as a predecessor + mem_proxy.mem_sm.register_predecessor(); + mem_proxy.mem_align_sm.register_predecessor(); mem_proxy } @@ -55,30 +55,66 @@ impl MemProxy { pub fn prove( &self, - mut operations: [Vec; 2], + mut operations: &mut [Vec; 2], ) -> Result<(), Box> { let mut aligned = std::mem::take(&mut operations[0]); - let non_aligned = std::mem::take(&mut operations[1]); + let unaligned = std::mem::take(&mut operations[1]); let mut new_aligned = Vec::new(); + //trace[63927]: MemRow { addr: 2685533720, step: 5145, sel: 1, wr: 0, value: [2685534552, 0], addr_changes: 0, same_value: 0, first_addr_access_is_read: 0 } + // println!("-----------------"); + // println!("-- Aligned inputs:"); + // for i in 0..aligned.len() { + // if aligned[i].address == 2685534096 { + // println!("aligned[{}]: {:?} value: {:x}", i, aligned[i], aligned[i].value); + // } + // } + // println!("-- Unaligned inputs:"); + // for i in 0..unaligned.len() { + // if unaligned[i].address >= (2685534096 - 8) && unaligned[i].address <= (2685534096 + 8) + // { + // println!("unaligned[{}]: {:?} value: {:x}", i, unaligned[i], unaligned[i].value); + // } + // } + // println!("-----------------"); + // Step 1. Sort the aligned memory accesses timer_start_debug!(MEM_SORT); aligned.sort_by_key(|mem| mem.address); timer_stop_and_log_debug!(MEM_SORT); - // Step 2. For each non-aligned memory access - non_aligned.iter().for_each(|unaligned_access| { - // Step 2.1 Ask to the Mem Align SM for the aligned memory accesses generated by the non-aligned one + // Step 2. For each unaligned memory access + unaligned.iter().for_each(|unaligned_access| { + // Step 2.1 Ask to the Mem Align SM for the aligned memory accesses generated by the non-aligned one let mem_op = MemAlignSM::::get_mem_op(unaligned_access); // Step 2.2 Ask to the Mem SM for the aligned memory accesses - let aligned_accesses = self.get_aligned_accesses(&unaligned_access, mem_op, &aligned); + // TODO! Remove mem_op.clone() + let aligned_accesses = self.get_aligned_accesses( + &unaligned_access, + mem_op.clone(), + &aligned, + &new_aligned, + ); // Step 2.3 Carried with the aligned memory accesses, prove the non-aligned ones self.mem_align_sm.prove(unaligned_access, &aligned_accesses); + for access in new_aligned.iter() { + if access.step == 4682 { + println!("new_aligned: {:?}", access); + } + } + // Step 2.4 Store the new aligned memory access(es) + if unaligned_access.step == 5145 { + println!("*** mem_op: {:?}", mem_op); + println!("*** unaligned_access: {:?}", unaligned_access); + println!("*** aligned_accesses: {:?}", aligned_accesses); + } + new_aligned.extend(aligned_accesses); + new_aligned.sort_by_key(|mem| mem.address); }); // Step 3. Concatenate the new aligned memory accesses with the original aligned memory @@ -86,15 +122,26 @@ impl MemProxy { aligned.extend(new_aligned); timer_start_debug!(MEM_SORT_2); - aligned.sort_by_key(|mem| mem.address); + aligned.sort_by_key(|mem| (mem.address, mem.step)); timer_stop_and_log_debug!(MEM_SORT_2); let mut idx = 0; while aligned[idx].address < RAM_ADDR && idx < aligned.len() { idx += 1; } + + println!("Aligned len(): {:?}", aligned.len()); + let (_input_aligned, aligned) = aligned.split_at_mut(idx); + // Filter where address = 2684391184 + println!(""); + for i in 0..aligned.len() { + if aligned[i].address == 2685534096 { + println!("OJO!!!! mem: {:?}", aligned[i]); + } + } + // Step 4. Prove the aligned memory accesses using mem state machine self.mem_sm.prove(aligned); @@ -107,30 +154,42 @@ impl MemProxy { unaligned_access: &ZiskRequiredMemory, mem_op: MemOp, aligned_accesses: &[ZiskRequiredMemory], + new_aligned_accesses: &[ZiskRequiredMemory], ) -> Vec { // Align down to a 8 byte addres let addr = unaligned_access.address & !7; match mem_op { MemOp::OneRead => { // Look for last write to the same address - let last_write_addr = - Self::get_last_write(addr, unaligned_access.step, aligned_accesses); - let last_write_addr = last_write_addr.unwrap_or(ZiskRequiredMemory { + let last_write_addr = Self::get_last_write( + addr, + unaligned_access.step, + aligned_accesses, + Some(new_aligned_accesses), + ); + let mut last_write_addr = last_write_addr.unwrap_or(ZiskRequiredMemory { step: unaligned_access.step, is_write: false, address: addr, width: 8, value: 0, }); + + last_write_addr.step = unaligned_access.step; + vec![last_write_addr] } MemOp::OneWrite => { // Look for last write to the same address - let last_write_addr = - Self::get_last_write(addr, unaligned_access.step, aligned_accesses); + let last_write_addr = Self::get_last_write( + addr, + unaligned_access.step, + aligned_accesses, + Some(new_aligned_accesses), + ); // Modify the value of the write to the same address - let mut last_write_addr = last_write_addr.unwrap_or(ZiskRequiredMemory { + let last_write_addr = last_write_addr.unwrap_or(ZiskRequiredMemory { step: unaligned_access.step, is_write: true, address: addr, @@ -138,17 +197,32 @@ impl MemProxy { value: 0, }); - Self::write_value(&unaligned_access, &mut last_write_addr); - vec![last_write_addr] + let mut last_write_addr_r = last_write_addr.clone(); + last_write_addr_r.step = unaligned_access.step; + last_write_addr_r.is_write = false; + + let mut last_write_addr_w = last_write_addr; + last_write_addr_w.step = unaligned_access.step; + Self::write_value(&unaligned_access, &mut last_write_addr_w); + + vec![last_write_addr_r, last_write_addr_w] } MemOp::TwoReads => { // Look for last write to the same address and same address + 8 - let last_write_addr = - Self::get_last_write(addr, unaligned_access.step, aligned_accesses); - let last_write_addr_p = - Self::get_last_write(addr + 8, unaligned_access.step, aligned_accesses); + let last_write_addr = Self::get_last_write( + addr, + unaligned_access.step, + aligned_accesses, + Some(new_aligned_accesses), + ); + let last_write_addr_p = Self::get_last_write( + addr + 8, + unaligned_access.step, + aligned_accesses, + Some(new_aligned_accesses), + ); - let last_write_addr = last_write_addr.unwrap_or(ZiskRequiredMemory { + let mut last_write_addr = last_write_addr.unwrap_or(ZiskRequiredMemory { step: unaligned_access.step, is_write: false, address: addr, @@ -156,7 +230,7 @@ impl MemProxy { value: 0, }); - let last_write_addr_p = last_write_addr_p.unwrap_or(ZiskRequiredMemory { + let mut last_write_addr_p = last_write_addr_p.unwrap_or(ZiskRequiredMemory { step: unaligned_access.step, is_write: false, address: addr + 8, @@ -164,33 +238,65 @@ impl MemProxy { value: 0, }); + last_write_addr.step = unaligned_access.step; + last_write_addr_p.step = unaligned_access.step; + vec![last_write_addr, last_write_addr_p] } MemOp::TwoWrites => { // Look for last write to the same address and same address + 8 - let last_write_addr = - Self::get_last_write(addr, unaligned_access.step, aligned_accesses); - let last_write_addr_p = - Self::get_last_write(addr + 8, unaligned_access.step, aligned_accesses); + let last_write_addr = Self::get_last_write( + addr, + unaligned_access.step, + aligned_accesses, + Some(new_aligned_accesses), + ); + let last_write_addr_p = Self::get_last_write( + addr + 8, + unaligned_access.step, + aligned_accesses, + Some(new_aligned_accesses), + ); - let mut last_write_addr = last_write_addr.unwrap_or(ZiskRequiredMemory { + // Modify the value of the write to the same address + let last_write_addr = last_write_addr.unwrap_or(ZiskRequiredMemory { step: unaligned_access.step, is_write: true, address: addr, width: 8, - value: 1, + value: 0, }); - let mut last_write_addr_p = last_write_addr_p.unwrap_or(ZiskRequiredMemory { + let mut last_write_addr_r = last_write_addr.clone(); + last_write_addr_r.step = unaligned_access.step; + last_write_addr_r.is_write = false; + + let mut last_write_addr_w = last_write_addr; + last_write_addr_w.step = unaligned_access.step; + Self::write_value(&unaligned_access, &mut last_write_addr_w); + + let last_write_addr_p = last_write_addr_p.unwrap_or(ZiskRequiredMemory { step: unaligned_access.step, is_write: true, address: addr + 8, width: 8, - value: 1, + value: 0, }); - Self::write_values(&unaligned_access, &mut last_write_addr, &mut last_write_addr_p); - vec![last_write_addr, last_write_addr_p] + let mut last_write_addr_p_r = last_write_addr_p.clone(); + last_write_addr_p_r.step = unaligned_access.step; + last_write_addr_p_r.is_write = false; + + let mut last_write_addr_p_w = last_write_addr_p; + last_write_addr_p_w.step = unaligned_access.step; + Self::write_value(&unaligned_access, &mut last_write_addr_p_w); + + Self::write_values( + &unaligned_access, + &mut last_write_addr_w, + &mut last_write_addr_p_w, + ); + vec![last_write_addr_r, last_write_addr_w, last_write_addr_p_r, last_write_addr_p_w] } } } @@ -200,6 +306,7 @@ impl MemProxy { addr: u64, step: u64, aligned_accesses: &[ZiskRequiredMemory], + new_aligned_accesses: Option<&[ZiskRequiredMemory]>, ) -> Option { // Step 1: Find the start of the range for `addr` let start_index = @@ -228,6 +335,24 @@ impl MemProxy { } } + // Step 3: If `new_aligned_accesses` exists, check for a more recent write + if let None = new_aligned_accesses { + return last_write; + } + + let new_aligned_accesses = new_aligned_accesses.unwrap(); + let last_new_write = Self::get_last_write(addr, step, new_aligned_accesses, None); + + if let None = last_write { + return last_new_write; + } + + if let Some(last_new_write) = last_new_write { + if last_new_write.step > last_write.as_ref().unwrap().step { + return Some(last_new_write); + } + } + last_write } diff --git a/witness-computation/src/executor.rs b/witness-computation/src/executor.rs index cd289cfb..853a5d09 100644 --- a/witness-computation/src/executor.rs +++ b/witness-computation/src/executor.rs @@ -35,7 +35,7 @@ pub struct ZiskExecutor { pub rom_sm: Arc>, /// Memory State Machine - pub mem_proxy: Arc>, + pub mem_proxy_sm: Arc>, /// Binary State Machine pub binary_sm: Arc>, @@ -51,7 +51,7 @@ impl ZiskExecutor { let std = Std::new(wcm.clone()); let rom_sm = RomSM::new(wcm.clone()); - let mem_proxy = MemProxy::new(wcm.clone(), std.clone()); + let mem_proxy_sm = MemProxy::new(wcm.clone(), std.clone()); let binary_sm = BinarySM::new(wcm.clone(), std.clone()); let arith_sm = ArithSM::new(wcm.clone()); @@ -83,9 +83,9 @@ impl ZiskExecutor { // TODO - If there is more than one Main AIR available, the MAX_ACCUMULATED will be the one // with the highest num_rows. It has to be a power of 2. - let main_sm = MainSM::new(wcm.clone(), arith_sm.clone(), binary_sm.clone()); + let main_sm = MainSM::new(wcm.clone(), mem_proxy_sm.clone(), arith_sm.clone(), binary_sm.clone()); - Self { zisk_rom, main_sm, rom_sm, mem_proxy, binary_sm, arith_sm } + Self { zisk_rom, main_sm, rom_sm, mem_proxy_sm, binary_sm, arith_sm } } /// Executes the MainSM state machine and processes the inputs in batches when the maximum @@ -187,7 +187,7 @@ impl ZiskExecutor { // STEP 2. Wait until all inputs are generated // ============================================== // Join all the threads to synchronize the execution - let mem_required = mem_thread.join().expect("Error during Memory witness computation"); + let mut mem_required = mem_thread.join().expect("Error during Memory witness computation"); let rom_required = rom_thread.join().expect("Error during ROM witness computation"); // STEP 3. Generate AIRs and Prove @@ -196,8 +196,8 @@ impl ZiskExecutor { // Memory State Machine // ---------------------------------------------- let mem_thread = thread::spawn({ - let mem_proxy = self.mem_proxy.clone(); - move || mem_proxy.prove(mem_required).expect("Error during Memory witness computation") + let mem_proxy_sm = self.mem_proxy_sm.clone(); + move || mem_proxy_sm.prove(&mut mem_required).expect("Error during Memory witness computation") }); // ROM State Machine @@ -286,8 +286,8 @@ impl ZiskExecutor { let _ = thread.join().expect("Error during ROM witness computation"); } - self.mem_proxy.unregister_predecessor(); + self.mem_proxy_sm.unregister_predecessor(); self.binary_sm.unregister_predecessor(); // self.arith_sm.register_predecessor(scope); } -} +} \ No newline at end of file From 1f7786e426a98df48f8b943f64df2c614ce71436 Mon Sep 17 00:00:00 2001 From: Xavier Pinsach <10213118+xavi-pinsach@users.noreply.github.com> Date: Thu, 7 Nov 2024 08:20:07 +0000 Subject: [PATCH 14/44] mem_proxy to little endian --- state-machines/mem/src/mem_proxy.rs | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/state-machines/mem/src/mem_proxy.rs b/state-machines/mem/src/mem_proxy.rs index 31c0d401..9e8489d9 100644 --- a/state-machines/mem/src/mem_proxy.rs +++ b/state-machines/mem/src/mem_proxy.rs @@ -77,7 +77,8 @@ impl MemProxy { } println!("-- Unaligned inputs:"); for i in 0..unaligned.len() { - if unaligned[i].address >= (2685534096 - 8) && unaligned[i].address <= (2685534096 + 8) { + if unaligned[i].address >= (2685534096 - 8) && unaligned[i].address <= (2685534096 + 8) + { println!("unaligned[{}]: {:?} value: {:x}", i, unaligned[i], unaligned[i].value); } } @@ -140,7 +141,7 @@ impl MemProxy { // Filter where address = 2684391184 println!(""); - for i in 0.. aligned.len() { + for i in 0..aligned.len() { if aligned[i].address == 2685534096 { println!("OJO!!!! mem: {:?}", aligned[i]); } @@ -295,7 +296,11 @@ impl MemProxy { last_write_addr_p_w.step = unaligned_access.step; Self::write_value(&unaligned_access, &mut last_write_addr_p_w); - Self::write_values(&unaligned_access, &mut last_write_addr_w, &mut last_write_addr_p_w); + Self::write_values( + &unaligned_access, + &mut last_write_addr_w, + &mut last_write_addr_p_w, + ); vec![last_write_addr_r, last_write_addr_w, last_write_addr_p_r, last_write_addr_p_w] } } @@ -358,15 +363,14 @@ impl MemProxy { #[inline(always)] fn write_value(unaligned: &ZiskRequiredMemory, aligned: &mut ZiskRequiredMemory) { - let offset = 8 - (unaligned.address & 7); + let offset = unaligned.address & 7; let width_in_bits = unaligned.width * 8; - let mask = !(((1u64 << width_in_bits) - 1) << ((offset - unaligned.width) * 8)); + let mask = !(((1u64 << width_in_bits) - 1) << (offset * 8)); - aligned.value = - (aligned.value & mask) | (unaligned.value << ((offset - unaligned.width) * 8)); + aligned.value = (aligned.value & mask) + | ((unaligned.value & ((1u64 << width_in_bits) - 1)) << (offset * 8)); } - #[inline(always)] fn write_values( unaligned: &ZiskRequiredMemory, From 89a59ada8e5995a7fa62ff5b19a42b5ddf918733 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9ctor=20Masip?= Date: Thu, 7 Nov 2024 11:00:47 +0000 Subject: [PATCH 15/44] Fix mem bugs --- state-machines/main/pil/main.pil | 10 +++--- state-machines/mem/pil/mem.pil | 7 ++-- state-machines/mem/src/mem_proxy.rs | 52 ++++------------------------- 3 files changed, 15 insertions(+), 54 deletions(-) diff --git a/state-machines/main/pil/main.pil b/state-machines/main/pil/main.pil index 2d6c3d24..095a6ac6 100644 --- a/state-machines/main/pil/main.pil +++ b/state-machines/main/pil/main.pil @@ -112,8 +112,6 @@ airtemplate Main(int N = 2**21, int RC = 2, int stack_enabled = 0, const int ope col witness jmp_offset1, jmp_offset2; // if flag, goto2, else goto 1 col witness m32; - const expr addr_step = STEP * 3; - const expr sel_mem_b; sel_mem_b = b_src_mem + b_src_ind; @@ -136,13 +134,14 @@ airtemplate Main(int N = 2**21, int RC = 2, int stack_enabled = 0, const int ope // Mem.load mem_load(sel: a_src_mem, - step: addr_step, + step: STEP, addr: addr0, value: a); // Mem.load mem_load(sel: sel_mem_b, - step: addr_step + 1, + step: STEP, + step_offset: 1, bytes: ind_width, addr: addr1, value: b); @@ -154,7 +153,8 @@ airtemplate Main(int N = 2**21, int RC = 2, int stack_enabled = 0, const int ope // Mem.store mem_store(sel: store_mem + store_ind, - step: addr_step + 2, + step: STEP, + step_offset: 2, bytes: ind_width, addr: addr2, value: store_value); diff --git a/state-machines/mem/pil/mem.pil b/state-machines/mem/pil/mem.pil index 50da226d..92035d3c 100644 --- a/state-machines/mem/pil/mem.pil +++ b/state-machines/mem/pil/mem.pil @@ -8,7 +8,8 @@ const int MEMORY_LOAD_OP = 1; const int MEMORY_STORE_OP = 2; const int MEMORY_MAX_DIFF = 2**22; -const int MAX_MEM_STEP_OFFSET = 3; +const int MAX_MEM_STEP_OFFSET = 2; +const int MAX_MEM_OPS_PER_MAIN_STEP = (MAX_MEM_STEP_OFFSET + 1) * 2; airtemplate Mem(const int N = 2**21, const int RC = 2, const int id = MEMORY_ID, const int MAX_STEP = 2 ** 23, const int MEM_BYTES = 8) { col fixed SEGMENT_L1 = [1,0...]; @@ -106,9 +107,9 @@ function mem_store(int id = MEMORY_ID, expr addr, expr step, expr step_offset = private function mem_assumes(int id, int mem_op, expr addr, expr step, expr step_offset, expr bytes, expr value[], expr sel) { if (step_offset > MAX_MEM_STEP_OFFSET) { - error("max step_offset ${step_offset} is greater than max value ${MAX_MEM_STEP_OFFSET}"); + error("step_offset ${step_offset} is greater than max value allowed ${MAX_MEM_STEP_OFFSET}"); } // adding 1 at step for first continuation - permutation_assumes(id, [mem_op, addr, 1 + ((MAX_MEM_STEP_OFFSET + 1) * step) + step_offset, bytes, ...value], sel: sel); + permutation_assumes(id, [mem_op, addr, 1 + MAX_MEM_OPS_PER_MAIN_STEP * step + 2 * step_offset, bytes, ...value], sel: sel); } \ No newline at end of file diff --git a/state-machines/mem/src/mem_proxy.rs b/state-machines/mem/src/mem_proxy.rs index 03266910..5b63cdce 100644 --- a/state-machines/mem/src/mem_proxy.rs +++ b/state-machines/mem/src/mem_proxy.rs @@ -61,23 +61,6 @@ impl MemProxy { let unaligned = std::mem::take(&mut operations[1]); let mut new_aligned = Vec::new(); - //trace[63927]: MemRow { addr: 2685533720, step: 5145, sel: 1, wr: 0, value: [2685534552, 0], addr_changes: 0, same_value: 0, first_addr_access_is_read: 0 } - // println!("-----------------"); - // println!("-- Aligned inputs:"); - // for i in 0..aligned.len() { - // if aligned[i].address == 2685534096 { - // println!("aligned[{}]: {:?} value: {:x}", i, aligned[i], aligned[i].value); - // } - // } - // println!("-- Unaligned inputs:"); - // for i in 0..unaligned.len() { - // if unaligned[i].address >= (2685534096 - 8) && unaligned[i].address <= (2685534096 + 8) - // { - // println!("unaligned[{}]: {:?} value: {:x}", i, unaligned[i], unaligned[i].value); - // } - // } - // println!("-----------------"); - // Step 1. Sort the aligned memory accesses timer_start_debug!(MEM_SORT); aligned.sort_by_key(|mem| mem.address); @@ -100,19 +83,7 @@ impl MemProxy { // Step 2.3 Carried with the aligned memory accesses, prove the non-aligned ones self.mem_align_sm.prove(unaligned_access, &aligned_accesses); - for access in new_aligned.iter() { - if access.step == 4682 { - println!("new_aligned: {:?}", access); - } - } - // Step 2.4 Store the new aligned memory access(es) - if unaligned_access.step == 5145 { - println!("*** mem_op: {:?}", mem_op); - println!("*** unaligned_access: {:?}", unaligned_access); - println!("*** aligned_accesses: {:?}", aligned_accesses); - } - new_aligned.extend(aligned_accesses); new_aligned.sort_by_key(|mem| mem.address); }); @@ -130,18 +101,8 @@ impl MemProxy { idx += 1; } - println!("Aligned len(): {:?}", aligned.len()); - let (_input_aligned, aligned) = aligned.split_at_mut(idx); - // Filter where address = 2684391184 - println!(""); - for i in 0..aligned.len() { - if aligned[i].address == 2685534096 { - println!("OJO!!!! mem: {:?}", aligned[i]); - } - } - // Step 4. Prove the aligned memory accesses using mem state machine self.mem_sm.prove(aligned); @@ -358,13 +319,13 @@ impl MemProxy { #[inline(always)] fn write_value(unaligned: &ZiskRequiredMemory, aligned: &mut ZiskRequiredMemory) { - let offset = 8 - (unaligned.address & 7); + let offset = unaligned.address & 7; let width_in_bits = unaligned.width * 8; - let mask = !(((1u64 << width_in_bits) - 1) << ((offset - unaligned.width) * 8)); + let mask = !(((1u64 << width_in_bits) - 1) << (offset * 8)); - aligned.value = - (aligned.value & mask) | (unaligned.value << ((offset - unaligned.width) * 8)); + aligned.value = (aligned.value & mask) + | ((unaligned.value & ((1u64 << width_in_bits) - 1)) << (offset * 8)); } #[inline(always)] @@ -378,14 +339,13 @@ impl MemProxy { let right_bits = (unaligned.width - bytes_to_write) * 8; // Left write - let left_value = unaligned.value >> right_bits; + let left_value = unaligned.value << right_bits; let left_memory = ZiskRequiredMemory { width: bytes_to_write, value: left_value, ..*unaligned }; Self::write_value(&left_memory, aligned); // Right write - let mask = (1u64 << right_bits) - 1; - let right_value = unaligned.value & mask; + let right_value = unaligned.value >> (bytes_to_write * 8); let right_memory = ZiskRequiredMemory { address: 0, From 968a53929e174f46e0ce5523d4eab5d853f221ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9ctor=20Masip?= Date: Thu, 7 Nov 2024 11:00:47 +0000 Subject: [PATCH 16/44] Fix mem bugs --- state-machines/main/pil/main.pil | 32 +++++++++---------- state-machines/mem/pil/mem.pil | 7 +++-- state-machines/mem/src/mem_proxy.rs | 48 +++-------------------------- state-machines/mem/src/mem_sm.rs | 1 - 4 files changed, 24 insertions(+), 64 deletions(-) diff --git a/state-machines/main/pil/main.pil b/state-machines/main/pil/main.pil index 027bee94..b7a9e81e 100644 --- a/state-machines/main/pil/main.pil +++ b/state-machines/main/pil/main.pil @@ -112,8 +112,6 @@ airtemplate Main(int N = 2**21, int RC = 2, int stack_enabled = 0, const int ope col witness jmp_offset1, jmp_offset2; // if flag, goto2, else goto 1 col witness m32; - const expr addr_step = STEP * 3; - const expr sel_mem_b; sel_mem_b = b_src_mem + b_src_ind; @@ -135,17 +133,18 @@ airtemplate Main(int N = 2**21, int RC = 2, int stack_enabled = 0, const int ope } // Mem.load - //mem_load(sel: a_src_mem, - // step: addr_step, - // addr: addr0, - // value: a); + mem_load(sel: a_src_mem, + step: STEP, + addr: addr0, + value: a); // Mem.load - //mem_load(sel: sel_mem_b, - // step: addr_step + 1, - // bytes: ind_width, - // addr: addr1, - // value: b); + mem_load(sel: sel_mem_b, + step: STEP, + step_offset: 1, + bytes: ind_width, + addr: addr1, + value: b); const expr store_value[2]; @@ -153,11 +152,12 @@ airtemplate Main(int N = 2**21, int RC = 2, int stack_enabled = 0, const int ope store_value[1] = (1 - store_ra) * c[1]; // Mem.store - //mem_store(sel: store_mem + store_ind, - // step: addr_step + 2, - // bytes: ind_width, - // addr: addr2, - // value: store_value); + mem_store(sel: store_mem + store_ind, + step: STEP, + step_offset: 2, + bytes: ind_width, + addr: addr2, + value: store_value); // Operation.assume => how organize software col witness __debug_operation_bus_enabled; diff --git a/state-machines/mem/pil/mem.pil b/state-machines/mem/pil/mem.pil index 50da226d..92035d3c 100644 --- a/state-machines/mem/pil/mem.pil +++ b/state-machines/mem/pil/mem.pil @@ -8,7 +8,8 @@ const int MEMORY_LOAD_OP = 1; const int MEMORY_STORE_OP = 2; const int MEMORY_MAX_DIFF = 2**22; -const int MAX_MEM_STEP_OFFSET = 3; +const int MAX_MEM_STEP_OFFSET = 2; +const int MAX_MEM_OPS_PER_MAIN_STEP = (MAX_MEM_STEP_OFFSET + 1) * 2; airtemplate Mem(const int N = 2**21, const int RC = 2, const int id = MEMORY_ID, const int MAX_STEP = 2 ** 23, const int MEM_BYTES = 8) { col fixed SEGMENT_L1 = [1,0...]; @@ -106,9 +107,9 @@ function mem_store(int id = MEMORY_ID, expr addr, expr step, expr step_offset = private function mem_assumes(int id, int mem_op, expr addr, expr step, expr step_offset, expr bytes, expr value[], expr sel) { if (step_offset > MAX_MEM_STEP_OFFSET) { - error("max step_offset ${step_offset} is greater than max value ${MAX_MEM_STEP_OFFSET}"); + error("step_offset ${step_offset} is greater than max value allowed ${MAX_MEM_STEP_OFFSET}"); } // adding 1 at step for first continuation - permutation_assumes(id, [mem_op, addr, 1 + ((MAX_MEM_STEP_OFFSET + 1) * step) + step_offset, bytes, ...value], sel: sel); + permutation_assumes(id, [mem_op, addr, 1 + MAX_MEM_OPS_PER_MAIN_STEP * step + 2 * step_offset, bytes, ...value], sel: sel); } \ No newline at end of file diff --git a/state-machines/mem/src/mem_proxy.rs b/state-machines/mem/src/mem_proxy.rs index 9e8489d9..549b83cd 100644 --- a/state-machines/mem/src/mem_proxy.rs +++ b/state-machines/mem/src/mem_proxy.rs @@ -6,7 +6,7 @@ use std::sync::{ use crate::{MemAlignSM, MemSM}; use p3_field::PrimeField; use proofman_util::{timer_start_debug, timer_stop_and_log_debug}; -use zisk_core::{ZiskRequiredMemory, RAM_ADDR, SYS_ADDR}; +use zisk_core::{ZiskRequiredMemory, RAM_ADDR}; use proofman::{WitnessComponent, WitnessManager}; @@ -61,29 +61,12 @@ impl MemProxy { pub fn prove( &self, - mut operations: &mut [Vec; 2], + operations: &mut [Vec; 2], ) -> Result<(), Box> { let mut aligned = std::mem::take(&mut operations[0]); let unaligned = std::mem::take(&mut operations[1]); let mut new_aligned = Vec::new(); - //trace[63927]: MemRow { addr: 2685533720, step: 5145, sel: 1, wr: 0, value: [2685534552, 0], addr_changes: 0, same_value: 0, first_addr_access_is_read: 0 } - println!("-----------------"); - println!("-- Aligned inputs:"); - for i in 0..aligned.len() { - if aligned[i].address == 2685534096 { - println!("aligned[{}]: {:?} value: {:x}", i, aligned[i], aligned[i].value); - } - } - println!("-- Unaligned inputs:"); - for i in 0..unaligned.len() { - if unaligned[i].address >= (2685534096 - 8) && unaligned[i].address <= (2685534096 + 8) - { - println!("unaligned[{}]: {:?} value: {:x}", i, unaligned[i], unaligned[i].value); - } - } - println!("-----------------"); - // Step 1. Sort the aligned memory accesses timer_start_debug!(MEM_SORT); aligned.sort_by_key(|mem| mem.address); @@ -105,19 +88,7 @@ impl MemProxy { // Step 2.2 Align memory access using mem_align state machine // self.mem_align_sm.prove(&aligned_accesses, unaligned_access); - for access in new_aligned.iter() { - if access.step == 4682 { - println!("new_aligned: {:?}", access); - } - } - // Step 2.3 Store the new aligned memory access(es) - if unaligned_access.step == 5145 { - println!("*** mem_ops: {:?}", mem_ops); - println!("*** unaligned_access: {:?}", unaligned_access); - println!("*** aligned_accesses: {:?}", aligned_accesses); - } - new_aligned.extend(aligned_accesses); new_aligned.sort_by_key(|mem| mem.address); }); @@ -135,18 +106,8 @@ impl MemProxy { idx += 1; } - println!("Aligned len(): {:?}", aligned.len()); - let (_input_aligned, aligned) = aligned.split_at_mut(idx); - // Filter where address = 2684391184 - println!(""); - for i in 0..aligned.len() { - if aligned[i].address == 2685534096 { - println!("OJO!!!! mem: {:?}", aligned[i]); - } - } - // Step 4. Prove the aligned memory accesses using mem state machine self.mem_sm.prove(aligned); @@ -382,14 +343,13 @@ impl MemProxy { let right_bits = (unaligned.width - bytes_to_write) * 8; // Left write - let left_value = unaligned.value >> right_bits; + let left_value = unaligned.value << right_bits; let left_memory = ZiskRequiredMemory { width: bytes_to_write, value: left_value, ..*unaligned }; Self::write_value(&left_memory, aligned); // Right write - let mask = (1u64 << right_bits) - 1; - let right_value = unaligned.value & mask; + let right_value = unaligned.value >> (bytes_to_write * 8); let right_memory = ZiskRequiredMemory { address: 0, diff --git a/state-machines/mem/src/mem_sm.rs b/state-machines/mem/src/mem_sm.rs index 89621c6e..c78d89fb 100644 --- a/state-machines/mem/src/mem_sm.rs +++ b/state-machines/mem/src/mem_sm.rs @@ -6,7 +6,6 @@ use std::sync::{ use p3_field::PrimeField; use proofman::{WitnessComponent, WitnessManager}; use proofman_common::AirInstance; -use proofman_util::{timer_start_debug, timer_stop_and_log_debug}; use rayon::prelude::*; use sm_common::create_prover_buffer; From 96131a839ab0feba2e0882881c9b61c9af0d23ca Mon Sep 17 00:00:00 2001 From: Xavier Pinsach <10213118+xavi-pinsach@users.noreply.github.com> Date: Thu, 7 Nov 2024 13:01:35 +0000 Subject: [PATCH 17/44] remove println --- state-machines/mem/src/mem_sm.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/state-machines/mem/src/mem_sm.rs b/state-machines/mem/src/mem_sm.rs index c78d89fb..b6055b0c 100644 --- a/state-machines/mem/src/mem_sm.rs +++ b/state-machines/mem/src/mem_sm.rs @@ -230,8 +230,6 @@ impl MemSM { trace[i].first_addr_access_is_read = F::zero(); } - println!("trace[66094]: {:?}", trace[66094]); - let mut air_instance = AirInstance::new( self.wcm.get_sctx(), ZISK_AIRGROUP_ID, From 7ad568fae1df06ae4e02862fbbbdc017de8b045b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9ctor=20Masip?= Date: Thu, 7 Nov 2024 13:59:37 +0000 Subject: [PATCH 18/44] Minor fixes --- state-machines/mem/src/mem_align_sm.rs | 43 +++++++++++++------------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/state-machines/mem/src/mem_align_sm.rs b/state-machines/mem/src/mem_align_sm.rs index d82ae33e..1acd5bef 100644 --- a/state-machines/mem/src/mem_align_sm.rs +++ b/state-machines/mem/src/mem_align_sm.rs @@ -78,7 +78,14 @@ impl MemAlignSM { pub fn unregister_predecessor(&self) { if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { // TODO: Fix this... - self.prove_internal(&[]); + if let Ok(mut inputs) = self.inputs.lock() { + let pctx = self.wcm.get_pctx(); + let air_mem_align = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_ALIGN_AIR_IDS[0]); + let num_drained = std::cmp::min(air_mem_align.num_rows(), inputs.len()); + let drained_inputs = inputs.drain(..num_drained).collect::>(); + + self.prove_internal(&drained_inputs); + } self.mem_align_rom_sm.unregister_predecessor(); self.std.unregister_predecessor(self.wcm.get_pctx(), None); @@ -195,7 +202,6 @@ impl MemAlignSM { trace_buffer[rows_processed + j] = row; } rows_processed += rows.len(); - println!("rows_processed: {}", rows_processed); } // Pad the remaining rows with trivailly satisfying rows @@ -238,7 +244,7 @@ impl MemAlignSM { let addr = unaligned_input.address; // Get the unaligned value - let value = unaligned_input.value.to_be_bytes(); + let value = unaligned_input.value.to_le_bytes(); // Get the unaligned step let step = unaligned_input.step; @@ -280,7 +286,7 @@ impl MemAlignSM { let addr_read = aligned_inputs[0].address; // addr / CHUNK_NUM; // Get the aligned values - let value_read = aligned_inputs[0].value.to_be_bytes(); + let value_read = aligned_inputs[0].value.to_le_bytes(); // Get the aligned step let step_read = aligned_inputs[0].step; @@ -312,7 +318,7 @@ impl MemAlignSM { read_row.reg[i] = F::from_canonical_u8(value_read[i]); read_row.sel[i] = F::from_bool(true); - value_row.reg[i] = F::from_canonical_u8(value[shift + i]); + value_row.reg[i] = F::from_canonical_u8(value[(shift + i) % CHUNK_NUM]); value_row.sel[i] = F::from_bool(i == offset); // Store the range check @@ -333,8 +339,8 @@ impl MemAlignSM { let addr_read_write = aligned_inputs[0].address; // addr / CHUNK_NUM; // Get the aligned values - let value_read = aligned_inputs[0].value.to_be_bytes(); - let value_write = aligned_inputs[1].value.to_be_bytes(); + let value_read = aligned_inputs[0].value.to_le_bytes(); + let value_write = aligned_inputs[1].value.to_le_bytes(); // Get the aligned step let step_read = aligned_inputs[0].step; @@ -384,7 +390,7 @@ impl MemAlignSM { write_row.reg[i] = F::from_canonical_u8(value_write[i]); write_row.sel[i] = F::from_bool(i >= offset); - value_row.reg[i] = F::from_canonical_u8(value[shift + i]); + value_row.reg[i] = F::from_canonical_u8(value[(shift + i) % CHUNK_NUM]); value_row.sel[i] = F::from_bool(i == offset); // Store the range check @@ -408,8 +414,8 @@ impl MemAlignSM { let addr_second_read = aligned_inputs[1].address; // addr / CHUNK_NUM + CHUNK_NUM; // Get the aligned values - let value_first_read = aligned_inputs[0].value.to_be_bytes(); - let value_second_read = aligned_inputs[1].value.to_be_bytes(); + let value_first_read = aligned_inputs[0].value.to_le_bytes(); + let value_second_read = aligned_inputs[1].value.to_le_bytes(); // Get the aligned step let step_first_read = aligned_inputs[0].step; @@ -456,7 +462,7 @@ impl MemAlignSM { first_read_row.reg[i] = F::from_canonical_u8(value_first_read[i]); first_read_row.sel[i] = F::from_bool(true); - value_row.reg[i] = F::from_canonical_u8(value[shift + i]); + value_row.reg[i] = F::from_canonical_u8(value[(shift + i) % CHUNK_NUM]); value_row.sel[i] = F::from_bool(i == offset); second_read_row.reg[i] = F::from_canonical_u8(value_second_read[i]); @@ -476,11 +482,6 @@ impl MemAlignSM { MemOp::TwoWrites => { // RWVWR // Sanity check - if aligned_inputs.len() != 4 { - println!("opcode: {:?}", op); - println!("aligned_inputs: {:?}", aligned_inputs); - println!("unaligned_input: {:?}", unaligned_input); - } assert!(aligned_inputs.len() == 4); // Get the aligned address @@ -489,10 +490,10 @@ impl MemAlignSM { // Get the aligned values // TODO: I do not need to establish an order, I can use the field is_write!!! - let value_first_read = aligned_inputs[0].value.to_be_bytes(); - let value_first_write = aligned_inputs[1].value.to_be_bytes(); - let value_second_read = aligned_inputs[2].value.to_be_bytes(); - let value_second_write = aligned_inputs[3].value.to_be_bytes(); + let value_first_read = aligned_inputs[0].value.to_le_bytes(); + let value_first_write = aligned_inputs[1].value.to_le_bytes(); + let value_second_read = aligned_inputs[2].value.to_le_bytes(); + let value_second_write = aligned_inputs[3].value.to_le_bytes(); // Get the aligned step let step_first_read = aligned_inputs[0].step; @@ -568,7 +569,7 @@ impl MemAlignSM { first_write_row.reg[i] = F::from_canonical_u8(value_first_write[i]); first_write_row.sel[i] = F::from_bool(i >= offset); - value_row.reg[i] = F::from_canonical_u8(value[shift + i]); + value_row.reg[i] = F::from_canonical_u8(value[(shift + i) % CHUNK_NUM]); value_row.sel[i] = F::from_bool(i == offset); second_write_row.reg[i] = F::from_canonical_u8(value_second_write[i]); From 5380f5b5e7181f0ee4e862a68cd5a826766f1259 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9ctor=20Masip?= Date: Thu, 7 Nov 2024 20:35:36 +0000 Subject: [PATCH 19/44] wip --- core/src/zisk_required_operation.rs | 7 ++++++- pil/src/pil_helpers/pilout.rs | 2 +- pil/src/pil_helpers/traces.rs | 2 +- state-machines/mem/pil/mem.pil | 2 +- 4 files changed, 9 insertions(+), 4 deletions(-) diff --git a/core/src/zisk_required_operation.rs b/core/src/zisk_required_operation.rs index 1ccef475..3187861a 100644 --- a/core/src/zisk_required_operation.rs +++ b/core/src/zisk_required_operation.rs @@ -8,7 +8,7 @@ pub struct ZiskRequiredOperation { pub b: u64, } -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct ZiskRequiredMemory { pub step: u64, pub is_write: bool, @@ -17,6 +17,11 @@ pub struct ZiskRequiredMemory { pub value: u64, } +pub struct ZiskRequiredMemoryAlign { + pub mem_op: ZiskRequiredMemory, + pub mem_value: [u64; 2], +} + #[derive(Clone, Default)] pub struct ZiskRequired { pub arith: Vec, diff --git a/pil/src/pil_helpers/pilout.rs b/pil/src/pil_helpers/pilout.rs index cedc45ba..bae5828f 100644 --- a/pil/src/pil_helpers/pilout.rs +++ b/pil/src/pil_helpers/pilout.rs @@ -54,4 +54,4 @@ impl Pilout { pilout } -} \ No newline at end of file +} diff --git a/pil/src/pil_helpers/traces.rs b/pil/src/pil_helpers/traces.rs index aac7ff7c..1545cdca 100644 --- a/pil/src/pil_helpers/traces.rs +++ b/pil/src/pil_helpers/traces.rs @@ -45,4 +45,4 @@ trace!(SpecifiedRangesRow, SpecifiedRangesTrace { trace!(U8AirRow, U8AirTrace { mul: F, -}); \ No newline at end of file +}); diff --git a/state-machines/mem/pil/mem.pil b/state-machines/mem/pil/mem.pil index 92035d3c..d5084847 100644 --- a/state-machines/mem/pil/mem.pil +++ b/state-machines/mem/pil/mem.pil @@ -53,7 +53,7 @@ airtemplate Mem(const int N = 2**21, const int RC = 2, const int id = MEMORY_ID, col witness first_addr_access_is_read; first_addr_access_is_read * (1 - first_addr_access_is_read) === 0; - (1 - first_addr_access_is_read) * rd * same_addr === 0; + (1 - first_addr_access_is_read) * rd * addr_changes === 0; for (int index = 0; index < length(value); index++) { same_value * (value[index] - 'value[index]) === 0; From f77db6b6210a9ce5e0a5984b44500773e9ad85cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9ctor=20Masip?= Date: Fri, 8 Nov 2024 11:20:07 +0000 Subject: [PATCH 20/44] wip --- core/src/zisk_required_operation.rs | 5 - state-machines/mem/src/mem_align_sm.rs | 551 ++++++++++++++++++++++++- 2 files changed, 534 insertions(+), 22 deletions(-) diff --git a/core/src/zisk_required_operation.rs b/core/src/zisk_required_operation.rs index 3187861a..59a7aee6 100644 --- a/core/src/zisk_required_operation.rs +++ b/core/src/zisk_required_operation.rs @@ -17,11 +17,6 @@ pub struct ZiskRequiredMemory { pub value: u64, } -pub struct ZiskRequiredMemoryAlign { - pub mem_op: ZiskRequiredMemory, - pub mem_value: [u64; 2], -} - #[derive(Clone, Default)] pub struct ZiskRequired { pub arith: Vec, diff --git a/state-machines/mem/src/mem_align_sm.rs b/state-machines/mem/src/mem_align_sm.rs index 1acd5bef..ddecab3b 100644 --- a/state-machines/mem/src/mem_align_sm.rs +++ b/state-machines/mem/src/mem_align_sm.rs @@ -23,6 +23,16 @@ use crate::{MemAlignRomSM, MemOp}; const CHUNK_NUM: usize = 8; const CHUNK_NUM_U64: u64 = CHUNK_NUM as u64; const CHUNK_BITS: usize = 8; +const CHUNK_BITS_U64: u64 = CHUNK_BITS as u64; +const CHUNK_BITS_MASK: u64 = (1 << CHUNK_BITS) - 1; + +const ALLOWED_WIDTHS: [u64; 4] = [1, 2, 4, 8]; + +pub struct MemAlignResponse { + pub more_address: bool, + pub step: u64, + pub value: Option, +} pub struct MemAlignSM { // Witness computation manager @@ -93,18 +103,514 @@ impl MemAlignSM { } #[inline(always)] - pub fn get_mem_op(unaligned_input: &ZiskRequiredMemory) -> MemOp { - let addr = unaligned_input.address; - let width = unaligned_input.width; + pub fn get_mem_op(&self, input: &ZiskRequiredMemory, mem_values: Vec, phase: usize) -> MemAlignResponse { + // Sanity check + assert!(mem_values.len() == phase + 1); // TODO: Debug mode + + let addr = input.address; + let width = input.width; + let width = if ALLOWED_WIDTHS.contains(&width) { + width as usize + } else { + panic!("Width={} is not allowed. Allowed widths are {:?}", width, ALLOWED_WIDTHS); + }; + + // Compute the offset + let offset = addr & CHUNK_BITS_MASK; + let offset = if offset <= usize::MAX as u64 { + offset as usize + } else { + panic!("Offset={} is too large", offset); + }; + + // main: [mem_op, addr, 1 + MAX_MEM_OPS_PER_MAIN_STEP * step + 2 * step_offset, bytes, ...value] + // mem: [wr * (MEMORY_STORE_OP - MEMORY_LOAD_OP) + MEMORY_LOAD_OP, addr * MEM_BYTES, step, MEM_BYTES, ...value] + // mem_align: [wr * (MEMORY_STORE_OP - MEMORY_LOAD_OP) + MEMORY_LOAD_OP, addr * CHUNK_NUM + offset, step, width, ...prove_val] + let mut result: MemAlignResponse; + match (input.is_write, offset + width > CHUNK_NUM) { + (false, false) => { // RV + assert!(phase == 0); // TODO: Debug mode + + // Unaligned memory op information thrown into the bus + let step = input.step; + let value = input.value; + + // Compute the shift + let shift = ((offset + width) % CHUNK_NUM) as u64; + + // Get the aligned address + let addr_read = addr >> CHUNK_BITS; + + // Get the aligned value + let value_read = mem_values[phase]; + + // Get the next pc + let next_pc = MemAlignRomSM::::calculate_next_pc(MemOp::OneRead, offset, width); + + let mut read_row = MemAlignRow:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u64(addr_read), + // offset: F::from_canonical_u64(0), + // wr: F::from_bool(false), + // pc: F::from_canonical_u64(0), + reset: F::from_bool(true), + sel_up_to_down: F::from_bool(true), + ..Default::default() + }; - let offset = addr & (CHUNK_NUM_U64 - 1); + let mut value_row = MemAlignRow:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u64(addr), + offset: F::from_canonical_usize(offset), + width: F::from_canonical_usize(width), + // wr: F::from_bool(false), + pc: F::from_canonical_u64(next_pc), + // reset: F::from_bool(false), + sel_prove: F::from_bool(true), + ..Default::default() + }; + + for i in 0..CHUNK_NUM { + let pos = i as u64; + + read_row.reg[i] = { + F::from_canonical_u64(value_read & (CHUNK_BITS_MASK << (pos * CHUNK_BITS_U64))) + }; + read_row.sel[i] = F::from_bool(true); + + value_row.reg[i] = { + F::from_canonical_u64(value & (CHUNK_BITS_MASK << (((shift + pos) % CHUNK_NUM_U64) * CHUNK_BITS_U64))) + }; + value_row.sel[i] = F::from_bool(i == offset as usize); - match (unaligned_input.is_write, offset + width > CHUNK_NUM_U64) { - (false, false) => MemOp::OneRead, - (true, false) => MemOp::OneWrite, - (false, true) => MemOp::TwoReads, - (true, true) => MemOp::TwoWrites, + // Store the range check + // *range_check.entry(read_row.reg[i]).or_insert(0) += 1; + // *range_check.entry(value_row.reg[i]).or_insert(0) += 1; + } + + result = MemAlignResponse { + more_address: false, + step, + value: None, + }; + }, + (true, false) => { // RWV + assert!(phase == 0); // TODO: Debug mode + + // Unaligned memory op information thrown into the bus + let step = input.step; + let value = input.value; + + // Compute the shift + let shift = ((offset + width) % CHUNK_NUM) as u64; + + // Get the aligned address + let addr_read = addr >> CHUNK_BITS; + + // Get the aligned value + let value_read = mem_values[phase]; + + // Get the next pc + let next_pc = MemAlignRomSM::::calculate_next_pc(MemOp::OneWrite, offset, width); + + // Compute the write value + let value_write = { + let width_bytes: u64 = (1 << (width * CHUNK_BITS)) - 1; + + // Get the first width bytes of the unaligned value + let value_to_write = value & width_bytes; + + // Write zeroes to value_read from offset to offset + width + let mask: u64 = width_bytes << (offset * CHUNK_BITS); + + // Add the value to write to the value read + (value_read & !mask) | value_to_write + }; + + let mut read_row = MemAlignRow:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u64(addr_read), + // offset: F::from_canonical_u64(0), + width: F::from_canonical_u64(CHUNK_NUM_U64), + // wr: F::from_bool(false), + // pc: F::from_canonical_u64(0), + reset: F::from_bool(true), + sel_up_to_down: F::from_bool(true), + ..Default::default() + }; + + let mut write_row = MemAlignRow:: { + step: F::from_canonical_u64(step + 1), + addr: F::from_canonical_u64(addr_read), + // offset: F::from_canonical_u64(0), + width: F::from_canonical_u64(CHUNK_NUM_U64), + wr: F::from_bool(true), + pc: F::from_canonical_u64(next_pc), + // reset: F::from_bool(false), + sel_up_to_down: F::from_bool(true), + ..Default::default() + }; + + let mut value_row = MemAlignRow:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u64(addr), + offset: F::from_canonical_usize(offset), + width: F::from_canonical_usize(width), + // wr: F::from_bool(false), + pc: F::from_canonical_u64(next_pc + 1), + // reset: F::from_bool(false), + sel_prove: F::from_bool(true), + ..Default::default() + }; + + for i in 0..CHUNK_NUM { + let pos = i as u64; + + read_row.reg[i] = { + F::from_canonical_u64(value_read & (CHUNK_BITS_MASK << (pos * CHUNK_BITS_U64))) + }; + read_row.sel[i] = F::from_bool(i >= width); + + write_row.reg[i] = { + F::from_canonical_u64(value_write & (CHUNK_BITS_MASK << (pos * CHUNK_BITS_U64))) + }; + write_row.sel[i] = F::from_bool(i < width); + + value_row.reg[i] = { + F::from_canonical_u64(value & (CHUNK_BITS_MASK << (((shift + pos) % CHUNK_NUM_U64) * CHUNK_BITS_U64))) + }; + value_row.sel[i] = F::from_bool(i == offset as usize); + + // Store the range check + // *range_check.entry(read_row.reg[i]).or_insert(0) += 1; + // *range_check.entry(write_row.reg[i]).or_insert(0) += 1; + // *range_check.entry(value_row.reg[i]).or_insert(0) += 1; + } + + result = MemAlignResponse { + more_address: false, + step, + value: Some(value_write), + }; + }, + (false, true) => { // RVR + assert!(phase == 0 || phase == 1); // TODO: Debug mode + + match phase { + // If phase == 0, do nothing, just ask for more + 0 => { + result = MemAlignResponse { + more_address: true, + step: input.step, + value: None, + }; + }, + + // Otherwise, do the RVR + 1 => { + assert!(mem_values.len() == 2); // TODO: Debug mode + + // Unaligned memory op information thrown into the bus + let step = input.step; + let value = input.value; + + // Compute the shift + let shift = ((offset + width) % CHUNK_NUM) as u64; + + // Get the aligned address + let addr_first_read = addr >> CHUNK_BITS; + let addr_second_read = addr >> CHUNK_BITS + CHUNK_BITS; + + // Get the aligned value + let value_first_read = mem_values[0]; + let value_second_read = mem_values[1]; + + // Get the next pc + let next_pc = MemAlignRomSM::::calculate_next_pc(MemOp::TwoReads, offset, width); + + let mut first_read_row = MemAlignRow:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u64(addr_first_read), + // offset: F::from_canonical_u64(0), + width: F::from_canonical_u64(CHUNK_NUM_U64), + // wr: F::from_bool(false), + // pc: F::from_canonical_u64(0), + reset: F::from_bool(true), + sel_up_to_down: F::from_bool(true), + ..Default::default() + }; + + let mut value_row = MemAlignRow:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u64(addr), + offset: F::from_canonical_usize(offset), + width: F::from_canonical_usize(width), + // wr: F::from_bool(false), + pc: F::from_canonical_u64(next_pc), + // reset: F::from_bool(false), + sel_prove: F::from_bool(true), + ..Default::default() + }; + + let mut second_read_row = MemAlignRow:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u64(addr_second_read), + // offset: F::from_canonical_u64(0), + width: F::from_canonical_u64(CHUNK_NUM_U64), + // wr: F::from_bool(false), + pc: F::from_canonical_u64(next_pc + 1), + // reset: F::from_bool(false), + sel_down_to_up: F::from_bool(true), + ..Default::default() + }; + + for i in 0..CHUNK_NUM { + let pos = i as u64; + + first_read_row.reg[i] = { + F::from_canonical_u64(value_first_read & (CHUNK_BITS_MASK << (pos * CHUNK_BITS_U64))) + }; + first_read_row.sel[i] = F::from_bool(true); + + value_row.reg[i] = { + F::from_canonical_u64(value & (CHUNK_BITS_MASK << (((shift + pos) % CHUNK_NUM_U64) * CHUNK_BITS_U64))) + }; + value_row.sel[i] = F::from_bool(i == offset); + + second_read_row.reg[i] = { + F::from_canonical_u64(value_second_read & (CHUNK_BITS_MASK << (pos * CHUNK_BITS_U64))) + }; + second_read_row.sel[i] = F::from_bool(true); + + // Store the range check + // *range_check.entry(first_read_row.reg[i]).or_insert(0) += 1; + // *range_check.entry(value_row.reg[i]).or_insert(0) += 1; + // *range_check.entry(second_read_row.reg[i]).or_insert(0) += 1; + } + + result = MemAlignResponse { + more_address: false, + step, + value: None, + }; + }, + _ => panic!("Invalid phase={}", phase), + } + }, + (true, true) => { // RWVWR + assert!(phase == 0 || phase == 1); // TODO: Debug mode + + match phase { + // If phase == 0, compute the resulting write value and ask for more + 0 => { + assert!(mem_values.len() == 1); // TODO: Debug mode + + // Unaligned memory op information thrown into the bus + let value = input.value; + let step = input.step; + + // Get the aligned value + let value_first_read = mem_values[0]; + + // Compute the write value + let value_first_write = { + let width_bytes: u64 = (1 << (width * CHUNK_BITS)) - 1; + + // Get the first width bytes of the unaligned value + let value_to_write = value & width_bytes; + + // Write zeroes to value_read from offset to offset + width + let mask = width_bytes << (offset * CHUNK_BITS); + + // Add the value to write to the value read + (value_first_read & !mask) | value_to_write + }; + + result = MemAlignResponse { + more_address: true, + step, + value: Some(value_first_write), + }; + }, + // Otherwise, do the RWVRW + 1 => { + assert!(mem_values.len() == 2); // TODO: Debug mode + + // Unaligned memory op information thrown into the bus + let step = input.step; + let value = input.value; + + // Compute the shift + let shift = ((offset + width) % CHUNK_NUM) as u64; + + // Get the aligned address + let addr_first_read_write = addr >> CHUNK_BITS; + let addr_second_read_write = addr >> CHUNK_BITS + CHUNK_BITS; + + // Get the first aligned value + let value_first_read = mem_values[0]; + + // Recompute the first write value + let value_first_write = { + let width_bytes = (1 << (width * CHUNK_BITS)) - 1; + + // Get the first width bytes of the unaligned value + let value_to_write = value & width_bytes; + + // Write zeroes to value_read from offset to offset + width + let mask = width_bytes << (offset * CHUNK_BITS); + + // Add the value to write to the value read + (value_first_read & !mask) | value_to_write + }; + + // Get the second aligned value + let value_second_read = mem_values[1]; + + // Compute the second write value + let value_second_write = { + let width_bytes = (1 << (width * CHUNK_BITS)) - 1; + + // Get the first width bytes of the unaligned value + let value_to_write = value & width_bytes; + + // Write zeroes to value_read from offset to offset + width + let mask = width_bytes << (offset * CHUNK_BITS); + + // Add the value to write to the value read + (value_second_read & !mask) | value_to_write + }; + + // Get the next pc + let next_pc = MemAlignRomSM::::calculate_next_pc(MemOp::TwoWrites, offset, width); + + // RWVWR + let mut first_read_row = MemAlignRow:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u64(addr_first_read_write), + // offset: F::from_canonical_u64(0), + width: F::from_canonical_u64(CHUNK_NUM_U64), + // wr: F::from_bool(false), + // pc: F::from_canonical_u64(0), + reset: F::from_bool(true), + sel_up_to_down: F::from_bool(true), + ..Default::default() + }; + + let mut first_write_row = MemAlignRow:: { + step: F::from_canonical_u64(step + 1), + addr: F::from_canonical_u64(addr_first_read_write), + // offset: F::from_canonical_u64(0), + width: F::from_canonical_u64(CHUNK_NUM_U64), + wr: F::from_bool(true), + pc: F::from_canonical_u64(next_pc), + // reset: F::from_bool(false), + sel_up_to_down: F::from_bool(true), + ..Default::default() + }; + + let mut value_row = MemAlignRow:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u64(addr), + offset: F::from_canonical_usize(offset), + width: F::from_canonical_usize(width), + // wr: F::from_bool(false), + pc: F::from_canonical_u64(next_pc + 1), + // reset: F::from_bool(false), + sel_prove: F::from_bool(true), + ..Default::default() + }; + + let mut second_write_row = MemAlignRow:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u64(addr_second_read_write), + // offset: F::from_canonical_u64(0), + width: F::from_canonical_u64(CHUNK_NUM_U64), + wr: F::from_bool(true), + pc: F::from_canonical_u64(next_pc + 2), + // reset: F::from_bool(false), + sel_down_to_up: F::from_bool(true), + ..Default::default() + }; + + let mut second_read_row = MemAlignRow:: { + step: F::from_canonical_u64(step + 1), + addr: F::from_canonical_u64(addr_second_read_write), + // offset: F::from_canonical_u64(0), + width: F::from_canonical_u64(CHUNK_NUM_U64), + // wr: F::from_bool(false), + pc: F::from_canonical_u64(next_pc + 3), + reset: F::from_bool(false), + sel_down_to_up: F::from_bool(true), + ..Default::default() + }; + + for i in 0..CHUNK_NUM { + let pos = i as u64; + + first_read_row.reg[i] = { + F::from_canonical_u64(value_first_read & (CHUNK_BITS_MASK << (pos * CHUNK_BITS_U64))) + }; + first_read_row.sel[i] = F::from_bool(i < offset); + + first_write_row.reg[i] = { + F::from_canonical_u64(value_first_write & (CHUNK_BITS_MASK << (pos * CHUNK_BITS_U64))) + }; + first_write_row.sel[i] = F::from_bool(i >= offset); + + value_row.reg[i] = { + F::from_canonical_u64(value & (CHUNK_BITS_MASK << (((shift + pos) % CHUNK_NUM_U64) * CHUNK_BITS_U64))) + }; + value_row.sel[i] = F::from_bool(i == offset); + + second_write_row.reg[i] = { + F::from_canonical_u64(value_second_write & (CHUNK_BITS_MASK << (pos * CHUNK_BITS_U64))) + }; + second_write_row.sel[i] = F::from_bool(pos < shift); + + second_read_row.reg[i] = { + F::from_canonical_u64(value_second_read & (CHUNK_BITS_MASK << (pos * CHUNK_BITS_U64))) + }; + second_read_row.sel[i] = F::from_bool(pos >= shift); + + // Store the range check + // *range_check.entry(first_read_row.reg[i]).or_insert(0) += 1; + // *range_check.entry(first_write_row.reg[i]).or_insert(0) += 1; + // *range_check.entry(value_row.reg[i]).or_insert(0) += 1; + // *range_check.entry(second_write_row.reg[i]).or_insert(0) += 1; + // *range_check.entry(second_read_row.reg[i]).or_insert(0) += 1; + + result = MemAlignResponse { + more_address: false, + step, + value: Some(value_second_write), + }; + } + }, + _ => panic!("Invalid phase={}", phase), + } + }, } + + if let Ok(mut inputs) = self.inputs.lock() { + inputs.push((unaligned_access.clone(), aligned_accesses.to_vec())); + + let pctx = self.wcm.get_pctx(); + let air_mem_align = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_ALIGN_AIR_IDS[0]); + + // TODO: Fix this, I am assuming the wc + while inputs.len() * 5 >= air_mem_align.num_rows() { + let num_drained = std::cmp::min(air_mem_align.num_rows(), inputs.len()); + let drained_inputs = inputs.drain(..num_drained).collect::>(); + + self.prove_internal(&drained_inputs); + } + } + + // { + // is_sufficient, + // if write => (value, step), // the step should be the same for the two words! (the write is step + 1) + // } } pub fn prove( @@ -273,6 +779,17 @@ impl MemAlignSM { let op_size = MemAlignRomSM::::get_mem_align_op_size(op); let next_pc = MemAlignRomSM::::calculate_next_pc(op, offset, width); + println!("OP: {:?}", op); + println!("UNALIGNED INPUT:\n {:?}", unaligned_input); + println!(" OFFSET: {:?}", offset); + println!(" value: {:?}", unaligned_input.value.to_le_bytes()); + println!("ALIGNED INPUTS:"); + for aligned_input in aligned_inputs { + println!(" {:?}", aligned_input); + println!(" value: {:?}", aligned_input.value.to_le_bytes()); + } + println!(""); + // Initialize and set the rows of the corresponding op let mut rows: Vec> = Vec::with_capacity(op_size); // TODO: Can I detatch the "shape" of the program from the mem_align and do it in the mem_align_rom? @@ -283,7 +800,7 @@ impl MemAlignSM { assert!(aligned_inputs.len() == 1); // Get the aligned address - let addr_read = aligned_inputs[0].address; // addr / CHUNK_NUM; + let addr_read = aligned_inputs[0].address; // Get the aligned values let value_read = aligned_inputs[0].value.to_le_bytes(); @@ -336,7 +853,7 @@ impl MemAlignSM { assert!(aligned_inputs.len() == 2); // Get the aligned address - let addr_read_write = aligned_inputs[0].address; // addr / CHUNK_NUM; + let addr_read_write = aligned_inputs[0].address; // Get the aligned values let value_read = aligned_inputs[0].value.to_le_bytes(); @@ -385,10 +902,10 @@ impl MemAlignSM { for i in 0..CHUNK_NUM { read_row.reg[i] = F::from_canonical_u8(value_read[i]); - read_row.sel[i] = F::from_bool(i < offset); + read_row.sel[i] = F::from_bool(i >= width); write_row.reg[i] = F::from_canonical_u8(value_write[i]); - write_row.sel[i] = F::from_bool(i >= offset); + write_row.sel[i] = F::from_bool(i < width); value_row.reg[i] = F::from_canonical_u8(value[(shift + i) % CHUNK_NUM]); value_row.sel[i] = F::from_bool(i == offset); @@ -410,8 +927,8 @@ impl MemAlignSM { assert!(aligned_inputs.len() == 2); // Get the aligned address - let addr_first_read = aligned_inputs[0].address; // addr / CHUNK_NUM; - let addr_second_read = aligned_inputs[1].address; // addr / CHUNK_NUM + CHUNK_NUM; + let addr_first_read = aligned_inputs[0].address; + let addr_second_read = aligned_inputs[1].address; // Get the aligned values let value_first_read = aligned_inputs[0].value.to_le_bytes(); @@ -485,8 +1002,8 @@ impl MemAlignSM { assert!(aligned_inputs.len() == 4); // Get the aligned address - let addr_first_read_write = aligned_inputs[0].address; // addr / CHUNK_NUM; - let addr_second_read_write = aligned_inputs[2].address; // addr / CHUNK_NUM + CHUNK_NUM; + let addr_first_read_write = aligned_inputs[0].address; + let addr_second_read_write = aligned_inputs[2].address; // Get the aligned values // TODO: I do not need to establish an order, I can use the field is_write!!! From 0036a50f97c2af998f36b8caacb155d7d4dc88e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9ctor=20Masip?= Date: Fri, 8 Nov 2024 16:24:40 +0000 Subject: [PATCH 21/44] Mem align compiling --- state-machines/mem/src/mem_align_sm.rs | 1131 ++++++++++++------------ 1 file changed, 561 insertions(+), 570 deletions(-) diff --git a/state-machines/mem/src/mem_align_sm.rs b/state-machines/mem/src/mem_align_sm.rs index ddecab3b..fb243e06 100644 --- a/state-machines/mem/src/mem_align_sm.rs +++ b/state-machines/mem/src/mem_align_sm.rs @@ -44,8 +44,8 @@ pub struct MemAlignSM { // Count of registered predecessors registered_predecessors: AtomicU32, - // Inputs - inputs: Mutex)>>, + // Computed rows + rows: Mutex>>, // Secondary State machines mem_align_rom_sm: Arc>, @@ -63,7 +63,7 @@ impl MemAlignSM { wcm: wcm.clone(), std: std.clone(), registered_predecessors: AtomicU32::new(0), - inputs: Mutex::new(Vec::new()), + rows: Mutex::new(Vec::new()), mem_align_rom_sm, }; let mem_align_sm = Arc::new(mem_align_sm); @@ -87,26 +87,37 @@ impl MemAlignSM { pub fn unregister_predecessor(&self) { if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { + let pctx = self.wcm.get_pctx(); + // TODO: Fix this... - if let Ok(mut inputs) = self.inputs.lock() { - let pctx = self.wcm.get_pctx(); + // If there are remaining rows, generate the last instance + if let Ok(mut rows) = self.rows.lock() { + // Get the Mem Align AIR let air_mem_align = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_ALIGN_AIR_IDS[0]); - let num_drained = std::cmp::min(air_mem_align.num_rows(), inputs.len()); - let drained_inputs = inputs.drain(..num_drained).collect::>(); - self.prove_internal(&drained_inputs); + let rows_len = rows.len(); + assert!(rows_len <= air_mem_align.num_rows()); + + let drained_rows = rows.drain(..rows_len).collect::>(); + + self.fill_new_air_instance(&drained_rows); } self.mem_align_rom_sm.unregister_predecessor(); - self.std.unregister_predecessor(self.wcm.get_pctx(), None); + self.std.unregister_predecessor(pctx, None); } } #[inline(always)] - pub fn get_mem_op(&self, input: &ZiskRequiredMemory, mem_values: Vec, phase: usize) -> MemAlignResponse { + pub fn get_mem_op( + &self, + input: &ZiskRequiredMemory, + mem_values: Vec, + phase: usize, + ) -> MemAlignResponse { // Sanity check assert!(mem_values.len() == phase + 1); // TODO: Debug mode - + let addr = input.address; let width = input.width; let width = if ALLOWED_WIDTHS.contains(&width) { @@ -126,10 +137,10 @@ impl MemAlignSM { // main: [mem_op, addr, 1 + MAX_MEM_OPS_PER_MAIN_STEP * step + 2 * step_offset, bytes, ...value] // mem: [wr * (MEMORY_STORE_OP - MEMORY_LOAD_OP) + MEMORY_LOAD_OP, addr * MEM_BYTES, step, MEM_BYTES, ...value] // mem_align: [wr * (MEMORY_STORE_OP - MEMORY_LOAD_OP) + MEMORY_LOAD_OP, addr * CHUNK_NUM + offset, step, width, ...prove_val] - let mut result: MemAlignResponse; match (input.is_write, offset + width > CHUNK_NUM) { - (false, false) => { // RV - assert!(phase == 0); // TODO: Debug mode + (false, false) => { + // RV + assert!(phase == 0); // TODO: Debug mode // Unaligned memory op information thrown into the bus let step = input.step; @@ -174,12 +185,18 @@ impl MemAlignSM { let pos = i as u64; read_row.reg[i] = { - F::from_canonical_u64(value_read & (CHUNK_BITS_MASK << (pos * CHUNK_BITS_U64))) + F::from_canonical_u64( + value_read & (CHUNK_BITS_MASK << (pos * CHUNK_BITS_U64)), + ) }; read_row.sel[i] = F::from_bool(true); value_row.reg[i] = { - F::from_canonical_u64(value & (CHUNK_BITS_MASK << (((shift + pos) % CHUNK_NUM_U64) * CHUNK_BITS_U64))) + F::from_canonical_u64( + value + & (CHUNK_BITS_MASK + << (((shift + pos) % CHUNK_NUM_U64) * CHUNK_BITS_U64)), + ) }; value_row.sel[i] = F::from_bool(i == offset as usize); @@ -188,14 +205,14 @@ impl MemAlignSM { // *range_check.entry(value_row.reg[i]).or_insert(0) += 1; } - result = MemAlignResponse { - more_address: false, - step, - value: None, - }; - }, - (true, false) => { // RWV - assert!(phase == 0); // TODO: Debug mode + // Prove the generated rows + self.prove(&[read_row, value_row]); + + MemAlignResponse { more_address: false, step, value: None } + } + (true, false) => { + // RWV + assert!(phase == 0); // TODO: Debug mode // Unaligned memory op information thrown into the bus let step = input.step; @@ -219,7 +236,7 @@ impl MemAlignSM { // Get the first width bytes of the unaligned value let value_to_write = value & width_bytes; - + // Write zeroes to value_read from offset to offset + width let mask: u64 = width_bytes << (offset * CHUNK_BITS); @@ -267,17 +284,25 @@ impl MemAlignSM { let pos = i as u64; read_row.reg[i] = { - F::from_canonical_u64(value_read & (CHUNK_BITS_MASK << (pos * CHUNK_BITS_U64))) + F::from_canonical_u64( + value_read & (CHUNK_BITS_MASK << (pos * CHUNK_BITS_U64)), + ) }; read_row.sel[i] = F::from_bool(i >= width); write_row.reg[i] = { - F::from_canonical_u64(value_write & (CHUNK_BITS_MASK << (pos * CHUNK_BITS_U64))) + F::from_canonical_u64( + value_write & (CHUNK_BITS_MASK << (pos * CHUNK_BITS_U64)), + ) }; write_row.sel[i] = F::from_bool(i < width); value_row.reg[i] = { - F::from_canonical_u64(value & (CHUNK_BITS_MASK << (((shift + pos) % CHUNK_NUM_U64) * CHUNK_BITS_U64))) + F::from_canonical_u64( + value + & (CHUNK_BITS_MASK + << (((shift + pos) % CHUNK_NUM_U64) * CHUNK_BITS_U64)), + ) }; value_row.sel[i] = F::from_bool(i == offset as usize); @@ -287,28 +312,22 @@ impl MemAlignSM { // *range_check.entry(value_row.reg[i]).or_insert(0) += 1; } - result = MemAlignResponse { - more_address: false, - step, - value: Some(value_write), - }; - }, - (false, true) => { // RVR - assert!(phase == 0 || phase == 1); // TODO: Debug mode + // Prove the generated rows + self.prove(&[read_row, value_row]); + + MemAlignResponse { more_address: false, step, value: Some(value_write) } + } + (false, true) => { + // RVR + assert!(phase == 0 || phase == 1); // TODO: Debug mode match phase { // If phase == 0, do nothing, just ask for more - 0 => { - result = MemAlignResponse { - more_address: true, - step: input.step, - value: None, - }; - }, + 0 => MemAlignResponse { more_address: true, step: input.step, value: None }, // Otherwise, do the RVR 1 => { - assert!(mem_values.len() == 2); // TODO: Debug mode + assert!(mem_values.len() == 2); // TODO: Debug mode // Unaligned memory op information thrown into the bus let step = input.step; @@ -326,7 +345,8 @@ impl MemAlignSM { let value_second_read = mem_values[1]; // Get the next pc - let next_pc = MemAlignRomSM::::calculate_next_pc(MemOp::TwoReads, offset, width); + let next_pc = + MemAlignRomSM::::calculate_next_pc(MemOp::TwoReads, offset, width); let mut first_read_row = MemAlignRow:: { step: F::from_canonical_u64(step), @@ -368,17 +388,25 @@ impl MemAlignSM { let pos = i as u64; first_read_row.reg[i] = { - F::from_canonical_u64(value_first_read & (CHUNK_BITS_MASK << (pos * CHUNK_BITS_U64))) + F::from_canonical_u64( + value_first_read & (CHUNK_BITS_MASK << (pos * CHUNK_BITS_U64)), + ) }; first_read_row.sel[i] = F::from_bool(true); value_row.reg[i] = { - F::from_canonical_u64(value & (CHUNK_BITS_MASK << (((shift + pos) % CHUNK_NUM_U64) * CHUNK_BITS_U64))) + F::from_canonical_u64( + value + & (CHUNK_BITS_MASK + << (((shift + pos) % CHUNK_NUM_U64) * CHUNK_BITS_U64)), + ) }; value_row.sel[i] = F::from_bool(i == offset); second_read_row.reg[i] = { - F::from_canonical_u64(value_second_read & (CHUNK_BITS_MASK << (pos * CHUNK_BITS_U64))) + F::from_canonical_u64( + value_second_read & (CHUNK_BITS_MASK << (pos * CHUNK_BITS_U64)), + ) }; second_read_row.sel[i] = F::from_bool(true); @@ -388,22 +416,22 @@ impl MemAlignSM { // *range_check.entry(second_read_row.reg[i]).or_insert(0) += 1; } - result = MemAlignResponse { - more_address: false, - step, - value: None, - }; - }, + // Prove the generated rows + self.prove(&[read_row, value_row]); + + MemAlignResponse { more_address: false, step, value: None } + } _ => panic!("Invalid phase={}", phase), } - }, - (true, true) => { // RWVWR - assert!(phase == 0 || phase == 1); // TODO: Debug mode + } + (true, true) => { + // RWVWR + assert!(phase == 0 || phase == 1); // TODO: Debug mode match phase { // If phase == 0, compute the resulting write value and ask for more 0 => { - assert!(mem_values.len() == 1); // TODO: Debug mode + assert!(mem_values.len() == 1); // TODO: Debug mode // Unaligned memory op information thrown into the bus let value = input.value; @@ -418,7 +446,7 @@ impl MemAlignSM { // Get the first width bytes of the unaligned value let value_to_write = value & width_bytes; - + // Write zeroes to value_read from offset to offset + width let mask = width_bytes << (offset * CHUNK_BITS); @@ -426,15 +454,15 @@ impl MemAlignSM { (value_first_read & !mask) | value_to_write }; - result = MemAlignResponse { + MemAlignResponse { more_address: true, step, value: Some(value_first_write), - }; - }, + } + } // Otherwise, do the RWVRW 1 => { - assert!(mem_values.len() == 2); // TODO: Debug mode + assert!(mem_values.len() == 2); // TODO: Debug mode // Unaligned memory op information thrown into the bus let step = input.step; @@ -456,7 +484,7 @@ impl MemAlignSM { // Get the first width bytes of the unaligned value let value_to_write = value & width_bytes; - + // Write zeroes to value_read from offset to offset + width let mask = width_bytes << (offset * CHUNK_BITS); @@ -473,7 +501,7 @@ impl MemAlignSM { // Get the first width bytes of the unaligned value let value_to_write = value & width_bytes; - + // Write zeroes to value_read from offset to offset + width let mask = width_bytes << (offset * CHUNK_BITS); @@ -482,7 +510,8 @@ impl MemAlignSM { }; // Get the next pc - let next_pc = MemAlignRomSM::::calculate_next_pc(MemOp::TwoWrites, offset, width); + let next_pc = + MemAlignRomSM::::calculate_next_pc(MemOp::TwoWrites, offset, width); // RWVWR let mut first_read_row = MemAlignRow:: { @@ -549,176 +578,129 @@ impl MemAlignSM { let pos = i as u64; first_read_row.reg[i] = { - F::from_canonical_u64(value_first_read & (CHUNK_BITS_MASK << (pos * CHUNK_BITS_U64))) + F::from_canonical_u64( + value_first_read & (CHUNK_BITS_MASK << (pos * CHUNK_BITS_U64)), + ) }; first_read_row.sel[i] = F::from_bool(i < offset); first_write_row.reg[i] = { - F::from_canonical_u64(value_first_write & (CHUNK_BITS_MASK << (pos * CHUNK_BITS_U64))) + F::from_canonical_u64( + value_first_write & (CHUNK_BITS_MASK << (pos * CHUNK_BITS_U64)), + ) }; first_write_row.sel[i] = F::from_bool(i >= offset); value_row.reg[i] = { - F::from_canonical_u64(value & (CHUNK_BITS_MASK << (((shift + pos) % CHUNK_NUM_U64) * CHUNK_BITS_U64))) + F::from_canonical_u64( + value + & (CHUNK_BITS_MASK + << (((shift + pos) % CHUNK_NUM_U64) * CHUNK_BITS_U64)), + ) }; value_row.sel[i] = F::from_bool(i == offset); second_write_row.reg[i] = { - F::from_canonical_u64(value_second_write & (CHUNK_BITS_MASK << (pos * CHUNK_BITS_U64))) + F::from_canonical_u64( + value_second_write + & (CHUNK_BITS_MASK << (pos * CHUNK_BITS_U64)), + ) }; second_write_row.sel[i] = F::from_bool(pos < shift); second_read_row.reg[i] = { - F::from_canonical_u64(value_second_read & (CHUNK_BITS_MASK << (pos * CHUNK_BITS_U64))) + F::from_canonical_u64( + value_second_read & (CHUNK_BITS_MASK << (pos * CHUNK_BITS_U64)), + ) }; second_read_row.sel[i] = F::from_bool(pos >= shift); + } - // Store the range check - // *range_check.entry(first_read_row.reg[i]).or_insert(0) += 1; - // *range_check.entry(first_write_row.reg[i]).or_insert(0) += 1; - // *range_check.entry(value_row.reg[i]).or_insert(0) += 1; - // *range_check.entry(second_write_row.reg[i]).or_insert(0) += 1; - // *range_check.entry(second_read_row.reg[i]).or_insert(0) += 1; + // Store the range check + // *range_check.entry(first_read_row.reg[i]).or_insert(0) += 1; + // *range_check.entry(first_write_row.reg[i]).or_insert(0) += 1; + // *range_check.entry(value_row.reg[i]).or_insert(0) += 1; + // *range_check.entry(second_write_row.reg[i]).or_insert(0) += 1; + // *range_check.entry(second_read_row.reg[i]).or_insert(0) += 1; - result = MemAlignResponse { - more_address: false, - step, - value: Some(value_second_write), - }; + // Prove the generated rows + self.prove(&[read_row, value_row]); + + MemAlignResponse { + more_address: false, + step, + value: Some(value_second_write), } - }, + } _ => panic!("Invalid phase={}", phase), } - }, - } - - if let Ok(mut inputs) = self.inputs.lock() { - inputs.push((unaligned_access.clone(), aligned_accesses.to_vec())); - - let pctx = self.wcm.get_pctx(); - let air_mem_align = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_ALIGN_AIR_IDS[0]); - - // TODO: Fix this, I am assuming the wc - while inputs.len() * 5 >= air_mem_align.num_rows() { - let num_drained = std::cmp::min(air_mem_align.num_rows(), inputs.len()); - let drained_inputs = inputs.drain(..num_drained).collect::>(); - - self.prove_internal(&drained_inputs); } } - - // { - // is_sufficient, - // if write => (value, step), // the step should be the same for the two words! (the write is step + 1) - // } } - pub fn prove( - &self, - unaligned_access: &ZiskRequiredMemory, - aligned_accesses: &[ZiskRequiredMemory], - ) { - if let Ok(mut inputs) = self.inputs.lock() { - inputs.push((unaligned_access.clone(), aligned_accesses.to_vec())); + pub fn prove(&self, computed_rows: &[MemAlignRow]) { + if let Ok(mut rows) = self.rows.lock() { + rows.extend_from_slice(computed_rows); let pctx = self.wcm.get_pctx(); let air_mem_align = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_ALIGN_AIR_IDS[0]); - // TODO: Fix this, I am assuming the wc - while inputs.len() * 5 >= air_mem_align.num_rows() { - let num_drained = std::cmp::min(air_mem_align.num_rows(), inputs.len()); - let drained_inputs = inputs.drain(..num_drained).collect::>(); + while rows.len() >= air_mem_align.num_rows() { + let num_drained = std::cmp::min(air_mem_align.num_rows(), rows.len()); + let drained_rows = rows.drain(..num_drained).collect::>(); - self.prove_internal(&drained_inputs); + self.fill_new_air_instance(&drained_rows); } } } - fn prove_internal( - &self, - inputs: &[(ZiskRequiredMemory, Vec)], - ) { - let mem_align_rom_sm = self.mem_align_rom_sm.clone(); + fn fill_new_air_instance(&self, rows: &[MemAlignRow]) { + // Get the proof context let wcm = self.wcm.clone(); - let std = self.std.clone(); - let sctx = self.wcm.get_sctx().clone(); - - let (mut prover_buffer, offset) = create_prover_buffer( - &wcm.get_ectx(), - &wcm.get_sctx(), - ZISK_AIRGROUP_ID, - MEM_ALIGN_AIR_IDS[0], - ); - - Self::prove_instance( - &wcm, - &mem_align_rom_sm, - &std, - inputs, - &mut prover_buffer, - offset, - ); - - let air_instance = - AirInstance::new(sctx, ZISK_AIRGROUP_ID, MEM_ALIGN_AIR_IDS[0], None, prover_buffer); - wcm.get_pctx().air_instance_repo.add_air_instance(air_instance, None); - } - - fn prove_instance( - wcm: &WitnessManager, - mem_align_rom_sm: &MemAlignRomSM, - std: &Std, - inputs: &[(ZiskRequiredMemory, Vec)], - prover_buffer: &mut [F], - offset: u64, - ) { - let input_len = inputs.len(); - let pctx = wcm.get_pctx(); + // Get the Mem Align AIR let air_mem_align = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_ALIGN_AIR_IDS[0]); - assert!(input_len <= air_mem_align.num_rows()); + let air_mem_align_rows = air_mem_align.num_rows(); + let rows_len = rows.len(); - info!( - "{}: ··· Creating Mem Align instance [{} / {} rows filled {:.2}%]", - Self::MY_NAME, - input_len, - air_mem_align.num_rows(), - input_len as f64 / air_mem_align.num_rows() as f64 * 100.0 - ); + // You cannot feed to the AIR more rows than it has + assert!(rows_len <= air_mem_align_rows); + + // Get the execution and setup context + let ectx = self.wcm.get_ectx(); + let sctx = self.wcm.get_sctx(); - let mut reg_range_check: HashMap = HashMap::new(); + // Create a prover buffer + let (mut prover_buffer, offset) = + create_prover_buffer(&ectx, &sctx, ZISK_AIRGROUP_ID, MEM_ALIGN_AIR_IDS[0]); + + // Create a Mem Align trace buffer let mut trace_buffer = MemAlignTrace::::map_buffer( - prover_buffer, - air_mem_align.num_rows(), + &mut prover_buffer, + air_mem_align_rows, offset as usize, ) .unwrap(); - // Process the inputs while saving the values to be range checked - let mut rows_processed = 0; - for (unaligned_input, aligned_inputs) in inputs.iter() { - let rows = Self::process_input( - unaligned_input, - aligned_inputs, - mem_align_rom_sm, - &mut reg_range_check, - ); - for (j, &row) in rows.iter().enumerate() { - trace_buffer[rows_processed + j] = row; - } - rows_processed += rows.len(); + // Add the input rows to the trace + for (i, &row) in rows.iter().enumerate() { + trace_buffer[i] = row; } - // Pad the remaining rows with trivailly satisfying rows + // Pad the remaining rows with trivially satisfying rows let padding_row = MemAlignRow::::default(); - - for i in rows_processed..air_mem_align.num_rows() { + for i in rows_len..air_mem_align_rows { trace_buffer[i] = padding_row; } + // TODO: Treat the range check here of both standard and padding rows!! + + // TODO: Treate the ROM multiplicity + // TODO: Store the padding multiplicity - let _padding_size = air_mem_align.num_rows() - rows_processed; + // let mem_align_rom_sm = self.mem_align_rom_sm.clone(); + // let _padding_size = air_mem_align.num_rows() - rows_processed; // for i in 0..8 { // let multiplicity = padding_size as u64; // let row = MemAlignRomSM::::calculate_rom_row( @@ -727,397 +709,406 @@ impl MemAlignSM { // rom_multiplicity[row as usize] += multiplicity; // } - // Perform the range checks - let range_id = std.get_range(BigInt::from(0), BigInt::from((1 << CHUNK_BITS) - 1), None); - for (&value, &multiplicity) in reg_range_check.iter() { - std.range_check(value, F::from_canonical_u64(multiplicity), range_id); - } - - // std::thread::spawn(move || { - // drop(inputs); - // drop(reg_range_check); - // }); - } - - #[inline(always)] - pub fn process_input( - unaligned_input: &ZiskRequiredMemory, - aligned_inputs: &[ZiskRequiredMemory], - mem_align_rom_sm: &MemAlignRomSM, - range_check: &mut HashMap, - ) -> Vec> { - // Get the unaligned address - let addr = unaligned_input.address; - - // Get the unaligned value - let value = unaligned_input.value.to_le_bytes(); - - // Get the unaligned step - let step = unaligned_input.step; - - // Get the unaligned width - let width = unaligned_input.width; - let width = if width <= CHUNK_NUM_U64 { - width as usize - } else { - panic!("Invalid width={}", width); - }; - - // Compute the offset - let offset = addr % CHUNK_NUM_U64; - let offset = if offset <= usize::MAX as u64 { - offset as usize - } else { - panic!("Invalid offset={}", offset); - }; - - // Compute the shift - let shift = (offset + width) % CHUNK_NUM; - - // Get the op to be executed, its size and the pc to jump to - let op = Self::get_mem_op(&unaligned_input); - let op_size = MemAlignRomSM::::get_mem_align_op_size(op); - let next_pc = MemAlignRomSM::::calculate_next_pc(op, offset, width); - - println!("OP: {:?}", op); - println!("UNALIGNED INPUT:\n {:?}", unaligned_input); - println!(" OFFSET: {:?}", offset); - println!(" value: {:?}", unaligned_input.value.to_le_bytes()); - println!("ALIGNED INPUTS:"); - for aligned_input in aligned_inputs { - println!(" {:?}", aligned_input); - println!(" value: {:?}", aligned_input.value.to_le_bytes()); - } - println!(""); - - // Initialize and set the rows of the corresponding op - let mut rows: Vec> = Vec::with_capacity(op_size); - // TODO: Can I detatch the "shape" of the program from the mem_align and do it in the mem_align_rom? - match op { - MemOp::OneRead => { - // RV - // Sanity check - assert!(aligned_inputs.len() == 1); - - // Get the aligned address - let addr_read = aligned_inputs[0].address; - - // Get the aligned values - let value_read = aligned_inputs[0].value.to_le_bytes(); - - // Get the aligned step - let step_read = aligned_inputs[0].step; - - let mut read_row = MemAlignRow:: { - step: F::from_canonical_u64(step_read), - addr: F::from_canonical_u64(addr_read), - // offset: F::from_canonical_u64(0), - // wr: F::from_bool(false), - // pc: F::from_canonical_u64(0), - reset: F::from_bool(true), - sel_up_to_down: F::from_bool(true), - ..Default::default() - }; - - let mut value_row = MemAlignRow:: { - step: F::from_canonical_u64(step), - addr: F::from_canonical_u64(addr), - offset: F::from_canonical_usize(offset), - width: F::from_canonical_usize(width), - // wr: F::from_bool(false), - pc: F::from_canonical_u64(next_pc), - // reset: F::from_bool(false), - sel_prove: F::from_bool(true), - ..Default::default() - }; - - for i in 0..CHUNK_NUM { - read_row.reg[i] = F::from_canonical_u8(value_read[i]); - read_row.sel[i] = F::from_bool(true); - - value_row.reg[i] = F::from_canonical_u8(value[(shift + i) % CHUNK_NUM]); - value_row.sel[i] = F::from_bool(i == offset); - - // Store the range check - *range_check.entry(read_row.reg[i]).or_insert(0) += 1; - *range_check.entry(value_row.reg[i]).or_insert(0) += 1; - } - - // Store the rows - rows.push(read_row); - rows.push(value_row); - } - MemOp::OneWrite => { - // RWV - // Sanity check - assert!(aligned_inputs.len() == 2); - - // Get the aligned address - let addr_read_write = aligned_inputs[0].address; - - // Get the aligned values - let value_read = aligned_inputs[0].value.to_le_bytes(); - let value_write = aligned_inputs[1].value.to_le_bytes(); - - // Get the aligned step - let step_read = aligned_inputs[0].step; - let step_write = aligned_inputs[1].step; - - // RWV - let mut read_row = MemAlignRow:: { - step: F::from_canonical_u64(step_read), - addr: F::from_canonical_u64(addr_read_write), - // offset: F::from_canonical_u64(0), - width: F::from_canonical_u64(CHUNK_NUM_U64), - // wr: F::from_bool(false), - // pc: F::from_canonical_u64(0), - reset: F::from_bool(true), - sel_up_to_down: F::from_bool(true), - ..Default::default() - }; - - let mut write_row = MemAlignRow:: { - step: F::from_canonical_u64(step_write), - addr: F::from_canonical_u64(addr_read_write), - // offset: F::from_canonical_u64(0), - width: F::from_canonical_u64(CHUNK_NUM_U64), - wr: F::from_bool(true), - pc: F::from_canonical_u64(next_pc), - // reset: F::from_bool(false), - sel_up_to_down: F::from_bool(true), - ..Default::default() - }; - - let mut value_row = MemAlignRow:: { - step: F::from_canonical_u64(step), - addr: F::from_canonical_u64(addr), - offset: F::from_canonical_usize(offset), - width: F::from_canonical_usize(width), - // wr: F::from_bool(false), - pc: F::from_canonical_u64(next_pc + 1), - // reset: F::from_bool(false), - sel_prove: F::from_bool(true), - ..Default::default() - }; - - for i in 0..CHUNK_NUM { - read_row.reg[i] = F::from_canonical_u8(value_read[i]); - read_row.sel[i] = F::from_bool(i >= width); - - write_row.reg[i] = F::from_canonical_u8(value_write[i]); - write_row.sel[i] = F::from_bool(i < width); - - value_row.reg[i] = F::from_canonical_u8(value[(shift + i) % CHUNK_NUM]); - value_row.sel[i] = F::from_bool(i == offset); - - // Store the range check - *range_check.entry(read_row.reg[i]).or_insert(0) += 1; - *range_check.entry(write_row.reg[i]).or_insert(0) += 1; - *range_check.entry(value_row.reg[i]).or_insert(0) += 1; - } - - // Store the rows - rows.push(read_row); - rows.push(write_row); - rows.push(value_row); - } - MemOp::TwoReads => { - // RVR - // Sanity check - assert!(aligned_inputs.len() == 2); - - // Get the aligned address - let addr_first_read = aligned_inputs[0].address; - let addr_second_read = aligned_inputs[1].address; - - // Get the aligned values - let value_first_read = aligned_inputs[0].value.to_le_bytes(); - let value_second_read = aligned_inputs[1].value.to_le_bytes(); - - // Get the aligned step - let step_first_read = aligned_inputs[0].step; - let step_second_read = aligned_inputs[1].step; - - // RVR - let mut first_read_row = MemAlignRow:: { - step: F::from_canonical_u64(step_first_read), - addr: F::from_canonical_u64(addr_first_read), - // offset: F::from_canonical_u64(0), - width: F::from_canonical_u64(CHUNK_NUM_U64), - // wr: F::from_bool(false), - // pc: F::from_canonical_u64(0), - reset: F::from_bool(true), - sel_up_to_down: F::from_bool(true), - ..Default::default() - }; - - let mut value_row = MemAlignRow:: { - step: F::from_canonical_u64(step), - addr: F::from_canonical_u64(addr), - offset: F::from_canonical_usize(offset), - width: F::from_canonical_usize(width), - // wr: F::from_bool(false), - pc: F::from_canonical_u64(next_pc), - // reset: F::from_bool(false), - sel_prove: F::from_bool(true), - ..Default::default() - }; - - let mut second_read_row = MemAlignRow:: { - step: F::from_canonical_u64(step_second_read), - addr: F::from_canonical_u64(addr_second_read), - // offset: F::from_canonical_u64(0), - width: F::from_canonical_u64(CHUNK_NUM_U64), - // wr: F::from_bool(false), - pc: F::from_canonical_u64(next_pc + 1), - // reset: F::from_bool(false), - sel_down_to_up: F::from_bool(true), - ..Default::default() - }; - - for i in 0..CHUNK_NUM { - first_read_row.reg[i] = F::from_canonical_u8(value_first_read[i]); - first_read_row.sel[i] = F::from_bool(true); - - value_row.reg[i] = F::from_canonical_u8(value[(shift + i) % CHUNK_NUM]); - value_row.sel[i] = F::from_bool(i == offset); - - second_read_row.reg[i] = F::from_canonical_u8(value_second_read[i]); - second_read_row.sel[i] = F::from_bool(true); - - // Store the range check - *range_check.entry(first_read_row.reg[i]).or_insert(0) += 1; - *range_check.entry(value_row.reg[i]).or_insert(0) += 1; - *range_check.entry(second_read_row.reg[i]).or_insert(0) += 1; - } - - // Store the rows - rows.push(first_read_row); - rows.push(value_row); - rows.push(second_read_row); - } - MemOp::TwoWrites => { - // RWVWR - // Sanity check - assert!(aligned_inputs.len() == 4); - - // Get the aligned address - let addr_first_read_write = aligned_inputs[0].address; - let addr_second_read_write = aligned_inputs[2].address; - - // Get the aligned values - // TODO: I do not need to establish an order, I can use the field is_write!!! - let value_first_read = aligned_inputs[0].value.to_le_bytes(); - let value_first_write = aligned_inputs[1].value.to_le_bytes(); - let value_second_read = aligned_inputs[2].value.to_le_bytes(); - let value_second_write = aligned_inputs[3].value.to_le_bytes(); - - // Get the aligned step - let step_first_read = aligned_inputs[0].step; - let step_first_write = aligned_inputs[1].step; - let step_second_read = aligned_inputs[2].step; - let step_second_write = aligned_inputs[3].step; - - // RWVWR - let mut first_read_row = MemAlignRow:: { - step: F::from_canonical_u64(step_first_read), - addr: F::from_canonical_u64(addr_first_read_write), - // offset: F::from_canonical_u64(0), - width: F::from_canonical_u64(CHUNK_NUM_U64), - // wr: F::from_bool(false), - // pc: F::from_canonical_u64(0), - reset: F::from_bool(true), - sel_up_to_down: F::from_bool(true), - ..Default::default() - }; - - let mut first_write_row = MemAlignRow:: { - step: F::from_canonical_u64(step_first_write), - addr: F::from_canonical_u64(addr_first_read_write), - // offset: F::from_canonical_u64(0), - width: F::from_canonical_u64(CHUNK_NUM_U64), - wr: F::from_bool(true), - pc: F::from_canonical_u64(next_pc), - // reset: F::from_bool(false), - sel_up_to_down: F::from_bool(true), - ..Default::default() - }; - - let mut value_row = MemAlignRow:: { - step: F::from_canonical_u64(step), - addr: F::from_canonical_u64(addr), - offset: F::from_canonical_usize(offset), - width: F::from_canonical_usize(width), - // wr: F::from_bool(false), - pc: F::from_canonical_u64(next_pc + 1), - // reset: F::from_bool(false), - sel_prove: F::from_bool(true), - ..Default::default() - }; - - let mut second_write_row = MemAlignRow:: { - step: F::from_canonical_u64(step_second_write), - addr: F::from_canonical_u64(addr_second_read_write), - // offset: F::from_canonical_u64(0), - width: F::from_canonical_u64(CHUNK_NUM_U64), - wr: F::from_bool(true), - pc: F::from_canonical_u64(next_pc + 2), - // reset: F::from_bool(false), - sel_down_to_up: F::from_bool(true), - ..Default::default() - }; - - let mut second_read_row = MemAlignRow:: { - step: F::from_canonical_u64(step_second_read), - addr: F::from_canonical_u64(addr_second_read_write), - // offset: F::from_canonical_u64(0), - width: F::from_canonical_u64(CHUNK_NUM_U64), - // wr: F::from_bool(false), - pc: F::from_canonical_u64(next_pc + 3), - reset: F::from_bool(false), - sel_down_to_up: F::from_bool(true), - ..Default::default() - }; - - for i in 0..CHUNK_NUM { - first_read_row.reg[i] = F::from_canonical_u8(value_first_read[i]); - first_read_row.sel[i] = F::from_bool(i < offset); - - first_write_row.reg[i] = F::from_canonical_u8(value_first_write[i]); - first_write_row.sel[i] = F::from_bool(i >= offset); - - value_row.reg[i] = F::from_canonical_u8(value[(shift + i) % CHUNK_NUM]); - value_row.sel[i] = F::from_bool(i == offset); - - second_write_row.reg[i] = F::from_canonical_u8(value_second_write[i]); - second_write_row.sel[i] = F::from_bool(i < shift); - - second_read_row.reg[i] = F::from_canonical_u8(value_second_read[i]); - second_read_row.sel[i] = F::from_bool(i >= shift); - - // Store the range check - *range_check.entry(first_read_row.reg[i]).or_insert(0) += 1; - *range_check.entry(first_write_row.reg[i]).or_insert(0) += 1; - *range_check.entry(value_row.reg[i]).or_insert(0) += 1; - *range_check.entry(second_write_row.reg[i]).or_insert(0) += 1; - *range_check.entry(second_read_row.reg[i]).or_insert(0) += 1; - } - - // Store the rows - rows.push(first_read_row); - rows.push(first_write_row); - rows.push(value_row); - rows.push(second_write_row); - rows.push(second_read_row); - } - } + // TODO: Perform the range checks + // let std = self.std.clone(); + // let range_id = std.get_range(BigInt::from(0), BigInt::from((1 << CHUNK_BITS) - 1), None); + // for (&value, &multiplicity) in reg_range_check.iter() { + // std.range_check(value, F::from_canonical_u64(multiplicity), range_id); + // } - // Update the ROM row multiplicity - mem_align_rom_sm.update_multiplicity_by_input(op, offset, width); + info!( + "{}: ··· Creating Mem Align instance [{} / {} rows filled {:.2}%]", + Self::MY_NAME, + rows_len, + air_mem_align.num_rows(), + rows_len as f64 / air_mem_align.num_rows() as f64 * 100.0 + ); - // Return successfully - rows + // Add a new Mem Align instance + let air_instance = + AirInstance::new(sctx, ZISK_AIRGROUP_ID, MEM_ALIGN_AIR_IDS[0], None, prover_buffer); + wcm.get_pctx().air_instance_repo.add_air_instance(air_instance, None); } + + // #[inline(always)] + // pub fn process_input( + // unaligned_input: &ZiskRequiredMemory, + // aligned_inputs: &[ZiskRequiredMemory], + // mem_align_rom_sm: &MemAlignRomSM, + // range_check: &mut HashMap, + // ) -> Vec> { + // // Get the unaligned address + // let addr = unaligned_input.address; + + // // Get the unaligned value + // let value = unaligned_input.value.to_le_bytes(); + + // // Get the unaligned step + // let step = unaligned_input.step; + + // // Get the unaligned width + // let width = unaligned_input.width; + // let width = if width <= CHUNK_NUM_U64 { + // width as usize + // } else { + // panic!("Invalid width={}", width); + // }; + + // // Compute the offset + // let offset = addr % CHUNK_NUM_U64; + // let offset = if offset <= usize::MAX as u64 { + // offset as usize + // } else { + // panic!("Invalid offset={}", offset); + // }; + + // // Compute the shift + // let shift = (offset + width) % CHUNK_NUM; + + // // Get the op to be executed, its size and the pc to jump to + // let op = Self::get_mem_op(&unaligned_input); + // let op_size = MemAlignRomSM::::get_mem_align_op_size(op); + // let next_pc = MemAlignRomSM::::calculate_next_pc(op, offset, width); + + // println!("OP: {:?}", op); + // println!("UNALIGNED INPUT:\n {:?}", unaligned_input); + // println!(" OFFSET: {:?}", offset); + // println!(" value: {:?}", unaligned_input.value.to_le_bytes()); + // println!("ALIGNED INPUTS:"); + // for aligned_input in aligned_inputs { + // println!(" {:?}", aligned_input); + // println!(" value: {:?}", aligned_input.value.to_le_bytes()); + // } + // println!(""); + + // // Initialize and set the rows of the corresponding op + // let mut rows: Vec> = Vec::with_capacity(op_size); + // // TODO: Can I detatch the "shape" of the program from the mem_align and do it in the mem_align_rom? + // match op { + // MemOp::OneRead => { + // // RV + // // Sanity check + // assert!(aligned_inputs.len() == 1); + + // // Get the aligned address + // let addr_read = aligned_inputs[0].address; + + // // Get the aligned values + // let value_read = aligned_inputs[0].value.to_le_bytes(); + + // // Get the aligned step + // let step_read = aligned_inputs[0].step; + + // let mut read_row = MemAlignRow:: { + // step: F::from_canonical_u64(step_read), + // addr: F::from_canonical_u64(addr_read), + // // offset: F::from_canonical_u64(0), + // // wr: F::from_bool(false), + // // pc: F::from_canonical_u64(0), + // reset: F::from_bool(true), + // sel_up_to_down: F::from_bool(true), + // ..Default::default() + // }; + + // let mut value_row = MemAlignRow:: { + // step: F::from_canonical_u64(step), + // addr: F::from_canonical_u64(addr), + // offset: F::from_canonical_usize(offset), + // width: F::from_canonical_usize(width), + // // wr: F::from_bool(false), + // pc: F::from_canonical_u64(next_pc), + // // reset: F::from_bool(false), + // sel_prove: F::from_bool(true), + // ..Default::default() + // }; + + // for i in 0..CHUNK_NUM { + // read_row.reg[i] = F::from_canonical_u8(value_read[i]); + // read_row.sel[i] = F::from_bool(true); + + // value_row.reg[i] = F::from_canonical_u8(value[(shift + i) % CHUNK_NUM]); + // value_row.sel[i] = F::from_bool(i == offset); + + // // Store the range check + // *range_check.entry(read_row.reg[i]).or_insert(0) += 1; + // *range_check.entry(value_row.reg[i]).or_insert(0) += 1; + // } + + // // Store the rows + // rows.push(read_row); + // rows.push(value_row); + // } + // MemOp::OneWrite => { + // // RWV + // // Sanity check + // assert!(aligned_inputs.len() == 2); + + // // Get the aligned address + // let addr_read_write = aligned_inputs[0].address; + + // // Get the aligned values + // let value_read = aligned_inputs[0].value.to_le_bytes(); + // let value_write = aligned_inputs[1].value.to_le_bytes(); + + // // Get the aligned step + // let step_read = aligned_inputs[0].step; + // let step_write = aligned_inputs[1].step; + + // // RWV + // let mut read_row = MemAlignRow:: { + // step: F::from_canonical_u64(step_read), + // addr: F::from_canonical_u64(addr_read_write), + // // offset: F::from_canonical_u64(0), + // width: F::from_canonical_u64(CHUNK_NUM_U64), + // // wr: F::from_bool(false), + // // pc: F::from_canonical_u64(0), + // reset: F::from_bool(true), + // sel_up_to_down: F::from_bool(true), + // ..Default::default() + // }; + + // let mut write_row = MemAlignRow:: { + // step: F::from_canonical_u64(step_write), + // addr: F::from_canonical_u64(addr_read_write), + // // offset: F::from_canonical_u64(0), + // width: F::from_canonical_u64(CHUNK_NUM_U64), + // wr: F::from_bool(true), + // pc: F::from_canonical_u64(next_pc), + // // reset: F::from_bool(false), + // sel_up_to_down: F::from_bool(true), + // ..Default::default() + // }; + + // let mut value_row = MemAlignRow:: { + // step: F::from_canonical_u64(step), + // addr: F::from_canonical_u64(addr), + // offset: F::from_canonical_usize(offset), + // width: F::from_canonical_usize(width), + // // wr: F::from_bool(false), + // pc: F::from_canonical_u64(next_pc + 1), + // // reset: F::from_bool(false), + // sel_prove: F::from_bool(true), + // ..Default::default() + // }; + + // for i in 0..CHUNK_NUM { + // read_row.reg[i] = F::from_canonical_u8(value_read[i]); + // read_row.sel[i] = F::from_bool(i >= width); + + // write_row.reg[i] = F::from_canonical_u8(value_write[i]); + // write_row.sel[i] = F::from_bool(i < width); + + // value_row.reg[i] = F::from_canonical_u8(value[(shift + i) % CHUNK_NUM]); + // value_row.sel[i] = F::from_bool(i == offset); + + // // Store the range check + // *range_check.entry(read_row.reg[i]).or_insert(0) += 1; + // *range_check.entry(write_row.reg[i]).or_insert(0) += 1; + // *range_check.entry(value_row.reg[i]).or_insert(0) += 1; + // } + + // // Store the rows + // rows.push(read_row); + // rows.push(write_row); + // rows.push(value_row); + // } + // MemOp::TwoReads => { + // // RVR + // // Sanity check + // assert!(aligned_inputs.len() == 2); + + // // Get the aligned address + // let addr_first_read = aligned_inputs[0].address; + // let addr_second_read = aligned_inputs[1].address; + + // // Get the aligned values + // let value_first_read = aligned_inputs[0].value.to_le_bytes(); + // let value_second_read = aligned_inputs[1].value.to_le_bytes(); + + // // Get the aligned step + // let step_first_read = aligned_inputs[0].step; + // let step_second_read = aligned_inputs[1].step; + + // // RVR + // let mut first_read_row = MemAlignRow:: { + // step: F::from_canonical_u64(step_first_read), + // addr: F::from_canonical_u64(addr_first_read), + // // offset: F::from_canonical_u64(0), + // width: F::from_canonical_u64(CHUNK_NUM_U64), + // // wr: F::from_bool(false), + // // pc: F::from_canonical_u64(0), + // reset: F::from_bool(true), + // sel_up_to_down: F::from_bool(true), + // ..Default::default() + // }; + + // let mut value_row = MemAlignRow:: { + // step: F::from_canonical_u64(step), + // addr: F::from_canonical_u64(addr), + // offset: F::from_canonical_usize(offset), + // width: F::from_canonical_usize(width), + // // wr: F::from_bool(false), + // pc: F::from_canonical_u64(next_pc), + // // reset: F::from_bool(false), + // sel_prove: F::from_bool(true), + // ..Default::default() + // }; + + // let mut second_read_row = MemAlignRow:: { + // step: F::from_canonical_u64(step_second_read), + // addr: F::from_canonical_u64(addr_second_read), + // // offset: F::from_canonical_u64(0), + // width: F::from_canonical_u64(CHUNK_NUM_U64), + // // wr: F::from_bool(false), + // pc: F::from_canonical_u64(next_pc + 1), + // // reset: F::from_bool(false), + // sel_down_to_up: F::from_bool(true), + // ..Default::default() + // }; + + // for i in 0..CHUNK_NUM { + // first_read_row.reg[i] = F::from_canonical_u8(value_first_read[i]); + // first_read_row.sel[i] = F::from_bool(true); + + // value_row.reg[i] = F::from_canonical_u8(value[(shift + i) % CHUNK_NUM]); + // value_row.sel[i] = F::from_bool(i == offset); + + // second_read_row.reg[i] = F::from_canonical_u8(value_second_read[i]); + // second_read_row.sel[i] = F::from_bool(true); + + // // Store the range check + // *range_check.entry(first_read_row.reg[i]).or_insert(0) += 1; + // *range_check.entry(value_row.reg[i]).or_insert(0) += 1; + // *range_check.entry(second_read_row.reg[i]).or_insert(0) += 1; + // } + + // // Store the rows + // rows.push(first_read_row); + // rows.push(value_row); + // rows.push(second_read_row); + // } + // MemOp::TwoWrites => { + // // RWVWR + // // Sanity check + // assert!(aligned_inputs.len() == 4); + + // // Get the aligned address + // let addr_first_read_write = aligned_inputs[0].address; + // let addr_second_read_write = aligned_inputs[2].address; + + // // Get the aligned values + // // TODO: I do not need to establish an order, I can use the field is_write!!! + // let value_first_read = aligned_inputs[0].value.to_le_bytes(); + // let value_first_write = aligned_inputs[1].value.to_le_bytes(); + // let value_second_read = aligned_inputs[2].value.to_le_bytes(); + // let value_second_write = aligned_inputs[3].value.to_le_bytes(); + + // // Get the aligned step + // let step_first_read = aligned_inputs[0].step; + // let step_first_write = aligned_inputs[1].step; + // let step_second_read = aligned_inputs[2].step; + // let step_second_write = aligned_inputs[3].step; + + // // RWVWR + // let mut first_read_row = MemAlignRow:: { + // step: F::from_canonical_u64(step_first_read), + // addr: F::from_canonical_u64(addr_first_read_write), + // // offset: F::from_canonical_u64(0), + // width: F::from_canonical_u64(CHUNK_NUM_U64), + // // wr: F::from_bool(false), + // // pc: F::from_canonical_u64(0), + // reset: F::from_bool(true), + // sel_up_to_down: F::from_bool(true), + // ..Default::default() + // }; + + // let mut first_write_row = MemAlignRow:: { + // step: F::from_canonical_u64(step_first_write), + // addr: F::from_canonical_u64(addr_first_read_write), + // // offset: F::from_canonical_u64(0), + // width: F::from_canonical_u64(CHUNK_NUM_U64), + // wr: F::from_bool(true), + // pc: F::from_canonical_u64(next_pc), + // // reset: F::from_bool(false), + // sel_up_to_down: F::from_bool(true), + // ..Default::default() + // }; + + // let mut value_row = MemAlignRow:: { + // step: F::from_canonical_u64(step), + // addr: F::from_canonical_u64(addr), + // offset: F::from_canonical_usize(offset), + // width: F::from_canonical_usize(width), + // // wr: F::from_bool(false), + // pc: F::from_canonical_u64(next_pc + 1), + // // reset: F::from_bool(false), + // sel_prove: F::from_bool(true), + // ..Default::default() + // }; + + // let mut second_write_row = MemAlignRow:: { + // step: F::from_canonical_u64(step_second_write), + // addr: F::from_canonical_u64(addr_second_read_write), + // // offset: F::from_canonical_u64(0), + // width: F::from_canonical_u64(CHUNK_NUM_U64), + // wr: F::from_bool(true), + // pc: F::from_canonical_u64(next_pc + 2), + // // reset: F::from_bool(false), + // sel_down_to_up: F::from_bool(true), + // ..Default::default() + // }; + + // let mut second_read_row = MemAlignRow:: { + // step: F::from_canonical_u64(step_second_read), + // addr: F::from_canonical_u64(addr_second_read_write), + // // offset: F::from_canonical_u64(0), + // width: F::from_canonical_u64(CHUNK_NUM_U64), + // // wr: F::from_bool(false), + // pc: F::from_canonical_u64(next_pc + 3), + // reset: F::from_bool(false), + // sel_down_to_up: F::from_bool(true), + // ..Default::default() + // }; + + // for i in 0..CHUNK_NUM { + // first_read_row.reg[i] = F::from_canonical_u8(value_first_read[i]); + // first_read_row.sel[i] = F::from_bool(i < offset); + + // first_write_row.reg[i] = F::from_canonical_u8(value_first_write[i]); + // first_write_row.sel[i] = F::from_bool(i >= offset); + + // value_row.reg[i] = F::from_canonical_u8(value[(shift + i) % CHUNK_NUM]); + // value_row.sel[i] = F::from_bool(i == offset); + + // second_write_row.reg[i] = F::from_canonical_u8(value_second_write[i]); + // second_write_row.sel[i] = F::from_bool(i < shift); + + // second_read_row.reg[i] = F::from_canonical_u8(value_second_read[i]); + // second_read_row.sel[i] = F::from_bool(i >= shift); + + // // Store the range check + // *range_check.entry(first_read_row.reg[i]).or_insert(0) += 1; + // *range_check.entry(first_write_row.reg[i]).or_insert(0) += 1; + // *range_check.entry(value_row.reg[i]).or_insert(0) += 1; + // *range_check.entry(second_write_row.reg[i]).or_insert(0) += 1; + // *range_check.entry(second_read_row.reg[i]).or_insert(0) += 1; + // } + + // // Store the rows + // rows.push(first_read_row); + // rows.push(first_write_row); + // rows.push(value_row); + // rows.push(second_write_row); + // rows.push(second_read_row); + // } + // } + + // // Update the ROM row multiplicity + // mem_align_rom_sm.update_multiplicity_by_input(op, offset, width); + + // // Return successfully + // rows + // } } impl WitnessComponent for MemAlignSM {} From d67b280e8e99ffe5fa8bfb89b58c9a77eadacd37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9ctor=20Masip?= Date: Fri, 8 Nov 2024 16:42:05 +0000 Subject: [PATCH 22/44] fixes --- state-machines/mem/src/mem_align_sm.rs | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/state-machines/mem/src/mem_align_sm.rs b/state-machines/mem/src/mem_align_sm.rs index fb243e06..34cc6eae 100644 --- a/state-machines/mem/src/mem_align_sm.rs +++ b/state-machines/mem/src/mem_align_sm.rs @@ -313,7 +313,7 @@ impl MemAlignSM { } // Prove the generated rows - self.prove(&[read_row, value_row]); + self.prove(&[read_row, write_row, value_row]); MemAlignResponse { more_address: false, step, value: Some(value_write) } } @@ -417,7 +417,7 @@ impl MemAlignSM { } // Prove the generated rows - self.prove(&[read_row, value_row]); + self.prove(&[first_read_row, value_row, second_read_row]); MemAlignResponse { more_address: false, step, value: None } } @@ -624,7 +624,13 @@ impl MemAlignSM { // *range_check.entry(second_read_row.reg[i]).or_insert(0) += 1; // Prove the generated rows - self.prove(&[read_row, value_row]); + self.prove(&[ + first_read_row, + first_write_row, + value_row, + second_write_row, + second_read_row, + ]); MemAlignResponse { more_address: false, From 451cdfb16f7e70e3e9ad2ba93b47d6e5f5caf475 Mon Sep 17 00:00:00 2001 From: zkronos73 Date: Fri, 8 Nov 2024 17:10:37 +0000 Subject: [PATCH 23/44] WIP mem proxy, rebase --- emulator/src/emu.rs | 141 +++---- emulator/src/emulator.rs | 23 +- state-machines/mem/src/mem_align_sm.rs | 14 +- state-machines/mem/src/mem_proxy.rs | 494 +++++++++++-------------- state-machines/mem/src/mem_sm.rs | 15 +- 5 files changed, 298 insertions(+), 389 deletions(-) diff --git a/emulator/src/emu.rs b/emulator/src/emu.rs index cc24f555..3e3bfc2e 100644 --- a/emulator/src/emu.rs +++ b/emulator/src/emu.rs @@ -98,7 +98,6 @@ impl<'a> Emu<'a> { &mut self, instruction: &ZiskInst, emu_mem: &mut Vec, - is_aligned: bool, ) { match instruction.a_src { SRC_C => self.ctx.inst_ctx.a = self.ctx.inst_ctx.c, @@ -109,16 +108,14 @@ impl<'a> Emu<'a> { } self.ctx.inst_ctx.a = self.ctx.inst_ctx.mem.read(addr, 8); - if is_aligned == Self::is_8_aligned(addr, 8) { - let required_memory = ZiskRequiredMemory { - step: self.ctx.inst_ctx.step, - is_write: false, - address: addr, - width: 8, - value: self.ctx.inst_ctx.a, - }; - emu_mem.push(required_memory); - } + let required_memory = ZiskRequiredMemory { + step: self.ctx.inst_ctx.step, + is_write: false, + address: addr, + width: 8, + value: self.ctx.inst_ctx.a, + }; + emu_mem.push(required_memory); } SRC_IMM => { self.ctx.inst_ctx.a = instruction.a_offset_imm0 | (instruction.a_use_sp_imm1 << 32) @@ -175,7 +172,6 @@ impl<'a> Emu<'a> { &mut self, instruction: &ZiskInst, emu_mem: &mut Vec, - is_aligned: bool, ) { match instruction.b_src { SRC_C => self.ctx.inst_ctx.b = self.ctx.inst_ctx.c, @@ -186,16 +182,14 @@ impl<'a> Emu<'a> { } self.ctx.inst_ctx.b = self.ctx.inst_ctx.mem.read(addr, 8); - if is_aligned == Self::is_8_aligned(addr, 8) { - let required_memory = ZiskRequiredMemory { - step: self.ctx.inst_ctx.step, - is_write: false, - address: addr, - width: 8, - value: self.ctx.inst_ctx.b, - }; - emu_mem.push(required_memory); - } + let required_memory = ZiskRequiredMemory { + step: self.ctx.inst_ctx.step, + is_write: false, + address: addr, + width: 8, + value: self.ctx.inst_ctx.b, + }; + emu_mem.push(required_memory); } SRC_IMM => { self.ctx.inst_ctx.b = instruction.b_offset_imm0 | (instruction.b_use_sp_imm1 << 32) @@ -207,16 +201,14 @@ impl<'a> Emu<'a> { addr += self.ctx.inst_ctx.sp; } self.ctx.inst_ctx.b = self.ctx.inst_ctx.mem.read(addr, instruction.ind_width); - if is_aligned == Self::is_8_aligned(addr, instruction.ind_width) { - let required_memory = ZiskRequiredMemory { - step: self.ctx.inst_ctx.step, - is_write: false, - address: addr, - width: instruction.ind_width, - value: self.ctx.inst_ctx.b, - }; - emu_mem.push(required_memory); - } + let required_memory = ZiskRequiredMemory { + step: self.ctx.inst_ctx.step, + is_write: false, + address: addr, + width: instruction.ind_width, + value: self.ctx.inst_ctx.b, + }; + emu_mem.push(required_memory); } _ => panic!( "Emu::source_b() Invalid b_src={} pc={}", @@ -274,7 +266,6 @@ impl<'a> Emu<'a> { &mut self, instruction: &ZiskInst, emu_mem: &mut Vec, - is_aligned: bool, ) { match instruction.store { STORE_NONE => {} @@ -290,16 +281,14 @@ impl<'a> Emu<'a> { } self.ctx.inst_ctx.mem.write_silent(addr as u64, val as u64, 8); - if is_aligned == Self::is_8_aligned(addr as u64, 8) { - let required_memory = ZiskRequiredMemory { - step: self.ctx.inst_ctx.step, - is_write: true, - address: addr as u64, - width: 8, - value: val as u64, - }; - emu_mem.push(required_memory); - } + let required_memory = ZiskRequiredMemory { + step: self.ctx.inst_ctx.step, + is_write: true, + address: addr as u64, + width: 8, + value: val as u64, + }; + emu_mem.push(required_memory); } STORE_IND => { let val: i64 = if instruction.store_ra { @@ -314,16 +303,14 @@ impl<'a> Emu<'a> { addr += self.ctx.inst_ctx.a as i64; self.ctx.inst_ctx.mem.write_silent(addr as u64, val as u64, instruction.ind_width); - if is_aligned == Self::is_8_aligned(addr as u64, instruction.ind_width) { - let required_memory = ZiskRequiredMemory { - step: self.ctx.inst_ctx.step, - is_write: true, - address: addr as u64, - width: instruction.ind_width, - value: val as u64, - }; - emu_mem.push(required_memory); - } + let required_memory = ZiskRequiredMemory { + step: self.ctx.inst_ctx.step, + is_write: true, + address: addr as u64, + width: instruction.ind_width, + value: val as u64, + }; + emu_mem.push(required_memory); } _ => panic!( "Emu::store_c() Invalid store={} pc={}", @@ -501,9 +488,9 @@ impl<'a> Emu<'a> { } // Log emulation step, if requested - if options.print_step.is_some() && - (options.print_step.unwrap() != 0) && - ((self.ctx.inst_ctx.step % options.print_step.unwrap()) == 0) + if options.print_step.is_some() + && (options.print_step.unwrap() != 0) + && ((self.ctx.inst_ctx.step % options.print_step.unwrap()) == 0) { println!("step={}", self.ctx.inst_ctx.step); } @@ -615,11 +602,7 @@ impl<'a> Emu<'a> { (emu_traces, emu_segments) } - pub fn par_run_memory( - &mut self, - inputs: Vec, - is_aligned: bool, - ) -> Vec { + pub fn par_run_memory(&mut self, inputs: Vec) -> Vec { // Context, where the state of the execution is stored and modified at every execution step self.ctx = self.create_emu_context(inputs); @@ -629,7 +612,7 @@ impl<'a> Emu<'a> { let mut emu_mem = Vec::new(); while !self.ctx.inst_ctx.end { - self.par_step_memory::(&mut emu_mem, is_aligned); + self.par_step_memory::(&mut emu_mem); } emu_mem @@ -706,9 +689,9 @@ impl<'a> Emu<'a> { // Increment step counter self.ctx.inst_ctx.step += 1; - if self.ctx.inst_ctx.end || - ((self.ctx.inst_ctx.step - self.ctx.last_callback_step) == - self.ctx.callback_steps) + if self.ctx.inst_ctx.end + || ((self.ctx.inst_ctx.step - self.ctx.last_callback_step) + == self.ctx.callback_steps) { // In run() we have checked the callback consistency with ctx.do_callback let callback = callback.as_ref().unwrap(); @@ -811,27 +794,23 @@ impl<'a> Emu<'a> { /// Performs one single step of the emulation #[inline(always)] #[allow(unused_variables)] - pub fn par_step_memory( - &mut self, - emu_mem: &mut Vec, - is_aligned: bool, - ) { + pub fn par_step_memory(&mut self, emu_mem: &mut Vec) { let last_pc = self.ctx.inst_ctx.pc; let last_c = self.ctx.inst_ctx.c; let instruction = self.rom.get_instruction(self.ctx.inst_ctx.pc); // Build the 'a' register value based on the source specified by the current instruction - self.source_a_memory(instruction, emu_mem, is_aligned); + self.source_a_memory(instruction, emu_mem); // Build the 'b' register value based on the source specified by the current instruction - self.source_b_memory(instruction, emu_mem, is_aligned); + self.source_b_memory(instruction, emu_mem); // Call the operation (instruction.func)(&mut self.ctx.inst_ctx); // Store the 'c' register value based on the storage specified by the current instruction - self.store_c_memory(instruction, emu_mem, is_aligned); + self.store_c_memory(instruction, emu_mem); // Set SP, if specified by the current instruction // #[cfg(feature = "sp")] @@ -924,11 +903,11 @@ impl<'a> Emu<'a> { let mut current_box_id = 0; let mut current_step_idx = loop { - if current_box_id == vec_traces.len() - 1 || - vec_traces[current_box_id + 1].start_state.step >= emu_trace_start.step + if current_box_id == vec_traces.len() - 1 + || vec_traces[current_box_id + 1].start_state.step >= emu_trace_start.step { - break emu_trace_start.step as usize - - vec_traces[current_box_id].start_state.step as usize; + break emu_trace_start.step as usize + - vec_traces[current_box_id].start_state.step as usize; } current_box_id += 1; }; @@ -1039,8 +1018,8 @@ impl<'a> Emu<'a> { let b = [inst_ctx.b & 0xFFFFFFFF, (inst_ctx.b >> 32) & 0xFFFFFFFF]; let c = [inst_ctx.c & 0xFFFFFFFF, (inst_ctx.c >> 32) & 0xFFFFFFFF]; - let addr1 = (inst.b_offset_imm0 as i64 + - if inst.b_src == SRC_IND { inst_ctx.a as i64 } else { 0 }) as u64; + let addr1 = (inst.b_offset_imm0 as i64 + + if inst.b_src == SRC_IND { inst_ctx.a as i64 } else { 0 }) as u64; let jmp_offset1 = if inst.jmp_offset1 >= 0 { F::from_canonical_u64(inst.jmp_offset1 as u64) @@ -1118,8 +1097,8 @@ impl<'a> Emu<'a> { m32: F::from_bool(inst.m32), addr1: F::from_canonical_u64(addr1), __debug_operation_bus_enabled: F::from_bool( - inst.op_type == ZiskOperationType::Binary || - inst.op_type == ZiskOperationType::BinaryE, + inst.op_type == ZiskOperationType::Binary + || inst.op_type == ZiskOperationType::BinaryE, ), } } diff --git a/emulator/src/emulator.rs b/emulator/src/emulator.rs index eb862fe0..f85ec309 100644 --- a/emulator/src/emulator.rs +++ b/emulator/src/emulator.rs @@ -246,22 +246,15 @@ impl ZiskEmulator { pub fn par_process_rom_memory( rom: &ZiskRom, inputs: &[u8], - ) -> Result<[Vec; 2], ZiskEmulatorErr> { - let mut result: [Vec; 2] = [Vec::new(), Vec::new()]; - - result.par_iter_mut().enumerate().for_each(|(is_aligned, result)| { - let is_aligned = is_aligned == 0; - let mut emu = Emu::new(rom); - let required = emu.par_run_memory::(inputs.to_owned(), is_aligned); - - if !emu.terminated() { - panic!("Emulation did not complete"); - // TODO! - // return Err(ZiskEmulatorErr::EmulationNoCompleted); - } + ) -> Result, ZiskEmulatorErr> { + let mut emu = Emu::new(rom); + let result = emu.par_run_memory::(inputs.to_owned()); - *result = required; - }); + if !emu.terminated() { + panic!("Emulation did not complete"); + // TODO! + // return Err(ZiskEmulatorErr::EmulationNoCompleted); + } Ok(result) } diff --git a/state-machines/mem/src/mem_align_sm.rs b/state-machines/mem/src/mem_align_sm.rs index 34cc6eae..312e2d02 100644 --- a/state-machines/mem/src/mem_align_sm.rs +++ b/state-machines/mem/src/mem_align_sm.rs @@ -34,6 +34,11 @@ pub struct MemAlignResponse { pub value: Option, } +pub struct MemAlignResponse { + pub more_address: bool, + pub step: u64, + pub mem_value: u64, +} pub struct MemAlignSM { // Witness computation manager wcm: Arc>, @@ -682,12 +687,9 @@ impl MemAlignSM { create_prover_buffer(&ectx, &sctx, ZISK_AIRGROUP_ID, MEM_ALIGN_AIR_IDS[0]); // Create a Mem Align trace buffer - let mut trace_buffer = MemAlignTrace::::map_buffer( - &mut prover_buffer, - air_mem_align_rows, - offset as usize, - ) - .unwrap(); + let mut trace_buffer = + MemAlignTrace::::map_buffer(&mut prover_buffer, air_mem_align_rows, offset as usize) + .unwrap(); // Add the input rows to the trace for (i, &row) in rows.iter().enumerate() { diff --git a/state-machines/mem/src/mem_proxy.rs b/state-machines/mem/src/mem_proxy.rs index 30d4689f..33c4ce9a 100644 --- a/state-machines/mem/src/mem_proxy.rs +++ b/state-machines/mem/src/mem_proxy.rs @@ -1,358 +1,280 @@ +use std::collections::VecDeque; +use std::default; use std::sync::{ atomic::{AtomicU32, Ordering}, Arc, }; -use crate::{MemAlignRomSM, MemAlignSM, MemOp, MemSM}; +use crate::{MemAlignResponse, MemAlignRomSM, MemAlignSM, MemOp, MemSM}; use p3_field::PrimeField; use pil_std_lib::Std; +use proofman_common::StepsParams; use proofman_util::{timer_start_debug, timer_stop_and_log_debug}; -use zisk_core::{ZiskRequiredMemory, RAM_ADDR}; +use zisk_core::{elf2rom, ZiskRequiredMemory, ZiskRequiredOperation, RAM_ADDR}; use proofman::{WitnessComponent, WitnessManager}; +const MEM_ADDR_MASK: u64 = 0xFFFF_FFFF_FFFF_FFF8; +const MEM_BYTES: u64 = 8; + +const MAX_MEM_STEP_OFFSET: u64 = 2; +const MAX_MEM_OPS_PER_MAIN_STEP: u64 = (MAX_MEM_STEP_OFFSET + 1) * 2; + +pub trait MemModule: Send + Sync { + fn send_inputs(&self, mem_op: &[ZiskRequiredMemory]); + fn get_addr_ranges(&self) -> Vec<(u64, u64)>; + fn get_flush_input_size(&self) -> u64; + fn unregister_predecessor(&self); + fn register_predecessor(&self); +} + +struct MemModuleData { + pub inputs: Vec, + pub addr_ranges: Vec<(u64, u64)>, + pub flush_input_size: u64, +} + pub struct MemProxy { // Count of registered predecessors registered_predecessors: AtomicU32, // Secondary State machines - mem_sm: Arc>, + // mem_sm: Arc>, mem_align_sm: Arc>, + modules: Vec>>, + modules_data: Vec, +} + +pub struct MemOperation { + pub step: u64, + pub is_write: bool, + pub address: u64, + pub width: u64, + pub value: u64, +} + +pub struct MemAlignOperation { + pub address: u64, + pub mem_op: ZiskRequiredMemory, + pub mem_value: [u64; 2], } impl MemProxy { pub fn new(wcm: Arc>, std: Arc>) -> Arc { - let mem_sm = MemSM::new(wcm.clone()); let mem_align_rom_sm = MemAlignRomSM::new(wcm.clone()); let mem_align_sm = MemAlignSM::new(wcm.clone(), std, mem_align_rom_sm); + let mut modules: Vec>> = Vec::new(); + + modules.push(MemSM::new(wcm.clone()).clone()); + let mut modules_data: Vec = Vec::new(); + + for module in modules.iter_mut() { + modules_data.push(Self::init_module(module)); + } let mem_proxy = Self { registered_predecessors: AtomicU32::new(0), - mem_sm, mem_align_sm, + modules, + modules_data, }; let mem_proxy = Arc::new(mem_proxy); wcm.register_component(mem_proxy.clone(), None, None); // For all the secondary state machines, register the main state machine as a predecessor - mem_proxy.mem_sm.register_predecessor(); mem_proxy.mem_align_sm.register_predecessor(); - mem_proxy } - + pub fn main_step_to_mem_step(step: u64, step_offset: u8) -> u64 { + 1 + MAX_MEM_OPS_PER_MAIN_STEP * step + 2 * step_offset as u64 + } pub fn register_predecessor(&self) { self.registered_predecessors.fetch_add(1, Ordering::SeqCst); } pub fn unregister_predecessor(&self) { if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { - self.mem_sm.unregister_predecessor(); + // self.mem_sm.unregister_predecessor(); self.mem_align_sm.unregister_predecessor(); } } + pub fn init_module(module: &Arc>) -> MemModuleData { + module.register_predecessor(); + let ranges = module.get_addr_ranges(); + let flush_input_size = module.get_flush_input_size(); + MemModuleData { inputs: Vec::new(), addr_ranges: ranges, flush_input_size } + } + + /// Static method to decide it the memory operation needs to be processed by + /// memAlign, because it isn't a 8-byte and 8-byte aligned memory access. + fn is_aligned(mem_op: &ZiskRequiredMemory) -> bool { + let aligned_mem_address = mem_op.address & MEM_ADDR_MASK; + aligned_mem_address == mem_op.address && mem_op.width == MEM_BYTES + } + /// Process information of mem_op and mem_align_op to push mem_op operation. Only two possible situations: + /// 1) read, only on single mem_op is pushed + /// 2) read+write, two mem_op are pushed, one read and one write. + /// + /// This process is used for each aligned memory address, means that the "second part" of non aligned memory + /// operation is processed on addr + MEM_BYTES. + fn push_mem_align_op( + &self, + mem_addr: u64, + mem_value: u64, + mem_op: &ZiskRequiredMemory, + mem_align_op: &MemAlignResponse, + input: &mut Vec, + ) -> u64 { + // Prepare aligned memory access + input.push(ZiskRequiredMemory { + step: mem_align_op.step, + is_write: false, + address: mem_addr, + width: MEM_BYTES, + value: mem_value, + }); + if mem_op.is_write { + input.push(ZiskRequiredMemory { + step: mem_align_op.step + 1, + is_write: true, + address: mem_addr, + width: MEM_BYTES, + value: mem_align_op.mem_value, + }); + mem_align_op.mem_value + } else { + mem_value + } + } + fn mem_align_call( + mem_op: &ZiskRequiredMemory, + mem_values: [u64; 2], + phase: u8, + ) -> MemAlignResponse { + let mem_align_res = MemAlignResponse { more_address: false, step: 0, mem_value: 0 }; + mem_align_res + } + fn get_mem_module_id(&self, address: u64) -> (usize, u64) { + let mem_module_id = 0; + let next_addr_to_reevaluate = 0; + (mem_module_id, next_addr_to_reevaluate) + } pub fn prove( &self, - operations: &mut [Vec; 2], + mem_operations: &mut Vec, ) -> Result<(), Box> { - let mut aligned = std::mem::take(&mut operations[0]); - let unaligned = std::mem::take(&mut operations[1]); - let mut new_aligned = Vec::new(); + let mut open_mem_align_ops: VecDeque = VecDeque::new(); + let mut mem_module_inputs: [Vec; 2] = Default::default(); // Step 1. Sort the aligned memory accesses + // original vector is sorted by step, sort_by_key is stable, no reordering of elements with + // the same key. timer_start_debug!(MEM_SORT); - aligned.sort_by_key(|mem| mem.address); + mem_operations.sort_by_key(|mem| (mem.address & 0xFFFF_FFFF_FFFF_FFF8)); timer_stop_and_log_debug!(MEM_SORT); - // Step 2. For each unaligned memory access - unaligned.iter().for_each(|unaligned_access| { - // Step 2.1 Ask to the Mem Align SM for the aligned memory accesses generated by the non-aligned one - let mem_op = MemAlignSM::::get_mem_op(unaligned_access); - - // Step 2.2 Ask to the Mem SM for the aligned memory accesses - // TODO! Remove mem_op.clone() - let aligned_accesses = self.get_aligned_accesses( - &unaligned_access, - mem_op.clone(), - &aligned, - &new_aligned, - ); - - // Step 2.3 Carried with the aligned memory accesses, prove the non-aligned ones - self.mem_align_sm.prove(unaligned_access, &aligned_accesses); - - // Step 2.4 Store the new aligned memory access(es) - new_aligned.extend(aligned_accesses); - new_aligned.sort_by_key(|mem| mem.address); + // Initialize the last values of address and value on the sorted memory operations + let mut last_addr = 0xFFFF_FFFF_FFFF_FFFFu64; + let mut last_value = 0u64; + + // Add a final fake mem_op to force flush of open_mem_align_ops + mem_operations.push(ZiskRequiredMemory { + step: 0, + is_write: false, + address: MEM_ADDR_MASK, + width: 8, + value: 0, }); - // Step 3. Concatenate the new aligned memory accesses with the original aligned memory - // accesses - aligned.extend(new_aligned); - - timer_start_debug!(MEM_SORT_2); - aligned.sort_by_key(|mem| (mem.address, mem.step)); - timer_stop_and_log_debug!(MEM_SORT_2); - - let mut idx = 0; - while aligned[idx].address < RAM_ADDR && idx < aligned.len() { - idx += 1; - } - - let (_input_aligned, aligned) = aligned.split_at_mut(idx); - - // Step 4. Prove the aligned memory accesses using mem state machine - self.mem_sm.prove(aligned); - - Ok(()) - } - - #[inline(always)] - fn get_aligned_accesses( - &self, - unaligned_access: &ZiskRequiredMemory, - mem_op: MemOp, - aligned_accesses: &[ZiskRequiredMemory], - new_aligned_accesses: &[ZiskRequiredMemory], - ) -> Vec { - // Align down to a 8 byte addres - let addr = unaligned_access.address & !7; - match mem_op { - MemOp::OneRead => { - // Look for last write to the same address - let last_write_addr = Self::get_last_write( - addr, - unaligned_access.step, - aligned_accesses, - Some(new_aligned_accesses), - ); - let mut last_write_addr = last_write_addr.unwrap_or(ZiskRequiredMemory { - step: unaligned_access.step, - is_write: false, - address: addr, - width: 8, - value: 0, - }); - - last_write_addr.step = unaligned_access.step; + let (mem_module_id, next_addr_to_reevaluate) = if mem_operations.is_empty() { + (0, 0) + } else { + self.get_mem_module_id(mem_operations[0].address) + }; - vec![last_write_addr] - } - MemOp::OneWrite => { - // Look for last write to the same address - let last_write_addr = Self::get_last_write( - addr, - unaligned_access.step, - aligned_accesses, - Some(new_aligned_accesses), + for mem_op in mem_operations.iter_mut() { + let mut aligned_mem_address = mem_op.address & MEM_ADDR_MASK; + + // Check if there are open mem align operations to be processed in this moment. Two possible + // conditions to process open mem align operations: + // 1) the address of open operation is less than the aligned address. + // 2) the address of open operation is equal to the aligned address, but the step of the open + // operation is less than the step of the current operation. + + while open_mem_align_ops.len() > 0 + || open_mem_align_ops[0].address < aligned_mem_address + || (open_mem_align_ops[0].address == aligned_mem_address + && open_mem_align_ops[0].mem_op.step < mem_op.step) + { + let open_op = open_mem_align_ops.pop_front().unwrap(); + let mem_value = if open_op.address == last_addr { last_value } else { 0 }; + + // call to mem_align to get information of the aligned memory access needed + // to prove the unaligned open operation. + let mem_align_op = Self::mem_align_call(&open_op.mem_op, [mem_value, 0], 1); + + // remove element from top of queue, because we are on last phase, phase 1. + open_mem_align_ops.pop_front(); + + // push the aligned memory operations for current address (read or read+write) and + // update last_address and last_value. + last_value = self.push_mem_align_op( + open_op.address, + mem_value, + &mem_op, + &mem_align_op, + &mut mem_module_inputs[mem_module_id], ); - - // Modify the value of the write to the same address - let last_write_addr = last_write_addr.unwrap_or(ZiskRequiredMemory { - step: unaligned_access.step, - is_write: true, - address: addr, - width: 8, - value: 0, - }); - - let mut last_write_addr_r = last_write_addr.clone(); - last_write_addr_r.step = unaligned_access.step; - last_write_addr_r.is_write = false; - - let mut last_write_addr_w = last_write_addr; - last_write_addr_w.step = unaligned_access.step; - Self::write_value(&unaligned_access, &mut last_write_addr_w); - - vec![last_write_addr_r, last_write_addr_w] + last_addr = open_op.address; + // TODO: check if flush is needed } - MemOp::TwoReads => { - // Look for last write to the same address and same address + 8 - let last_write_addr = Self::get_last_write( - addr, - unaligned_access.step, - aligned_accesses, - Some(new_aligned_accesses), - ); - let last_write_addr_p = Self::get_last_write( - addr + 8, - unaligned_access.step, - aligned_accesses, - Some(new_aligned_accesses), - ); - let mut last_write_addr = last_write_addr.unwrap_or(ZiskRequiredMemory { - step: unaligned_access.step, - is_write: false, - address: addr, - width: 8, - value: 0, - }); - - let mut last_write_addr_p = last_write_addr_p.unwrap_or(ZiskRequiredMemory { - step: unaligned_access.step, - is_write: false, - address: addr + 8, - width: 8, - value: 0, - }); - - last_write_addr.step = unaligned_access.step; - last_write_addr_p.step = unaligned_access.step; - - vec![last_write_addr, last_write_addr_p] - } - MemOp::TwoWrites => { - // Look for last write to the same address and same address + 8 - let last_write_addr = Self::get_last_write( - addr, - unaligned_access.step, - aligned_accesses, - Some(new_aligned_accesses), - ); - let last_write_addr_p = Self::get_last_write( - addr + 8, - unaligned_access.step, - aligned_accesses, - Some(new_aligned_accesses), - ); + aligned_mem_address = mem_op.address & MEM_ADDR_MASK; - // Modify the value of the write to the same address - let last_write_addr = last_write_addr.unwrap_or(ZiskRequiredMemory { - step: unaligned_access.step, - is_write: true, - address: addr, - width: 8, - value: 0, - }); - - let mut last_write_addr_r = last_write_addr.clone(); - last_write_addr_r.step = unaligned_access.step; - last_write_addr_r.is_write = false; - - let mut last_write_addr_w = last_write_addr; - last_write_addr_w.step = unaligned_access.step; - Self::write_value(&unaligned_access, &mut last_write_addr_w); - - let last_write_addr_p = last_write_addr_p.unwrap_or(ZiskRequiredMemory { - step: unaligned_access.step, - is_write: true, - address: addr + 8, - width: 8, - value: 0, - }); - - let mut last_write_addr_p_r = last_write_addr_p.clone(); - last_write_addr_p_r.step = unaligned_access.step; - last_write_addr_p_r.is_write = false; - - let mut last_write_addr_p_w = last_write_addr_p; - last_write_addr_p_w.step = unaligned_access.step; - Self::write_value(&unaligned_access, &mut last_write_addr_p_w); - - Self::write_values( - &unaligned_access, - &mut last_write_addr_w, - &mut last_write_addr_p_w, + // check if the aligned address is the last address to avoid processing the last fake mem_op + if aligned_mem_address == MEM_ADDR_MASK { + assert!( + open_mem_align_ops.len() == 0, + "open_mem_align_ops not empty, has {} elements", + open_mem_align_ops.len() ); - vec![last_write_addr_r, last_write_addr_w, last_write_addr_p_r, last_write_addr_p_w] + break; } - } - } - #[inline(always)] - fn get_last_write( - addr: u64, - step: u64, - aligned_accesses: &[ZiskRequiredMemory], - new_aligned_accesses: Option<&[ZiskRequiredMemory]>, - ) -> Option { - // Step 1: Find the start of the range for `addr` - let start_index = - match aligned_accesses.binary_search_by_key(&addr, |access| access.address) { - Ok(mut index) => { - // Backtrack to find the first occurrence of `addr` - while index > 0 && aligned_accesses[index - 1].address == addr { - index -= 1; - } - index + let mem_value = if aligned_mem_address == last_addr { last_value } else { 0 }; + + // all open mem align operations are processed, check if new mem operation is aligned + if !Self::is_aligned(&mem_op) { + // In this point found non-aligned memory access, phase-0 + let mem_align_op = Self::mem_align_call(mem_op, [mem_value, 0], 0); + if mem_align_op.more_address { + open_mem_align_ops.push_back(MemAlignOperation { + address: aligned_mem_address + MEM_BYTES, + mem_op: mem_op.clone(), + mem_value: [mem_value, 0], + }); } - Err(index) => index, // If no match, use the insertion point as before - }; - - // Step 2: Iterate from start_index forward, storing the last valid write - let mut last_write = None; - for access in &aligned_accesses[start_index..] { - if access.address != addr { - break; // Stop if we move past the given address - } - if access.step >= step { - break; // Stop if step is not less than the given step - } - if access.is_write { - last_write = Some(access.clone()); // Update last write if conditions are met + self.push_mem_align_op( + aligned_mem_address, + mem_value, + &mem_op, + &mem_align_op, + &mut mem_module_inputs[mem_module_id], + ); + } else { + mem_module_inputs[mem_module_id].push(mem_op.clone()); } - } - - // Step 3: If `new_aligned_accesses` exists, check for a more recent write - if let None = new_aligned_accesses { - return last_write; - } - - let new_aligned_accesses = new_aligned_accesses.unwrap(); - let last_new_write = Self::get_last_write(addr, step, new_aligned_accesses, None); - - if let None = last_write { - return last_new_write; - } - - if let Some(last_new_write) = last_new_write { - if last_new_write.step > last_write.as_ref().unwrap().step { - return Some(last_new_write); + if (mem_module_inputs[mem_module_id].len() as u64) + >= self.modules_data[mem_module_id].flush_input_size + { + let module = &self.modules[mem_module_id]; + module.send_inputs(&mem_module_inputs[mem_module_id]); } } - last_write - } - - #[inline(always)] - fn write_value(unaligned: &ZiskRequiredMemory, aligned: &mut ZiskRequiredMemory) { - let offset = unaligned.address & 7; - let width_in_bits = unaligned.width * 8; - - let mask = !(((1u64 << width_in_bits) - 1) << (offset * 8)); - - aligned.value = (aligned.value & mask) - | ((unaligned.value & ((1u64 << width_in_bits) - 1)) << (offset * 8)); - } - #[inline(always)] - fn write_values( - unaligned: &ZiskRequiredMemory, - aligned: &mut ZiskRequiredMemory, - aligned_next: &mut ZiskRequiredMemory, - ) { - let offset = unaligned.address & 7; - let bytes_to_write = 8 - offset; - let right_bits = (unaligned.width - bytes_to_write) * 8; - - // Left write - let left_value = unaligned.value << right_bits; - let left_memory = - ZiskRequiredMemory { width: bytes_to_write, value: left_value, ..*unaligned }; - Self::write_value(&left_memory, aligned); - - // Right write - let right_value = unaligned.value >> (bytes_to_write * 8); - - let right_memory = ZiskRequiredMemory { - address: 0, - width: unaligned.width - bytes_to_write, - value: right_value, - ..*unaligned - }; - Self::write_value(&right_memory, aligned_next); + Ok(()) } } diff --git a/state-machines/mem/src/mem_sm.rs b/state-machines/mem/src/mem_sm.rs index 0d3aa1a4..50f8fae4 100644 --- a/state-machines/mem/src/mem_sm.rs +++ b/state-machines/mem/src/mem_sm.rs @@ -3,13 +3,14 @@ use std::sync::{ Arc, Mutex, }; +use crate::MemModule; use p3_field::PrimeField; use proofman::{WitnessComponent, WitnessManager}; use proofman_common::AirInstance; use rayon::prelude::*; use sm_common::create_prover_buffer; -use zisk_core::ZiskRequiredMemory; +use zisk_core::{Mem, ZiskRequiredMemory}; use zisk_pil::{MemTrace, MEM_AIR_IDS, ZISK_AIRGROUP_ID}; pub struct MemSM { @@ -255,4 +256,16 @@ impl MemSM { } } +impl MemModule for MemSM { + fn send_inputs(&self, mem_op: &[ZiskRequiredMemory]) {} + fn get_addr_ranges(&self) -> Vec<(u64, u64)> { + vec![] + } + fn get_flush_input_size(&self) -> u64 { + 0 + } + fn unregister_predecessor(&self) {} + fn register_predecessor(&self) {} +} + impl WitnessComponent for MemSM {} From ad7381eccecce9bad6c01dbae4c63d07055cb21e Mon Sep 17 00:00:00 2001 From: zkronos73 Date: Fri, 8 Nov 2024 17:13:16 +0000 Subject: [PATCH 24/44] rebase --- Cargo.lock | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Cargo.lock b/Cargo.lock index 67f9ca1d..cdb6f094 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2232,7 +2232,9 @@ name = "sm-mem" version = "0.1.0" dependencies = [ "log", + "num-bigint", "p3-field", + "pil-std-lib", "proofman", "proofman-common", "proofman-macros", From 257fe755d6e57e37de6781ae0b008e1f7659e12b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9ctor=20Masip?= Date: Fri, 8 Nov 2024 20:20:26 +0000 Subject: [PATCH 25/44] Range check done --- state-machines/mem/src/mem_align_sm.rs | 30 +++++++++++++++++++------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/state-machines/mem/src/mem_align_sm.rs b/state-machines/mem/src/mem_align_sm.rs index 34cc6eae..634ac19a 100644 --- a/state-machines/mem/src/mem_align_sm.rs +++ b/state-machines/mem/src/mem_align_sm.rs @@ -689,18 +689,39 @@ impl MemAlignSM { ) .unwrap(); + let mut reg_range_check: HashMap = HashMap::new(); + // Add the input rows to the trace for (i, &row) in rows.iter().enumerate() { + // Store the entire row trace_buffer[i] = row; + + // Store the value of all reg columns so that they can be range checked + for j in 0..CHUNK_NUM { + *reg_range_check.entry(row.reg[j]).or_insert(0) += 1; + } } // Pad the remaining rows with trivially satisfying rows let padding_row = MemAlignRow::::default(); + let padding_size = air_mem_align_rows - rows_len; + + // Store the padding rows for i in rows_len..air_mem_align_rows { trace_buffer[i] = padding_row; } - // TODO: Treat the range check here of both standard and padding rows!! + // Store the value of all reg columns so that they can be range checked + for j in 0..CHUNK_NUM { + *reg_range_check.entry(padding_row.reg[j]).or_insert(0) += padding_size as u64; + } + + // Perform the range checks + let std = self.std.clone(); + let range_id = std.get_range(BigInt::from(0), BigInt::from(CHUNK_BITS_MASK), None); + for (&value, &multiplicity) in reg_range_check.iter() { + std.range_check(value, F::from_canonical_u64(multiplicity), range_id); + } // TODO: Treate the ROM multiplicity @@ -715,13 +736,6 @@ impl MemAlignSM { // rom_multiplicity[row as usize] += multiplicity; // } - // TODO: Perform the range checks - // let std = self.std.clone(); - // let range_id = std.get_range(BigInt::from(0), BigInt::from((1 << CHUNK_BITS) - 1), None); - // for (&value, &multiplicity) in reg_range_check.iter() { - // std.range_check(value, F::from_canonical_u64(multiplicity), range_id); - // } - info!( "{}: ··· Creating Mem Align instance [{} / {} rows filled {:.2}%]", Self::MY_NAME, From 125b205853735fc9a12c2270b548d2efcf6babaa Mon Sep 17 00:00:00 2001 From: zkronos73 Date: Sat, 9 Nov 2024 09:27:35 +0000 Subject: [PATCH 26/44] minor bugs,logs on memory proxy --- core/src/zisk_required_operation.rs | 17 ++ pil/src/pil_helpers/traces.rs | 2 + state-machines/mem/src/mem_align_rom_sm.rs | 2 +- state-machines/mem/src/mem_align_sm.rs | 13 +- state-machines/mem/src/mem_proxy.rs | 176 +++++++++++++++++---- state-machines/mem/src/mem_sm.rs | 20 ++- 6 files changed, 188 insertions(+), 42 deletions(-) diff --git a/core/src/zisk_required_operation.rs b/core/src/zisk_required_operation.rs index 59a7aee6..7702c71a 100644 --- a/core/src/zisk_required_operation.rs +++ b/core/src/zisk_required_operation.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::fmt; #[derive(Clone)] pub struct ZiskRequiredOperation { @@ -17,6 +18,22 @@ pub struct ZiskRequiredMemory { pub value: u64, } +impl fmt::Debug for ZiskRequiredMemory { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let label = if self.is_write { "WR" } else { "RD" }; + write!( + f, + "{0} addr:{1:#08X}({1}) with:{2} value:{3:#016X}({3}) step:{4} offset:{5}", + label, + self.address, + self.width, + self.value, + self.step, + self.address & 0x07 + ) + } +} + #[derive(Clone, Default)] pub struct ZiskRequired { pub arith: Vec, diff --git a/pil/src/pil_helpers/traces.rs b/pil/src/pil_helpers/traces.rs index 1545cdca..da9b392c 100644 --- a/pil/src/pil_helpers/traces.rs +++ b/pil/src/pil_helpers/traces.rs @@ -46,3 +46,5 @@ trace!(SpecifiedRangesRow, SpecifiedRangesTrace { trace!(U8AirRow, U8AirTrace { mul: F, }); + + diff --git a/state-machines/mem/src/mem_align_rom_sm.rs b/state-machines/mem/src/mem_align_rom_sm.rs index 5a0e63b1..606d0dad 100644 --- a/state-machines/mem/src/mem_align_rom_sm.rs +++ b/state-machines/mem/src/mem_align_rom_sm.rs @@ -12,7 +12,7 @@ use proofman::{WitnessComponent, WitnessManager}; use proofman_common::AirInstance; use sm_common::create_prover_buffer; -use zisk_pil::{MemAlignRomRow, MemAlignRomTrace, MEM_ALIGN_ROM_AIR_IDS, ZISK_AIRGROUP_ID}; +use zisk_pil::{MemAlignRomTrace, MEM_ALIGN_ROM_AIR_IDS, ZISK_AIRGROUP_ID}; #[derive(Debug, Clone, Copy)] pub enum MemOp { diff --git a/state-machines/mem/src/mem_align_sm.rs b/state-machines/mem/src/mem_align_sm.rs index 312e2d02..2597bb15 100644 --- a/state-machines/mem/src/mem_align_sm.rs +++ b/state-machines/mem/src/mem_align_sm.rs @@ -33,12 +33,6 @@ pub struct MemAlignResponse { pub step: u64, pub value: Option, } - -pub struct MemAlignResponse { - pub more_address: bool, - pub step: u64, - pub mem_value: u64, -} pub struct MemAlignSM { // Witness computation manager wcm: Arc>, @@ -693,11 +687,18 @@ impl MemAlignSM { // Add the input rows to the trace for (i, &row) in rows.iter().enumerate() { + assert!( + row.sel_up_to_down.is_zero() || row.sel_down_to_up.is_zero(), + "sel_up_to_down:{:?} sel_down_to_up:{:?}", + row.sel_up_to_down, + row.sel_down_to_up + ); trace_buffer[i] = row; } // Pad the remaining rows with trivially satisfying rows let padding_row = MemAlignRow::::default(); + assert!(padding_row.sel_up_to_down.is_zero() || padding_row.sel_down_to_up.is_zero()); for i in rows_len..air_mem_align_rows { trace_buffer[i] = padding_row; } diff --git a/state-machines/mem/src/mem_proxy.rs b/state-machines/mem/src/mem_proxy.rs index 33c4ce9a..79c671ad 100644 --- a/state-machines/mem/src/mem_proxy.rs +++ b/state-machines/mem/src/mem_proxy.rs @@ -1,16 +1,15 @@ use std::collections::VecDeque; -use std::default; +use std::fmt; use std::sync::{ atomic::{AtomicU32, Ordering}, Arc, }; -use crate::{MemAlignResponse, MemAlignRomSM, MemAlignSM, MemOp, MemSM}; +use crate::{MemAlignResponse, MemAlignRomSM, MemAlignSM, MemSM}; use p3_field::PrimeField; use pil_std_lib::Std; -use proofman_common::StepsParams; use proofman_util::{timer_start_debug, timer_stop_and_log_debug}; -use zisk_core::{elf2rom, ZiskRequiredMemory, ZiskRequiredOperation, RAM_ADDR}; +use zisk_core::ZiskRequiredMemory; use proofman::{WitnessComponent, WitnessManager}; @@ -34,6 +33,18 @@ struct MemModuleData { pub flush_input_size: u64, } +impl fmt::Debug for MemAlignResponse { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "more:{} step:{} value:{:2}({:3})", + self.more_address, + self.step, + format_hex(self.value.unwrap_or(0)), + self.value.unwrap_or(0) + ) + } +} pub struct MemProxy { // Count of registered predecessors registered_predecessors: AtomicU32, @@ -95,12 +106,15 @@ impl MemProxy { pub fn unregister_predecessor(&self) { if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { + for module in self.modules.iter() { + module.unregister_predecessor(); + } // self.mem_sm.unregister_predecessor(); self.mem_align_sm.unregister_predecessor(); } } - pub fn init_module(module: &Arc>) -> MemModuleData { + fn init_module(module: &Arc>) -> MemModuleData { module.register_predecessor(); let ranges = module.get_addr_ranges(); let flush_input_size = module.get_flush_input_size(); @@ -128,37 +142,42 @@ impl MemProxy { input: &mut Vec, ) -> u64 { // Prepare aligned memory access - input.push(ZiskRequiredMemory { + let read = ZiskRequiredMemory { step: mem_align_op.step, is_write: false, address: mem_addr, width: MEM_BYTES, value: mem_value, - }); + }; + println!(" ##SEND2## mem_op: {0:?}", read); + input.push(read); + if mem_op.is_write { - input.push(ZiskRequiredMemory { + let mem_value = mem_align_op.value.expect("value returned by mem_align"); + let write = ZiskRequiredMemory { step: mem_align_op.step + 1, is_write: true, address: mem_addr, width: MEM_BYTES, - value: mem_align_op.mem_value, - }); - mem_align_op.mem_value + value: mem_value, + }; + println!(" ##SEND2## mem_op: {0:?}", write); + input.push(write); + mem_value } else { mem_value } } - fn mem_align_call( - mem_op: &ZiskRequiredMemory, - mem_values: [u64; 2], - phase: u8, - ) -> MemAlignResponse { - let mem_align_res = MemAlignResponse { more_address: false, step: 0, mem_value: 0 }; - mem_align_res + fn create_modules_inputs(&self) -> Vec> { + let mut mem_module_inputs: Vec> = Default::default(); + for module in self.modules.iter() { + mem_module_inputs.push(Vec::new()); + } + mem_module_inputs } fn get_mem_module_id(&self, address: u64) -> (usize, u64) { let mem_module_id = 0; - let next_addr_to_reevaluate = 0; + let next_addr_to_reevaluate = 0xFFFF_FFFF_FFFF; (mem_module_id, next_addr_to_reevaluate) } pub fn prove( @@ -166,7 +185,7 @@ impl MemProxy { mem_operations: &mut Vec, ) -> Result<(), Box> { let mut open_mem_align_ops: VecDeque = VecDeque::new(); - let mut mem_module_inputs: [Vec; 2] = Default::default(); + let mut mem_module_inputs = self.create_modules_inputs(); // Step 1. Sort the aligned memory accesses // original vector is sorted by step, sort_by_key is stable, no reordering of elements with @@ -188,13 +207,19 @@ impl MemProxy { value: 0, }); - let (mem_module_id, next_addr_to_reevaluate) = if mem_operations.is_empty() { + // Initialize the module id and next module address to reevaluate the module id, it's done + // to avoid check on each loop if memory address is inside one range or other + let (mut mem_module_id, mut next_module_addr) = if mem_operations.is_empty() { (0, 0) } else { self.get_mem_module_id(mem_operations[0].address) }; for mem_op in mem_operations.iter_mut() { + println!( + "##LOOP## mem_op: {0:?} 0x{1:#08X}({1}) 0x{2:#016X}({2})", + mem_op, last_addr, last_value + ); let mut aligned_mem_address = mem_op.address & MEM_ADDR_MASK; // Check if there are open mem align operations to be processed in this moment. Two possible @@ -204,20 +229,24 @@ impl MemProxy { // operation is less than the step of the current operation. while open_mem_align_ops.len() > 0 - || open_mem_align_ops[0].address < aligned_mem_address - || (open_mem_align_ops[0].address == aligned_mem_address - && open_mem_align_ops[0].mem_op.step < mem_op.step) + && (open_mem_align_ops[0].address < aligned_mem_address + || (open_mem_align_ops[0].address == aligned_mem_address + && open_mem_align_ops[0].mem_op.step < mem_op.step)) { let open_op = open_mem_align_ops.pop_front().unwrap(); let mem_value = if open_op.address == last_addr { last_value } else { 0 }; // call to mem_align to get information of the aligned memory access needed // to prove the unaligned open operation. - let mem_align_op = Self::mem_align_call(&open_op.mem_op, [mem_value, 0], 1); + let mem_align_op = mem_align_call(&open_op.mem_op, [mem_value, 0], 1); // remove element from top of queue, because we are on last phase, phase 1. open_mem_align_ops.pop_front(); + // check if need to reevaluate the module id + if open_op.address >= next_module_addr { + (mem_module_id, next_module_addr) = self.get_mem_module_id(open_op.address); + } // push the aligned memory operations for current address (read or read+write) and // update last_address and last_value. last_value = self.push_mem_align_op( @@ -228,7 +257,13 @@ impl MemProxy { &mut mem_module_inputs[mem_module_id], ); last_addr = open_op.address; - // TODO: check if flush is needed + + // check if need to flush the inputs of the module + if (mem_module_inputs[mem_module_id].len() as u64) + >= self.modules_data[mem_module_id].flush_input_size + { + self.modules[mem_module_id].send_inputs(&mut mem_module_inputs[mem_module_id]); + } } aligned_mem_address = mem_op.address & MEM_ADDR_MASK; @@ -243,12 +278,17 @@ impl MemProxy { break; } + // check if need to reevaluate the module id + if aligned_mem_address >= next_module_addr { + (mem_module_id, next_module_addr) = self.get_mem_module_id(aligned_mem_address); + } + let mem_value = if aligned_mem_address == last_addr { last_value } else { 0 }; // all open mem align operations are processed, check if new mem operation is aligned if !Self::is_aligned(&mem_op) { // In this point found non-aligned memory access, phase-0 - let mem_align_op = Self::mem_align_call(mem_op, [mem_value, 0], 0); + let mem_align_op = mem_align_call(mem_op, [mem_value, 0], 0); if mem_align_op.more_address { open_mem_align_ops.push_back(MemAlignOperation { address: aligned_mem_address + MEM_BYTES, @@ -256,21 +296,26 @@ impl MemProxy { mem_value: [mem_value, 0], }); } - self.push_mem_align_op( + last_value = self.push_mem_align_op( aligned_mem_address, mem_value, &mem_op, &mem_align_op, &mut mem_module_inputs[mem_module_id], ); + last_addr = aligned_mem_address } else { + println!(" ##SEND1## mem_op: {0:?}", mem_op); mem_module_inputs[mem_module_id].push(mem_op.clone()); + last_value = mem_op.value; + last_addr = aligned_mem_address } + + // check if need to flush the inputs of the module if (mem_module_inputs[mem_module_id].len() as u64) >= self.modules_data[mem_module_id].flush_input_size { - let module = &self.modules[mem_module_id]; - module.send_inputs(&mem_module_inputs[mem_module_id]); + self.modules[mem_module_id].send_inputs(&mut mem_module_inputs[mem_module_id]); } } @@ -279,3 +324,74 @@ impl MemProxy { } impl WitnessComponent for MemProxy {} + +fn format_hex(value: u64) -> String { + let hex_str = format!("{:016x}", value); // Format hexadecimal amb 16 dígits i padding de 0s + hex_str + .as_bytes() // Converteix a bytes per manipular fàcilment + .chunks(4) // Separa en grups de 4 caràcters (2 bytes) + .map(|chunk| std::str::from_utf8(chunk).unwrap()) // Converteix cada chunk a &str + .collect::>() // Recull els chunks com a un vector + .join("_") // Uneix amb "_" +} + +fn mem_align_call( + mem_op: &ZiskRequiredMemory, + mem_values: [u64; 2], + phase: u8, +) -> MemAlignResponse { + // DEBUG: only for testing + let offset = (mem_op.address & 0x7) * 8; + let width = (mem_op.width as u64) * 8; + let double_address = (offset + width) > 64; + let mem_value = mem_values[phase as usize]; + let mask = 0xFFFF_FFFF_FFFF_FFFFu64 >> (64 - width); + /*println!("width: {} offset:{}", width, offset); + println!("mem_value {}", format_hex(mem_value)); + println!("mask {}", format_hex(mask));*/ + if mem_op.is_write { + if phase == 0 { + /*println!("mask1 {}", format_hex(mask << offset)); + println!("mask2 {}", format_hex(0xFFFF_FFFF_FFFF_FFFFu64 ^ (mask << offset))); + println!( + "mask3 {}", + format_hex((mem_value & (0xFFFF_FFFF_FFFF_FFFFu64 ^ (mask << offset)))) + ); + println!("mask4 {}", format_hex((mem_op.value & mask) << offset));*/ + MemAlignResponse { + more_address: double_address, + step: mem_op.step + 1, + value: Some( + (mem_value & (0xFFFF_FFFF_FFFF_FFFFu64 ^ (mask << offset))) + | ((mem_op.value & mask) << offset), + ), + } + } else { + /* println!("{} bits = {} bytes", (offset + width - 64), (offset + width - 64) >> 3); + println!("ph1_1 {}", format_hex(mask << offset)); + println!( + "ph1_2 {}", + format_hex(0xFFFF_FFFF_FFFF_FFFFu64 << (offset + width - 64)) + ); + println!( + "ph1_3 {}", + format_hex(mem_value & (0xFFFF_FFFF_FFFF_FFFFu64 << (offset + width - 64))) + ); + println!("ph1_4 {}", format_hex((mem_op.value & mask) >> (128 - offset - width)));*/ + MemAlignResponse { + more_address: false, + step: mem_op.step + 1, + value: Some( + (mem_value & (0xFFFF_FFFF_FFFF_FFFFu64 << (offset + width - 64))) + | ((mem_op.value & mask) >> (128 - offset - width)), + ), + } + } + } else { + MemAlignResponse { + more_address: double_address && phase == 0, + step: mem_op.step + 1, + value: None, + } + } +} diff --git a/state-machines/mem/src/mem_sm.rs b/state-machines/mem/src/mem_sm.rs index 50f8fae4..fe54e310 100644 --- a/state-machines/mem/src/mem_sm.rs +++ b/state-machines/mem/src/mem_sm.rs @@ -10,13 +10,14 @@ use proofman_common::AirInstance; use rayon::prelude::*; use sm_common::create_prover_buffer; -use zisk_core::{Mem, ZiskRequiredMemory}; +use zisk_core::ZiskRequiredMemory; use zisk_pil::{MemTrace, MEM_AIR_IDS, ZISK_AIRGROUP_ID}; pub struct MemSM { // Witness computation manager wcm: Arc>, + num_rows: usize, // Count of registered predecessors registered_predecessors: AtomicU32, } @@ -24,7 +25,13 @@ pub struct MemSM { #[allow(unused, unused_variables)] impl MemSM { pub fn new(wcm: Arc>) -> Arc { - let mem_sm = Self { wcm: wcm.clone(), registered_predecessors: AtomicU32::new(0) }; + let pctx = wcm.get_pctx(); + let air = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_AIR_IDS[0]); + let mem_sm = Self { + wcm: wcm.clone(), + num_rows: air.num_rows(), + registered_predecessors: AtomicU32::new(0), + }; let mem_sm = Arc::new(mem_sm); wcm.register_component(mem_sm.clone(), Some(ZISK_AIRGROUP_ID), Some(MEM_AIR_IDS)); @@ -40,7 +47,7 @@ impl MemSM { if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 {} } - pub fn prove(&self, mem_accesses: &mut [ZiskRequiredMemory]) { + pub fn prove(&self, mem_accesses: &[ZiskRequiredMemory]) { // Sort the (full) aligned memory accesses let pctx = self.wcm.get_pctx(); @@ -202,6 +209,7 @@ impl MemSM { let first_addr_access_is_read = addr_changes && !mem_op.is_write; trace[i].first_addr_access_is_read = if first_addr_access_is_read { F::one() } else { F::zero() }; + assert!(trace[i].sel.is_zero() || trace[i].sel.is_one()); } // STEP3. Add dummy rows to the output vector to fill the remaining rows @@ -257,12 +265,14 @@ impl MemSM { } impl MemModule for MemSM { - fn send_inputs(&self, mem_op: &[ZiskRequiredMemory]) {} + fn send_inputs(&self, mem_op: &[ZiskRequiredMemory]) { + self.prove(&mem_op); + } fn get_addr_ranges(&self) -> Vec<(u64, u64)> { vec![] } fn get_flush_input_size(&self) -> u64 { - 0 + self.num_rows as u64 } fn unregister_predecessor(&self) {} fn register_predecessor(&self) {} From 00fd718a6be24650c735fbfa91f668aca4b2d16b Mon Sep 17 00:00:00 2001 From: zkronos73 Date: Sat, 9 Nov 2024 10:53:09 +0000 Subject: [PATCH 27/44] filter some address to test --- state-machines/mem/src/mem_proxy.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/state-machines/mem/src/mem_proxy.rs b/state-machines/mem/src/mem_proxy.rs index 79c671ad..1ea0bd01 100644 --- a/state-machines/mem/src/mem_proxy.rs +++ b/state-machines/mem/src/mem_proxy.rs @@ -222,6 +222,11 @@ impl MemProxy { ); let mut aligned_mem_address = mem_op.address & MEM_ADDR_MASK; + // ONLY TO TEST + if aligned_mem_address < 0xA0000000 { + continue; + } + // Check if there are open mem align operations to be processed in this moment. Two possible // conditions to process open mem align operations: // 1) the address of open operation is less than the aligned address. From f6fbc3db497d1422bf5a248c6acb8ea7f6b6e781 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9ctor=20Masip?= Date: Sat, 9 Nov 2024 16:44:41 +0000 Subject: [PATCH 28/44] wip --- state-machines/mem/src/mem_align_sm.rs | 194 ++++++++++++++++--------- state-machines/mem/src/mem_proxy.rs | 37 +---- 2 files changed, 132 insertions(+), 99 deletions(-) diff --git a/state-machines/mem/src/mem_align_sm.rs b/state-machines/mem/src/mem_align_sm.rs index 76f637cb..07d034a3 100644 --- a/state-machines/mem/src/mem_align_sm.rs +++ b/state-machines/mem/src/mem_align_sm.rs @@ -24,6 +24,7 @@ const CHUNK_NUM: usize = 8; const CHUNK_NUM_U64: u64 = CHUNK_NUM as u64; const CHUNK_BITS: usize = 8; const CHUNK_BITS_U64: u64 = CHUNK_BITS as u64; +const OFFSET_MASK: u64 = CHUNK_NUM_U64 - 1; const CHUNK_BITS_MASK: u64 = (1 << CHUNK_BITS) - 1; const ALLOWED_WIDTHS: [u64; 4] = [1, 2, 4, 8]; @@ -45,6 +46,7 @@ pub struct MemAlignSM { // Computed rows rows: Mutex>>, + num_computed_rows: Mutex, // TODO: DEBUG!!! // Secondary State machines mem_align_rom_sm: Arc>, @@ -63,6 +65,7 @@ impl MemAlignSM { std: std.clone(), registered_predecessors: AtomicU32::new(0), rows: Mutex::new(Vec::new()), + num_computed_rows: Mutex::new(0), mem_align_rom_sm, }; let mem_align_sm = Arc::new(mem_align_sm); @@ -111,11 +114,11 @@ impl MemAlignSM { pub fn get_mem_op( &self, input: &ZiskRequiredMemory, - mem_values: Vec, + mem_values: [u64; 2], phase: usize, ) -> MemAlignResponse { // Sanity check - assert!(mem_values.len() == phase + 1); // TODO: Debug mode + // assert!(mem_values.len() == phase + 1); // TODO: Debug mode let addr = input.address; let width = input.width; @@ -126,18 +129,24 @@ impl MemAlignSM { }; // Compute the offset - let offset = addr & CHUNK_BITS_MASK; + let offset = addr & OFFSET_MASK; let offset = if offset <= usize::MAX as u64 { offset as usize } else { panic!("Offset={} is too large", offset); }; - // main: [mem_op, addr, 1 + MAX_MEM_OPS_PER_MAIN_STEP * step + 2 * step_offset, bytes, ...value] - // mem: [wr * (MEMORY_STORE_OP - MEMORY_LOAD_OP) + MEMORY_LOAD_OP, addr * MEM_BYTES, step, MEM_BYTES, ...value] - // mem_align: [wr * (MEMORY_STORE_OP - MEMORY_LOAD_OP) + MEMORY_LOAD_OP, addr * CHUNK_NUM + offset, step, width, ...prove_val] + let num_rows = self.num_computed_rows.lock().unwrap(); // TODO: DEBUG!!! + match (input.is_write, offset + width > CHUNK_NUM) { (false, false) => { + println!("ONE READ"); + println!("NUM_ROWS: [{},{}]", num_rows, *num_rows + 1); + drop(num_rows); + println!("INPUT: {:?}", input); + println!("MEM_VALUES: {:?}", mem_values); + println!("PHASE: {:?}\n", phase); + // RV assert!(phase == 0); // TODO: Debug mode @@ -188,20 +197,22 @@ impl MemAlignSM { value_read & (CHUNK_BITS_MASK << (pos * CHUNK_BITS_U64)), ) }; - read_row.sel[i] = F::from_bool(true); + println!("READ_ROW[{}]: {:?}", i, read_row.reg[i]); + if i >= offset && i <= offset + width { + read_row.sel[i] = F::from_bool(true); + } value_row.reg[i] = { F::from_canonical_u64( value & (CHUNK_BITS_MASK - << (((shift + pos) % CHUNK_NUM_U64) * CHUNK_BITS_U64)), + << (((offset as u64 + pos) % CHUNK_NUM_U64) * CHUNK_BITS_U64)), ) }; - value_row.sel[i] = F::from_bool(i == offset as usize); - - // Store the range check - // *range_check.entry(read_row.reg[i]).or_insert(0) += 1; - // *range_check.entry(value_row.reg[i]).or_insert(0) += 1; + println!("VALUE_ROW[{}]: {:?}", i, value_row.reg[i]); + if i == offset { + value_row.sel[i] = F::from_bool(true); + } } // Prove the generated rows @@ -210,6 +221,13 @@ impl MemAlignSM { MemAlignResponse { more_address: false, step, value: None } } (true, false) => { + println!("ONE WRITE"); + println!("NUM_ROWS: [{},{}]", num_rows, *num_rows + 2); + drop(num_rows); + println!("INPUT: {:?}", input); + println!("MEM_VALUES: {:?}", mem_values); + println!("PHASE: {:?}\n", phase); + // RWV assert!(phase == 0); // TODO: Debug mode @@ -218,7 +236,7 @@ impl MemAlignSM { let value = input.value; // Compute the shift - let shift = ((offset + width) % CHUNK_NUM) as u64; + let shift = ((offset + width - 1) % CHUNK_NUM) as u64; // Get the aligned address let addr_read = addr >> CHUNK_BITS; @@ -231,16 +249,23 @@ impl MemAlignSM { // Compute the write value let value_write = { - let width_bytes: u64 = (1 << (width * CHUNK_BITS)) - 1; + // with:1 offset:4 + let width_bytes: u64 = (1 << (width * CHUNK_BITS)) - 1; // 0xFF + println!("WIDTH_BYTES: {:#X}", width_bytes); + + let mask: u64 = width_bytes << (offset * CHUNK_BITS); // 0x00_00_00_FF_00_00_00_00 + println!("MASK: {:#X}", mask); // Get the first width bytes of the unaligned value - let value_to_write = value & width_bytes; + let value_to_write = (value & width_bytes) << (offset * CHUNK_BITS); + println!("VALUE_TO_WRITE: {:#X}", value_to_write); // Write zeroes to value_read from offset to offset + width - let mask: u64 = width_bytes << (offset * CHUNK_BITS); + // and add the value to write to the value read - // Add the value to write to the value read - (value_read & !mask) | value_to_write + let result = (value_read & !mask) | value_to_write; + println!("RESULT: {:#X}", result); + result }; let mut read_row = MemAlignRow:: { @@ -287,28 +312,36 @@ impl MemAlignSM { value_read & (CHUNK_BITS_MASK << (pos * CHUNK_BITS_U64)), ) }; - read_row.sel[i] = F::from_bool(i >= width); + println!("READ_ROW[{}]: {:?}", i, read_row.reg[i]); + if i < offset || i > offset + width { + read_row.sel[i] = F::from_bool(true); + } write_row.reg[i] = { F::from_canonical_u64( value_write & (CHUNK_BITS_MASK << (pos * CHUNK_BITS_U64)), ) }; - write_row.sel[i] = F::from_bool(i < width); + println!("WRITE_ROW[{}]: {:?}", i, write_row.reg[i]); + if i >= offset && i <= offset + width { + write_row.sel[i] = F::from_bool(true); + } value_row.reg[i] = { - F::from_canonical_u64( - value - & (CHUNK_BITS_MASK - << (((shift + pos) % CHUNK_NUM_U64) * CHUNK_BITS_U64)), - ) + if i >= offset && i <= offset + width { + write_row.reg[i] + } else { + F::from_canonical_u64( + value + & (CHUNK_BITS_MASK + << (((shift + pos) % CHUNK_NUM_U64) * CHUNK_BITS_U64)), + ) + } }; - value_row.sel[i] = F::from_bool(i == offset as usize); - - // Store the range check - // *range_check.entry(read_row.reg[i]).or_insert(0) += 1; - // *range_check.entry(write_row.reg[i]).or_insert(0) += 1; - // *range_check.entry(value_row.reg[i]).or_insert(0) += 1; + println!("VALUE_ROW[{}]: {:?}", i, value_row.reg[i]); + if i == offset { + value_row.sel[i] = F::from_bool(true); + } } // Prove the generated rows @@ -326,6 +359,13 @@ impl MemAlignSM { // Otherwise, do the RVR 1 => { + println!("TWO READS"); + println!("NUM_ROWS: [{},{}]", num_rows, *num_rows + 2); + drop(num_rows); + println!("INPUT: {:?}", input); + println!("MEM_VALUES: {:?}", mem_values); + println!("PHASE: {:?}\n", phase); + assert!(mem_values.len() == 2); // TODO: Debug mode // Unaligned memory op information thrown into the bus @@ -391,7 +431,9 @@ impl MemAlignSM { value_first_read & (CHUNK_BITS_MASK << (pos * CHUNK_BITS_U64)), ) }; - first_read_row.sel[i] = F::from_bool(true); + if i >= offset && i <= offset + width { + first_read_row.sel[i] = F::from_bool(true); + } value_row.reg[i] = { F::from_canonical_u64( @@ -400,19 +442,18 @@ impl MemAlignSM { << (((shift + pos) % CHUNK_NUM_U64) * CHUNK_BITS_U64)), ) }; - value_row.sel[i] = F::from_bool(i == offset); + if i == offset { + value_row.sel[i] = F::from_bool(true); + } second_read_row.reg[i] = { F::from_canonical_u64( value_second_read & (CHUNK_BITS_MASK << (pos * CHUNK_BITS_U64)), ) }; - second_read_row.sel[i] = F::from_bool(true); - - // Store the range check - // *range_check.entry(first_read_row.reg[i]).or_insert(0) += 1; - // *range_check.entry(value_row.reg[i]).or_insert(0) += 1; - // *range_check.entry(second_read_row.reg[i]).or_insert(0) += 1; + if pos < shift { + second_read_row.sel[i] = F::from_bool(true); + } } // Prove the generated rows @@ -430,7 +471,7 @@ impl MemAlignSM { match phase { // If phase == 0, compute the resulting write value and ask for more 0 => { - assert!(mem_values.len() == 1); // TODO: Debug mode + // assert!(mem_values.len() == 1); // TODO: Debug mode // Unaligned memory op information thrown into the bus let value = input.value; @@ -441,15 +482,19 @@ impl MemAlignSM { // Compute the write value let value_first_write = { - let width_bytes: u64 = (1 << (width * CHUNK_BITS)) - 1; + // Normalize the width + let width_norm = CHUNK_NUM - offset; + let width_bytes: u64 = (1 << (width_norm * CHUNK_BITS)) - 1; + + let mask: u64 = width_bytes << (offset * CHUNK_BITS); + // Get the first width bytes of the unaligned value - let value_to_write = value & width_bytes; - + let value_to_write = (value & width_bytes) << (offset * CHUNK_BITS); + // Write zeroes to value_read from offset to offset + width - let mask = width_bytes << (offset * CHUNK_BITS); - - // Add the value to write to the value read + // and add the value to write to the value read + (value_first_read & !mask) | value_to_write }; @@ -461,6 +506,13 @@ impl MemAlignSM { } // Otherwise, do the RWVRW 1 => { + println!("TWO WRITES"); + println!("NUM_ROWS: [{},{}]", num_rows, *num_rows + 4); + drop(num_rows); + println!("INPUT: {:?}", input); + println!("MEM_VALUES: {:?}", mem_values); + println!("PHASE: {:?}\n", phase); + assert!(mem_values.len() == 2); // TODO: Debug mode // Unaligned memory op information thrown into the bus @@ -479,15 +531,19 @@ impl MemAlignSM { // Recompute the first write value let value_first_write = { - let width_bytes = (1 << (width * CHUNK_BITS)) - 1; + // Normalize the width + let width_norm = CHUNK_NUM - offset; + let width_bytes: u64 = (1 << (width_norm * CHUNK_BITS)) - 1; + + let mask: u64 = width_bytes << (offset * CHUNK_BITS); + // Get the first width bytes of the unaligned value - let value_to_write = value & width_bytes; - + let value_to_write = (value & width_bytes) << (offset * CHUNK_BITS); + // Write zeroes to value_read from offset to offset + width - let mask = width_bytes << (offset * CHUNK_BITS); - - // Add the value to write to the value read + // and add the value to write to the value read + (value_first_read & !mask) | value_to_write }; @@ -495,16 +551,20 @@ impl MemAlignSM { let value_second_read = mem_values[1]; // Compute the second write value - let value_second_write = { - let width_bytes = (1 << (width * CHUNK_BITS)) - 1; - + let value_second_write = { // TODO: Fix + // Normalize the width + let width_norm = CHUNK_NUM - offset; + + let width_bytes: u64 = (1 << (width_norm * CHUNK_BITS)) - 1; + + let mask: u64 = width_bytes << (offset * CHUNK_BITS); + // Get the first width bytes of the unaligned value - let value_to_write = value & width_bytes; - + let value_to_write = (value & width_bytes) << (offset * CHUNK_BITS); + // Write zeroes to value_read from offset to offset + width - let mask = width_bytes << (offset * CHUNK_BITS); - - // Add the value to write to the value read + // and add the value to write to the value read + (value_second_read & !mask) | value_to_write }; @@ -615,13 +675,6 @@ impl MemAlignSM { second_read_row.sel[i] = F::from_bool(pos >= shift); } - // Store the range check - // *range_check.entry(first_read_row.reg[i]).or_insert(0) += 1; - // *range_check.entry(first_write_row.reg[i]).or_insert(0) += 1; - // *range_check.entry(value_row.reg[i]).or_insert(0) += 1; - // *range_check.entry(second_write_row.reg[i]).or_insert(0) += 1; - // *range_check.entry(second_read_row.reg[i]).or_insert(0) += 1; - // Prove the generated rows self.prove(&[ first_read_row, @@ -647,6 +700,9 @@ impl MemAlignSM { if let Ok(mut rows) = self.rows.lock() { rows.extend_from_slice(computed_rows); + let mut num_rows = self.num_computed_rows.lock().unwrap(); // TODO: DEBUG!!! + *num_rows += computed_rows.len(); + let pctx = self.wcm.get_pctx(); let air_mem_align = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_ALIGN_AIR_IDS[0]); @@ -689,7 +745,7 @@ impl MemAlignSM { // Add the input rows to the trace for (i, &row) in rows.iter().enumerate() { - // Store the entire row + // Store the entire row trace_buffer[i] = row; // Store the value of all reg columns so that they can be range checked diff --git a/state-machines/mem/src/mem_proxy.rs b/state-machines/mem/src/mem_proxy.rs index 79c671ad..612c2e29 100644 --- a/state-machines/mem/src/mem_proxy.rs +++ b/state-machines/mem/src/mem_proxy.rs @@ -149,7 +149,6 @@ impl MemProxy { width: MEM_BYTES, value: mem_value, }; - println!(" ##SEND2## mem_op: {0:?}", read); input.push(read); if mem_op.is_write { @@ -161,7 +160,6 @@ impl MemProxy { width: MEM_BYTES, value: mem_value, }; - println!(" ##SEND2## mem_op: {0:?}", write); input.push(write); mem_value } else { @@ -216,12 +214,13 @@ impl MemProxy { }; for mem_op in mem_operations.iter_mut() { - println!( - "##LOOP## mem_op: {0:?} 0x{1:#08X}({1}) 0x{2:#016X}({2})", - mem_op, last_addr, last_value - ); let mut aligned_mem_address = mem_op.address & MEM_ADDR_MASK; + // ONLY TO TEST + if aligned_mem_address < 0xA0000000 { + continue; + } + // Check if there are open mem align operations to be processed in this moment. Two possible // conditions to process open mem align operations: // 1) the address of open operation is less than the aligned address. @@ -238,7 +237,7 @@ impl MemProxy { // call to mem_align to get information of the aligned memory access needed // to prove the unaligned open operation. - let mem_align_op = mem_align_call(&open_op.mem_op, [mem_value, 0], 1); + let mem_align_op = self.mem_align_sm.get_mem_op(&open_op.mem_op, [mem_value, 0], 1); // remove element from top of queue, because we are on last phase, phase 1. open_mem_align_ops.pop_front(); @@ -288,7 +287,7 @@ impl MemProxy { // all open mem align operations are processed, check if new mem operation is aligned if !Self::is_aligned(&mem_op) { // In this point found non-aligned memory access, phase-0 - let mem_align_op = mem_align_call(mem_op, [mem_value, 0], 0); + let mem_align_op = self.mem_align_sm.get_mem_op(mem_op, [mem_value, 0], 0); if mem_align_op.more_address { open_mem_align_ops.push_back(MemAlignOperation { address: aligned_mem_address + MEM_BYTES, @@ -305,7 +304,6 @@ impl MemProxy { ); last_addr = aligned_mem_address } else { - println!(" ##SEND1## mem_op: {0:?}", mem_op); mem_module_inputs[mem_module_id].push(mem_op.clone()); last_value = mem_op.value; last_addr = aligned_mem_address @@ -346,18 +344,8 @@ fn mem_align_call( let double_address = (offset + width) > 64; let mem_value = mem_values[phase as usize]; let mask = 0xFFFF_FFFF_FFFF_FFFFu64 >> (64 - width); - /*println!("width: {} offset:{}", width, offset); - println!("mem_value {}", format_hex(mem_value)); - println!("mask {}", format_hex(mask));*/ if mem_op.is_write { if phase == 0 { - /*println!("mask1 {}", format_hex(mask << offset)); - println!("mask2 {}", format_hex(0xFFFF_FFFF_FFFF_FFFFu64 ^ (mask << offset))); - println!( - "mask3 {}", - format_hex((mem_value & (0xFFFF_FFFF_FFFF_FFFFu64 ^ (mask << offset)))) - ); - println!("mask4 {}", format_hex((mem_op.value & mask) << offset));*/ MemAlignResponse { more_address: double_address, step: mem_op.step + 1, @@ -367,17 +355,6 @@ fn mem_align_call( ), } } else { - /* println!("{} bits = {} bytes", (offset + width - 64), (offset + width - 64) >> 3); - println!("ph1_1 {}", format_hex(mask << offset)); - println!( - "ph1_2 {}", - format_hex(0xFFFF_FFFF_FFFF_FFFFu64 << (offset + width - 64)) - ); - println!( - "ph1_3 {}", - format_hex(mem_value & (0xFFFF_FFFF_FFFF_FFFFu64 << (offset + width - 64))) - ); - println!("ph1_4 {}", format_hex((mem_op.value & mask) >> (128 - offset - width)));*/ MemAlignResponse { more_address: false, step: mem_op.step + 1, From 0c0d19a281c0e7cbef16ee73baf63eda6bc7e101 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9ctor=20Masip?= Date: Sun, 10 Nov 2024 01:44:26 +0000 Subject: [PATCH 29/44] Cleaning up --- state-machines/mem/pil/mem_align_rom.pil | 52 ++- state-machines/mem/src/mem_align_rom_sm.rs | 61 +-- state-machines/mem/src/mem_align_sm.rs | 486 +++------------------ 3 files changed, 117 insertions(+), 482 deletions(-) diff --git a/state-machines/mem/pil/mem_align_rom.pil b/state-machines/mem/pil/mem_align_rom.pil index 322bcd2d..016dd956 100644 --- a/state-machines/mem/pil/mem_align_rom.pil +++ b/state-machines/mem/pil/mem_align_rom.pil @@ -44,13 +44,14 @@ airtemplate MemAlignRom(const int N = MEM_ALIGN_ROM_SIZE, const int CHUNK_NUM = // RWVWR 5*4+10+10+15 = 55 | 188 => N = 2^8 // Moreover, offset is set to DEFAULT_OFFSET and width to DEFAULT_WIDTH in aligned memory accesses. + // offset == width == 0 is set at the very first row for padding // size - col fixed OFFSET = [[[0,0]:3, [0,1]:3, [0,2]:3, [0,3]:3, [0,4]:3, [0,5]:2, [0,6]:2, [0,7]], // RV 6+6*4+4+4+2 = 40 | 40 + col fixed OFFSET = [0, [[0,0]:3, [0,1]:3, [0,2]:3, [0,3]:3, [0,4]:3, [0,5]:2, [0,6]:2, [0,7]], // RV 6+6*4+4+4+2 = 40 | 40 [[0,0,0]:3, [0,0,1]:3, [0,0,2]:3, [0,0,3]:3, [0,0,4]:3, [0,0,5]:2, [0,0,6]:2, [0,0,7]], // RWV 9+9*4+6+6+3 = 60 | 100 [[0,1,0], [0,2,0], [0,3,0], [0,4,0], [0,5,0]:2, [0,6,0]:2, [0,7,0]:3], // RVR 3*4+6+6+9 = 33 | 133 [[0,0,1,0,0], [0,0,2,0,0], [0,0,3,0,0], [0,0,4,0,0], [0,0,5,0,0]:2, [0,0,6,0,0]:2, [0,0,7,0,0]:3]]...; // RWVWR 5*4+10+10+15 = 55 | 188 => N = 2^8 - col fixed WIDTH = [[[8,1,8,2,8,4]:5, [8,1,8,2]:2, [8,1]], // RV + col fixed WIDTH = [0, [[8,1,8,2,8,4]:5, [8,1,8,2]:2, [8,1]], // RV [[8,8,1,8,8,2,8,8,4]:5, [8,8,1,8,8,2]:2, [8,8,1]], // RWV [[8,8,8]:4, [8,4,8,8,8,8]:2, [8,2,8,8,4,8,8,8,8]], // RVR [[8,8,8,8,8]:4, [8,8,4,8,8,8,8,8,8,8]:2, [8,8,2,8,8,8,8,4,8,8,8,8,8,8,8]]]...; // RWVWR @@ -66,31 +67,31 @@ airtemplate MemAlignRom(const int N = MEM_ALIGN_ROM_SIZE, const int CHUNK_NUM = // WIDTH[i] = width; // } - // line | pc | pc'-pc | reset | addr | (addr-'addr)*(1-reset) | - // 0 | 0 | 1 | 1 | X1 | 0 | // (RV) - // 1 | 1 | -1 | 0 | X1 | 0 | - // 2 | 0 | 3 | 1 | X2 | 0 | // (RV) - // 3 | 3 | -3 | 0 | X2 | 0 | - // 4 | 0 | 5 | 1 | X3 | 0 | // (RV) - // 5 | 5 | -5 | 0 | X3 | 0 | - // 6 | 0 | 7 | 1 | ⋮ | ⋮ | // (RV) + // line | pc | pc'-pc | reset | addr | (addr-'addr)*(1-reset) | + // 0 | 0 | 0 | 0 | 0 | 0 | // for padding + // 1 | 0 | 1 | 1 | X1 | 0 | // (RV) + // 2 | 1 | -1 | 0 | X1 | 0 | + // 3 | 0 | 3 | 1 | X2 | 0 | // (RV) + // 4 | 3 | -3 | 0 | X2 | 0 | + // 5 | 0 | 5 | 1 | X3 | 0 | // (RV) + // 6 | 5 | -5 | 0 | X3 | 0 | + // 7 | 0 | 7 | 1 | ⋮ | ⋮ | // (RV) // ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | - // 40 | 0 | 41 | 1 | X4 | 0 | // (RWV) - // 41 | 41 | 1 | 0 | X4 | 0 | - // 42 | 42 | -42 | 0 | X4 | 0 | - // 43 | 0 | 44 | 1 | X5 | 0 | // (RWV) + // 41 | 0 | 41 | 1 | X4 | 0 | // (RWV) + // 42 | 41 | 1 | 0 | X4 | 0 | + // 43 | 42 | -42 | 0 | X4 | 0 | + // 44 | 0 | 44 | 1 | X5 | 0 | // (RWV) // ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | - // 100 | 0 | 101 | 1 | X6 | 0 | // (RVR) - // 101 |101 | 1 | 0 | X6 | 0 | - // 102 |102 | -102 | 0 | X6+1 | 1 | + // 101 | 0 | 101 | 1 | X6 | 0 | // (RVR) + // 102 |101 | 1 | 0 | X6 | 0 | + // 103 |102 | -102 | 0 | X6+1 | 1 | // ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | - // 133 | 0 | 134 | 1 | X7 | 0 | // (RWVWR) - // 134 |134 | 1 | 0 | X7 | 0 | - // 135 |135 | 1 | 0 | X7 | 0 | - // 136 |136 | 1 | 0 | X7+1 | 1 | - // 137 |137 | -137 | 0 | X7+1 | 1 | + // 134 | 0 | 134 | 1 | X7 | 0 | // (RWVWR) + // 135 |134 | 1 | 0 | X7 | 0 | + // 136 |135 | 1 | 0 | X7 | 0 | + // 137 |136 | 1 | 0 | X7+1 | 1 | + // 138 |137 | -137 | 0 | X7+1 | 1 | // ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | - // 188 | 0 | 0 | 0 | 0 | 0 | // for padding col fixed PC; col fixed DELTA_PC; @@ -112,7 +113,10 @@ airtemplate MemAlignRom(const int N = MEM_ALIGN_ROM_SIZE, const int CHUNK_NUM = const int line = i; const int next = i+1; - if (line < tsize[0]) // RV + if (line == 0) { // padding + // Do nothing + } + else if (line < tsize[0]) // RV { if (line % 2 == 0) { // pc = 0; diff --git a/state-machines/mem/src/mem_align_rom_sm.rs b/state-machines/mem/src/mem_align_rom_sm.rs index 606d0dad..59b46bcd 100644 --- a/state-machines/mem/src/mem_align_rom_sm.rs +++ b/state-machines/mem/src/mem_align_rom_sm.rs @@ -12,7 +12,7 @@ use proofman::{WitnessComponent, WitnessManager}; use proofman_common::AirInstance; use sm_common::create_prover_buffer; -use zisk_pil::{MemAlignRomTrace, MEM_ALIGN_ROM_AIR_IDS, ZISK_AIRGROUP_ID}; +use zisk_pil::{MemAlignRomRow, MemAlignRomTrace, MEM_ALIGN_ROM_AIR_IDS, ZISK_AIRGROUP_ID}; #[derive(Debug, Clone, Copy)] pub enum MemOp { @@ -146,11 +146,19 @@ impl MemAlignRomSM { } } - pub fn calculate_next_pc(op: MemOp, offset: usize, width: usize) -> u64 { - let rows = Self::calculate_rom_rows(op, offset, width); + pub fn calculate_next_pc(&self, op: MemOp, offset: usize, width: usize) -> u64 { + let row_idxs = Self::calculate_rom_rows(op, offset, width); + + // Update the multiplicity + self.update_multiplicity_by_idx(&row_idxs); // The "next" pc is always found on the second row of the program being executed - rows[1] + row_idxs[1] + } + + pub fn update_padding_row(&self, padding_len: u64) { + // Update entry at the padding row (pos = 0) with the given padding length + self.update_multiplicity(&[padding_len]); } pub fn update_multiplicity_by_input(&self, opcode: MemOp, offset: usize, width: usize) { @@ -175,32 +183,34 @@ impl MemAlignRomSM { } pub fn create_air_instance(&self) { - let pctx = self.wcm.get_pctx(); + // Get the contexts + let wcm = self.wcm.clone(); + let pctx = wcm.get_pctx(); + let ectx = wcm.get_ectx(); + let sctx = wcm.get_sctx(); + // Get the Mem Align ROM AIR let air_mem_align_rom = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_ALIGN_ROM_AIR_IDS[0]); + let air_mem_align_rom_rows = air_mem_align_rom.num_rows(); - // Create the prover buffer - let (mut prover_buffer, offset) = create_prover_buffer( - &self.wcm.get_ectx(), - &self.wcm.get_sctx(), - ZISK_AIRGROUP_ID, - MEM_ALIGN_ROM_AIR_IDS[0], - ); + // Create a prover buffer + let (mut prover_buffer, offset) = + create_prover_buffer(&ectx, &sctx, ZISK_AIRGROUP_ID, MEM_ALIGN_ROM_AIR_IDS[0]); + // Create the Mem Align ROM trace buffer let mut trace_buffer = MemAlignRomTrace::::map_buffer( &mut prover_buffer, - air_mem_align_rom.num_rows(), + air_mem_align_rom_rows, offset as usize, ) .unwrap(); - let mut multiplicity = self.multiplicity.lock().unwrap(); - - // for row_idx in multiplicity.keys() { - // trace_buffer[*row_idx as usize] = MemAlignRomRow { - // multiplicity: multiplicity - // }; - // } + if let Ok(multiplicity) = self.multiplicity.lock() { + for (row_idx, multiplicity) in multiplicity.iter() { + trace_buffer[*row_idx as usize] = + MemAlignRomRow { multiplicity: F::from_canonical_u64(*multiplicity) }; + } + } info!( "{}: ··· Creating Mem Align ROM instance [{} rows filled 100%]", @@ -208,14 +218,9 @@ impl MemAlignRomSM { self.num_rows, ); - let air_instance = AirInstance::new( - self.wcm.get_sctx(), - ZISK_AIRGROUP_ID, - MEM_ALIGN_ROM_AIR_IDS[0], - None, - prover_buffer, - ); - self.wcm.get_pctx().air_instance_repo.add_air_instance(air_instance, None); + let air_instance = + AirInstance::new(sctx, ZISK_AIRGROUP_ID, MEM_ALIGN_ROM_AIR_IDS[0], None, prover_buffer); + pctx.air_instance_repo.add_air_instance(air_instance, None); } } diff --git a/state-machines/mem/src/mem_align_sm.rs b/state-machines/mem/src/mem_align_sm.rs index 07d034a3..8fc146e7 100644 --- a/state-machines/mem/src/mem_align_sm.rs +++ b/state-machines/mem/src/mem_align_sm.rs @@ -154,9 +154,6 @@ impl MemAlignSM { let step = input.step; let value = input.value; - // Compute the shift - let shift = ((offset + width) % CHUNK_NUM) as u64; - // Get the aligned address let addr_read = addr >> CHUNK_BITS; @@ -164,7 +161,8 @@ impl MemAlignSM { let value_read = mem_values[phase]; // Get the next pc - let next_pc = MemAlignRomSM::::calculate_next_pc(MemOp::OneRead, offset, width); + let next_pc = + self.mem_align_rom_sm.calculate_next_pc(MemOp::OneRead, offset, width); let mut read_row = MemAlignRow:: { step: F::from_canonical_u64(step), @@ -245,7 +243,8 @@ impl MemAlignSM { let value_read = mem_values[phase]; // Get the next pc - let next_pc = MemAlignRomSM::::calculate_next_pc(MemOp::OneWrite, offset, width); + let next_pc = + self.mem_align_rom_sm.calculate_next_pc(MemOp::OneWrite, offset, width); // Compute the write value let value_write = { @@ -385,7 +384,7 @@ impl MemAlignSM { // Get the next pc let next_pc = - MemAlignRomSM::::calculate_next_pc(MemOp::TwoReads, offset, width); + self.mem_align_rom_sm.calculate_next_pc(MemOp::TwoReads, offset, width); let mut first_read_row = MemAlignRow:: { step: F::from_canonical_u64(step), @@ -486,15 +485,14 @@ impl MemAlignSM { let width_norm = CHUNK_NUM - offset; let width_bytes: u64 = (1 << (width_norm * CHUNK_BITS)) - 1; - + let mask: u64 = width_bytes << (offset * CHUNK_BITS); - + // Get the first width bytes of the unaligned value let value_to_write = (value & width_bytes) << (offset * CHUNK_BITS); - + // Write zeroes to value_read from offset to offset + width // and add the value to write to the value read - (value_first_read & !mask) | value_to_write }; @@ -520,7 +518,7 @@ impl MemAlignSM { let value = input.value; // Compute the shift - let shift = ((offset + width) % CHUNK_NUM) as u64; + let shift = (offset + width) % CHUNK_NUM; // Get the aligned address let addr_first_read_write = addr >> CHUNK_BITS; @@ -535,15 +533,14 @@ impl MemAlignSM { let width_norm = CHUNK_NUM - offset; let width_bytes: u64 = (1 << (width_norm * CHUNK_BITS)) - 1; - + let mask: u64 = width_bytes << (offset * CHUNK_BITS); - + // Get the first width bytes of the unaligned value let value_to_write = (value & width_bytes) << (offset * CHUNK_BITS); - + // Write zeroes to value_read from offset to offset + width // and add the value to write to the value read - (value_first_read & !mask) | value_to_write }; @@ -551,26 +548,27 @@ impl MemAlignSM { let value_second_read = mem_values[1]; // Compute the second write value - let value_second_write = { // TODO: Fix + let value_second_write = { + // TODO: Fix // Normalize the width - let width_norm = CHUNK_NUM - offset; + let width_bytes = (1 << (shift * CHUNK_BITS)) - 1; - let width_bytes: u64 = (1 << (width_norm * CHUNK_BITS)) - 1; - let mask: u64 = width_bytes << (offset * CHUNK_BITS); - + // Get the first width bytes of the unaligned value - let value_to_write = (value & width_bytes) << (offset * CHUNK_BITS); - + let value_to_write = value & width_bytes; + // Write zeroes to value_read from offset to offset + width // and add the value to write to the value read - (value_second_read & !mask) | value_to_write }; // Get the next pc - let next_pc = - MemAlignRomSM::::calculate_next_pc(MemOp::TwoWrites, offset, width); + let next_pc = self.mem_align_rom_sm.calculate_next_pc( + MemOp::TwoWrites, + offset, + width, + ); // RWVWR let mut first_read_row = MemAlignRow:: { @@ -641,23 +639,36 @@ impl MemAlignSM { value_first_read & (CHUNK_BITS_MASK << (pos * CHUNK_BITS_U64)), ) }; - first_read_row.sel[i] = F::from_bool(i < offset); + if i < offset { + first_read_row.sel[i] = F::from_bool(true); + } first_write_row.reg[i] = { F::from_canonical_u64( value_first_write & (CHUNK_BITS_MASK << (pos * CHUNK_BITS_U64)), ) }; - first_write_row.sel[i] = F::from_bool(i >= offset); + if i >= offset { + first_write_row.sel[i] = F::from_bool(true); + } value_row.reg[i] = { - F::from_canonical_u64( - value - & (CHUNK_BITS_MASK - << (((shift + pos) % CHUNK_NUM_U64) * CHUNK_BITS_U64)), - ) + if i < shift { + second_write_row.reg[i] + } else if i >= offset { + first_write_row.reg[i] + } else { + F::from_canonical_u64( + value + & (CHUNK_BITS_MASK + << (((shift as u64 + pos) % CHUNK_NUM_U64) + * CHUNK_BITS_U64)), + ) + } }; - value_row.sel[i] = F::from_bool(i == offset); + if i == offset { + value_row.sel[i] = F::from_bool(true); + } second_write_row.reg[i] = { F::from_canonical_u64( @@ -665,14 +676,18 @@ impl MemAlignSM { & (CHUNK_BITS_MASK << (pos * CHUNK_BITS_U64)), ) }; - second_write_row.sel[i] = F::from_bool(pos < shift); + if i < shift { + second_write_row.sel[i] = F::from_bool(true); + } second_read_row.reg[i] = { F::from_canonical_u64( value_second_read & (CHUNK_BITS_MASK << (pos * CHUNK_BITS_U64)), ) }; - second_read_row.sel[i] = F::from_bool(pos >= shift); + if i >= shift { + second_read_row.sel[i] = F::from_bool(true); + } } // Prove the generated rows @@ -729,8 +744,8 @@ impl MemAlignSM { assert!(rows_len <= air_mem_align_rows); // Get the execution and setup context - let ectx = self.wcm.get_ectx(); - let sctx = self.wcm.get_sctx(); + let ectx = wcm.get_ectx(); + let sctx = wcm.get_sctx(); // Create a prover buffer let (mut prover_buffer, offset) = @@ -775,18 +790,9 @@ impl MemAlignSM { std.range_check(value, F::from_canonical_u64(multiplicity), range_id); } - // TODO: Treate the ROM multiplicity - - // TODO: Store the padding multiplicity - // let mem_align_rom_sm = self.mem_align_rom_sm.clone(); - // let _padding_size = air_mem_align.num_rows() - rows_processed; - // for i in 0..8 { - // let multiplicity = padding_size as u64; - // let row = MemAlignRomSM::::calculate_rom_row( - // op, offset, width - // ); - // rom_multiplicity[row as usize] += multiplicity; - // } + // Compute the padding multiplicity + let mem_align_rom_sm = self.mem_align_rom_sm.clone(); + mem_align_rom_sm.update_padding_row(padding_size as u64); info!( "{}: ··· Creating Mem Align instance [{} / {} rows filled {:.2}%]", @@ -799,388 +805,8 @@ impl MemAlignSM { // Add a new Mem Align instance let air_instance = AirInstance::new(sctx, ZISK_AIRGROUP_ID, MEM_ALIGN_AIR_IDS[0], None, prover_buffer); - wcm.get_pctx().air_instance_repo.add_air_instance(air_instance, None); + pctx.air_instance_repo.add_air_instance(air_instance, None); } - - // #[inline(always)] - // pub fn process_input( - // unaligned_input: &ZiskRequiredMemory, - // aligned_inputs: &[ZiskRequiredMemory], - // mem_align_rom_sm: &MemAlignRomSM, - // range_check: &mut HashMap, - // ) -> Vec> { - // // Get the unaligned address - // let addr = unaligned_input.address; - - // // Get the unaligned value - // let value = unaligned_input.value.to_le_bytes(); - - // // Get the unaligned step - // let step = unaligned_input.step; - - // // Get the unaligned width - // let width = unaligned_input.width; - // let width = if width <= CHUNK_NUM_U64 { - // width as usize - // } else { - // panic!("Invalid width={}", width); - // }; - - // // Compute the offset - // let offset = addr % CHUNK_NUM_U64; - // let offset = if offset <= usize::MAX as u64 { - // offset as usize - // } else { - // panic!("Invalid offset={}", offset); - // }; - - // // Compute the shift - // let shift = (offset + width) % CHUNK_NUM; - - // // Get the op to be executed, its size and the pc to jump to - // let op = Self::get_mem_op(&unaligned_input); - // let op_size = MemAlignRomSM::::get_mem_align_op_size(op); - // let next_pc = MemAlignRomSM::::calculate_next_pc(op, offset, width); - - // println!("OP: {:?}", op); - // println!("UNALIGNED INPUT:\n {:?}", unaligned_input); - // println!(" OFFSET: {:?}", offset); - // println!(" value: {:?}", unaligned_input.value.to_le_bytes()); - // println!("ALIGNED INPUTS:"); - // for aligned_input in aligned_inputs { - // println!(" {:?}", aligned_input); - // println!(" value: {:?}", aligned_input.value.to_le_bytes()); - // } - // println!(""); - - // // Initialize and set the rows of the corresponding op - // let mut rows: Vec> = Vec::with_capacity(op_size); - // // TODO: Can I detatch the "shape" of the program from the mem_align and do it in the mem_align_rom? - // match op { - // MemOp::OneRead => { - // // RV - // // Sanity check - // assert!(aligned_inputs.len() == 1); - - // // Get the aligned address - // let addr_read = aligned_inputs[0].address; - - // // Get the aligned values - // let value_read = aligned_inputs[0].value.to_le_bytes(); - - // // Get the aligned step - // let step_read = aligned_inputs[0].step; - - // let mut read_row = MemAlignRow:: { - // step: F::from_canonical_u64(step_read), - // addr: F::from_canonical_u64(addr_read), - // // offset: F::from_canonical_u64(0), - // // wr: F::from_bool(false), - // // pc: F::from_canonical_u64(0), - // reset: F::from_bool(true), - // sel_up_to_down: F::from_bool(true), - // ..Default::default() - // }; - - // let mut value_row = MemAlignRow:: { - // step: F::from_canonical_u64(step), - // addr: F::from_canonical_u64(addr), - // offset: F::from_canonical_usize(offset), - // width: F::from_canonical_usize(width), - // // wr: F::from_bool(false), - // pc: F::from_canonical_u64(next_pc), - // // reset: F::from_bool(false), - // sel_prove: F::from_bool(true), - // ..Default::default() - // }; - - // for i in 0..CHUNK_NUM { - // read_row.reg[i] = F::from_canonical_u8(value_read[i]); - // read_row.sel[i] = F::from_bool(true); - - // value_row.reg[i] = F::from_canonical_u8(value[(shift + i) % CHUNK_NUM]); - // value_row.sel[i] = F::from_bool(i == offset); - - // // Store the range check - // *range_check.entry(read_row.reg[i]).or_insert(0) += 1; - // *range_check.entry(value_row.reg[i]).or_insert(0) += 1; - // } - - // // Store the rows - // rows.push(read_row); - // rows.push(value_row); - // } - // MemOp::OneWrite => { - // // RWV - // // Sanity check - // assert!(aligned_inputs.len() == 2); - - // // Get the aligned address - // let addr_read_write = aligned_inputs[0].address; - - // // Get the aligned values - // let value_read = aligned_inputs[0].value.to_le_bytes(); - // let value_write = aligned_inputs[1].value.to_le_bytes(); - - // // Get the aligned step - // let step_read = aligned_inputs[0].step; - // let step_write = aligned_inputs[1].step; - - // // RWV - // let mut read_row = MemAlignRow:: { - // step: F::from_canonical_u64(step_read), - // addr: F::from_canonical_u64(addr_read_write), - // // offset: F::from_canonical_u64(0), - // width: F::from_canonical_u64(CHUNK_NUM_U64), - // // wr: F::from_bool(false), - // // pc: F::from_canonical_u64(0), - // reset: F::from_bool(true), - // sel_up_to_down: F::from_bool(true), - // ..Default::default() - // }; - - // let mut write_row = MemAlignRow:: { - // step: F::from_canonical_u64(step_write), - // addr: F::from_canonical_u64(addr_read_write), - // // offset: F::from_canonical_u64(0), - // width: F::from_canonical_u64(CHUNK_NUM_U64), - // wr: F::from_bool(true), - // pc: F::from_canonical_u64(next_pc), - // // reset: F::from_bool(false), - // sel_up_to_down: F::from_bool(true), - // ..Default::default() - // }; - - // let mut value_row = MemAlignRow:: { - // step: F::from_canonical_u64(step), - // addr: F::from_canonical_u64(addr), - // offset: F::from_canonical_usize(offset), - // width: F::from_canonical_usize(width), - // // wr: F::from_bool(false), - // pc: F::from_canonical_u64(next_pc + 1), - // // reset: F::from_bool(false), - // sel_prove: F::from_bool(true), - // ..Default::default() - // }; - - // for i in 0..CHUNK_NUM { - // read_row.reg[i] = F::from_canonical_u8(value_read[i]); - // read_row.sel[i] = F::from_bool(i >= width); - - // write_row.reg[i] = F::from_canonical_u8(value_write[i]); - // write_row.sel[i] = F::from_bool(i < width); - - // value_row.reg[i] = F::from_canonical_u8(value[(shift + i) % CHUNK_NUM]); - // value_row.sel[i] = F::from_bool(i == offset); - - // // Store the range check - // *range_check.entry(read_row.reg[i]).or_insert(0) += 1; - // *range_check.entry(write_row.reg[i]).or_insert(0) += 1; - // *range_check.entry(value_row.reg[i]).or_insert(0) += 1; - // } - - // // Store the rows - // rows.push(read_row); - // rows.push(write_row); - // rows.push(value_row); - // } - // MemOp::TwoReads => { - // // RVR - // // Sanity check - // assert!(aligned_inputs.len() == 2); - - // // Get the aligned address - // let addr_first_read = aligned_inputs[0].address; - // let addr_second_read = aligned_inputs[1].address; - - // // Get the aligned values - // let value_first_read = aligned_inputs[0].value.to_le_bytes(); - // let value_second_read = aligned_inputs[1].value.to_le_bytes(); - - // // Get the aligned step - // let step_first_read = aligned_inputs[0].step; - // let step_second_read = aligned_inputs[1].step; - - // // RVR - // let mut first_read_row = MemAlignRow:: { - // step: F::from_canonical_u64(step_first_read), - // addr: F::from_canonical_u64(addr_first_read), - // // offset: F::from_canonical_u64(0), - // width: F::from_canonical_u64(CHUNK_NUM_U64), - // // wr: F::from_bool(false), - // // pc: F::from_canonical_u64(0), - // reset: F::from_bool(true), - // sel_up_to_down: F::from_bool(true), - // ..Default::default() - // }; - - // let mut value_row = MemAlignRow:: { - // step: F::from_canonical_u64(step), - // addr: F::from_canonical_u64(addr), - // offset: F::from_canonical_usize(offset), - // width: F::from_canonical_usize(width), - // // wr: F::from_bool(false), - // pc: F::from_canonical_u64(next_pc), - // // reset: F::from_bool(false), - // sel_prove: F::from_bool(true), - // ..Default::default() - // }; - - // let mut second_read_row = MemAlignRow:: { - // step: F::from_canonical_u64(step_second_read), - // addr: F::from_canonical_u64(addr_second_read), - // // offset: F::from_canonical_u64(0), - // width: F::from_canonical_u64(CHUNK_NUM_U64), - // // wr: F::from_bool(false), - // pc: F::from_canonical_u64(next_pc + 1), - // // reset: F::from_bool(false), - // sel_down_to_up: F::from_bool(true), - // ..Default::default() - // }; - - // for i in 0..CHUNK_NUM { - // first_read_row.reg[i] = F::from_canonical_u8(value_first_read[i]); - // first_read_row.sel[i] = F::from_bool(true); - - // value_row.reg[i] = F::from_canonical_u8(value[(shift + i) % CHUNK_NUM]); - // value_row.sel[i] = F::from_bool(i == offset); - - // second_read_row.reg[i] = F::from_canonical_u8(value_second_read[i]); - // second_read_row.sel[i] = F::from_bool(true); - - // // Store the range check - // *range_check.entry(first_read_row.reg[i]).or_insert(0) += 1; - // *range_check.entry(value_row.reg[i]).or_insert(0) += 1; - // *range_check.entry(second_read_row.reg[i]).or_insert(0) += 1; - // } - - // // Store the rows - // rows.push(first_read_row); - // rows.push(value_row); - // rows.push(second_read_row); - // } - // MemOp::TwoWrites => { - // // RWVWR - // // Sanity check - // assert!(aligned_inputs.len() == 4); - - // // Get the aligned address - // let addr_first_read_write = aligned_inputs[0].address; - // let addr_second_read_write = aligned_inputs[2].address; - - // // Get the aligned values - // // TODO: I do not need to establish an order, I can use the field is_write!!! - // let value_first_read = aligned_inputs[0].value.to_le_bytes(); - // let value_first_write = aligned_inputs[1].value.to_le_bytes(); - // let value_second_read = aligned_inputs[2].value.to_le_bytes(); - // let value_second_write = aligned_inputs[3].value.to_le_bytes(); - - // // Get the aligned step - // let step_first_read = aligned_inputs[0].step; - // let step_first_write = aligned_inputs[1].step; - // let step_second_read = aligned_inputs[2].step; - // let step_second_write = aligned_inputs[3].step; - - // // RWVWR - // let mut first_read_row = MemAlignRow:: { - // step: F::from_canonical_u64(step_first_read), - // addr: F::from_canonical_u64(addr_first_read_write), - // // offset: F::from_canonical_u64(0), - // width: F::from_canonical_u64(CHUNK_NUM_U64), - // // wr: F::from_bool(false), - // // pc: F::from_canonical_u64(0), - // reset: F::from_bool(true), - // sel_up_to_down: F::from_bool(true), - // ..Default::default() - // }; - - // let mut first_write_row = MemAlignRow:: { - // step: F::from_canonical_u64(step_first_write), - // addr: F::from_canonical_u64(addr_first_read_write), - // // offset: F::from_canonical_u64(0), - // width: F::from_canonical_u64(CHUNK_NUM_U64), - // wr: F::from_bool(true), - // pc: F::from_canonical_u64(next_pc), - // // reset: F::from_bool(false), - // sel_up_to_down: F::from_bool(true), - // ..Default::default() - // }; - - // let mut value_row = MemAlignRow:: { - // step: F::from_canonical_u64(step), - // addr: F::from_canonical_u64(addr), - // offset: F::from_canonical_usize(offset), - // width: F::from_canonical_usize(width), - // // wr: F::from_bool(false), - // pc: F::from_canonical_u64(next_pc + 1), - // // reset: F::from_bool(false), - // sel_prove: F::from_bool(true), - // ..Default::default() - // }; - - // let mut second_write_row = MemAlignRow:: { - // step: F::from_canonical_u64(step_second_write), - // addr: F::from_canonical_u64(addr_second_read_write), - // // offset: F::from_canonical_u64(0), - // width: F::from_canonical_u64(CHUNK_NUM_U64), - // wr: F::from_bool(true), - // pc: F::from_canonical_u64(next_pc + 2), - // // reset: F::from_bool(false), - // sel_down_to_up: F::from_bool(true), - // ..Default::default() - // }; - - // let mut second_read_row = MemAlignRow:: { - // step: F::from_canonical_u64(step_second_read), - // addr: F::from_canonical_u64(addr_second_read_write), - // // offset: F::from_canonical_u64(0), - // width: F::from_canonical_u64(CHUNK_NUM_U64), - // // wr: F::from_bool(false), - // pc: F::from_canonical_u64(next_pc + 3), - // reset: F::from_bool(false), - // sel_down_to_up: F::from_bool(true), - // ..Default::default() - // }; - - // for i in 0..CHUNK_NUM { - // first_read_row.reg[i] = F::from_canonical_u8(value_first_read[i]); - // first_read_row.sel[i] = F::from_bool(i < offset); - - // first_write_row.reg[i] = F::from_canonical_u8(value_first_write[i]); - // first_write_row.sel[i] = F::from_bool(i >= offset); - - // value_row.reg[i] = F::from_canonical_u8(value[(shift + i) % CHUNK_NUM]); - // value_row.sel[i] = F::from_bool(i == offset); - - // second_write_row.reg[i] = F::from_canonical_u8(value_second_write[i]); - // second_write_row.sel[i] = F::from_bool(i < shift); - - // second_read_row.reg[i] = F::from_canonical_u8(value_second_read[i]); - // second_read_row.sel[i] = F::from_bool(i >= shift); - - // // Store the range check - // *range_check.entry(first_read_row.reg[i]).or_insert(0) += 1; - // *range_check.entry(first_write_row.reg[i]).or_insert(0) += 1; - // *range_check.entry(value_row.reg[i]).or_insert(0) += 1; - // *range_check.entry(second_write_row.reg[i]).or_insert(0) += 1; - // *range_check.entry(second_read_row.reg[i]).or_insert(0) += 1; - // } - - // // Store the rows - // rows.push(first_read_row); - // rows.push(first_write_row); - // rows.push(value_row); - // rows.push(second_write_row); - // rows.push(second_read_row); - // } - // } - - // // Update the ROM row multiplicity - // mem_align_rom_sm.update_multiplicity_by_input(op, offset, width); - - // // Return successfully - // rows - // } } impl WitnessComponent for MemAlignSM {} From 66d20db0c7adab82b6069207266a77b42b8bb802 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9ctor=20Masip?= Date: Sun, 10 Nov 2024 04:02:08 +0000 Subject: [PATCH 30/44] Fix errors --- state-machines/mem/src/mem_align_sm.rs | 236 ++++++++++++++----------- 1 file changed, 128 insertions(+), 108 deletions(-) diff --git a/state-machines/mem/src/mem_align_sm.rs b/state-machines/mem/src/mem_align_sm.rs index 8fc146e7..446f240d 100644 --- a/state-machines/mem/src/mem_align_sm.rs +++ b/state-machines/mem/src/mem_align_sm.rs @@ -44,7 +44,7 @@ pub struct MemAlignSM { // Count of registered predecessors registered_predecessors: AtomicU32, - // Computed rows + // Computed row information rows: Mutex>>, num_computed_rows: Mutex, // TODO: DEBUG!!! @@ -91,7 +91,6 @@ impl MemAlignSM { if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { let pctx = self.wcm.get_pctx(); - // TODO: Fix this... // If there are remaining rows, generate the last instance if let Ok(mut rows) = self.rows.lock() { // Get the Mem Align AIR @@ -147,7 +146,15 @@ impl MemAlignSM { println!("MEM_VALUES: {:?}", mem_values); println!("PHASE: {:?}\n", phase); - // RV + /* RV with offset=2, width=4 + +----+----+====+====+====+====+----+----+ + | R0 | R1 | R2 | R3 | R4 | R5 | R6 | R7 | + +----+----+====+====+====+====+----+----+ + ⇓ + +----+----+====+====+====+====+----+----+ + | V6 | V7 | V0 | V1 | V2 | V3 | V4 | V5 | + +----+----+====+====+====+====+----+----+ + */ assert!(phase == 0); // TODO: Debug mode // Unaligned memory op information thrown into the bus @@ -158,7 +165,7 @@ impl MemAlignSM { let addr_read = addr >> CHUNK_BITS; // Get the aligned value - let value_read = mem_values[phase]; + let read_value = mem_values[phase]; // Get the next pc let next_pc = @@ -188,25 +195,14 @@ impl MemAlignSM { }; for i in 0..CHUNK_NUM { - let pos = i as u64; - - read_row.reg[i] = { - F::from_canonical_u64( - value_read & (CHUNK_BITS_MASK << (pos * CHUNK_BITS_U64)), - ) - }; + read_row.reg[i] = F::from_canonical_u64(Self::get_byte(read_value, i, 0)); println!("READ_ROW[{}]: {:?}", i, read_row.reg[i]); - if i >= offset && i <= offset + width { + if i >= offset && i < offset + width { read_row.sel[i] = F::from_bool(true); } - value_row.reg[i] = { - F::from_canonical_u64( - value - & (CHUNK_BITS_MASK - << (((offset as u64 + pos) % CHUNK_NUM_U64) * CHUNK_BITS_U64)), - ) - }; + value_row.reg[i] = + F::from_canonical_u64(Self::get_byte(value, i, CHUNK_NUM - offset)); println!("VALUE_ROW[{}]: {:?}", i, value_row.reg[i]); if i == offset { value_row.sel[i] = F::from_bool(true); @@ -226,21 +222,30 @@ impl MemAlignSM { println!("MEM_VALUES: {:?}", mem_values); println!("PHASE: {:?}\n", phase); - // RWV + /* RWV with offset=3, width=4 + +----+----+----+====+====+====+====+----+ + | R0 | R1 | R2 | R3 | R4 | R5 | R6 | R7 | + +----+----+----+====+====+====+====+----+ + ⇓ + +----+----+----+====+====+====+====+----+ + | W0 | W1 | W2 | W3 | W4 | W5 | W6 | W7 | + +----+----+----+====+====+====+====+----+ + ⇓ + +----+----+----+====+====+====+====+----+ + | V5 | V6 | V7 | V0 | V1 | V2 | V3 | V4 | + +----+----+----+====+====+====+====+----+ + */ assert!(phase == 0); // TODO: Debug mode // Unaligned memory op information thrown into the bus let step = input.step; let value = input.value; - // Compute the shift - let shift = ((offset + width - 1) % CHUNK_NUM) as u64; - // Get the aligned address let addr_read = addr >> CHUNK_BITS; // Get the aligned value - let value_read = mem_values[phase]; + let read_value = mem_values[phase]; // Get the next pc let next_pc = @@ -259,10 +264,10 @@ impl MemAlignSM { let value_to_write = (value & width_bytes) << (offset * CHUNK_BITS); println!("VALUE_TO_WRITE: {:#X}", value_to_write); - // Write zeroes to value_read from offset to offset + width + // Write zeroes to read_value from offset to offset + width // and add the value to write to the value read - let result = (value_read & !mask) | value_to_write; + let result = (read_value & !mask) | value_to_write; println!("RESULT: {:#X}", result); result }; @@ -304,37 +309,23 @@ impl MemAlignSM { }; for i in 0..CHUNK_NUM { - let pos = i as u64; - - read_row.reg[i] = { - F::from_canonical_u64( - value_read & (CHUNK_BITS_MASK << (pos * CHUNK_BITS_U64)), - ) - }; + read_row.reg[i] = F::from_canonical_u64(Self::get_byte(read_value, i, 0)); println!("READ_ROW[{}]: {:?}", i, read_row.reg[i]); - if i < offset || i > offset + width { + if i < offset || i >= offset + width { read_row.sel[i] = F::from_bool(true); } - write_row.reg[i] = { - F::from_canonical_u64( - value_write & (CHUNK_BITS_MASK << (pos * CHUNK_BITS_U64)), - ) - }; + write_row.reg[i] = F::from_canonical_u64(Self::get_byte(value_write, i, 0)); println!("WRITE_ROW[{}]: {:?}", i, write_row.reg[i]); - if i >= offset && i <= offset + width { + if i >= offset && i < offset + width { write_row.sel[i] = F::from_bool(true); } value_row.reg[i] = { - if i >= offset && i <= offset + width { + if i >= offset && i < offset + width { write_row.reg[i] } else { - F::from_canonical_u64( - value - & (CHUNK_BITS_MASK - << (((shift + pos) % CHUNK_NUM_U64) * CHUNK_BITS_U64)), - ) + F::from_canonical_u64(Self::get_byte(value, i, CHUNK_NUM - offset)) } }; println!("VALUE_ROW[{}]: {:?}", i, value_row.reg[i]); @@ -349,7 +340,19 @@ impl MemAlignSM { MemAlignResponse { more_address: false, step, value: Some(value_write) } } (false, true) => { - // RVR + /* RVR with offset=5, width=8 + +----+----+----+----+----+====+====+====+ + | R0 | R1 | R2 | R3 | R4 | R5 | R6 | R7 | + +----+----+----+----+----+====+====+====+ + ⇓ + +====+====+====+====+====+====+====+====+ + | V3 | V4 | V5 | V6 | V7 | V0 | V1 | V2 | + +====+====+====+====+====+====+====+====+ + ⇓ + +====+====+====+====+====+----+----+----+ + | R0 | R1 | R2 | R3 | R4 | R5 | R6 | R7 | + +====+====+====+====+====+----+----+----+ + */ assert!(phase == 0 || phase == 1); // TODO: Debug mode match phase { @@ -371,8 +374,8 @@ impl MemAlignSM { let step = input.step; let value = input.value; - // Compute the shift - let shift = ((offset + width) % CHUNK_NUM) as u64; + // Compute the remaining bytes + let rem_bytes = (offset + width) % CHUNK_NUM; // Get the aligned address let addr_first_read = addr >> CHUNK_BITS; @@ -423,34 +426,21 @@ impl MemAlignSM { }; for i in 0..CHUNK_NUM { - let pos = i as u64; - - first_read_row.reg[i] = { - F::from_canonical_u64( - value_first_read & (CHUNK_BITS_MASK << (pos * CHUNK_BITS_U64)), - ) - }; - if i >= offset && i <= offset + width { + first_read_row.reg[i] = + F::from_canonical_u64(Self::get_byte(value_first_read, i, 0)); + if i >= offset { first_read_row.sel[i] = F::from_bool(true); } - value_row.reg[i] = { - F::from_canonical_u64( - value - & (CHUNK_BITS_MASK - << (((shift + pos) % CHUNK_NUM_U64) * CHUNK_BITS_U64)), - ) - }; + value_row.reg[i] = + F::from_canonical_u64(Self::get_byte(value, i, CHUNK_NUM - offset)); if i == offset { value_row.sel[i] = F::from_bool(true); } - second_read_row.reg[i] = { - F::from_canonical_u64( - value_second_read & (CHUNK_BITS_MASK << (pos * CHUNK_BITS_U64)), - ) - }; - if pos < shift { + second_read_row.reg[i] = + F::from_canonical_u64(Self::get_byte(value_second_read, i, 0)); + if i < rem_bytes { second_read_row.sel[i] = F::from_bool(true); } } @@ -464,7 +454,27 @@ impl MemAlignSM { } } (true, true) => { - // RWVWR + /* RWVWR with offset=6, width=4 + +----+----+----+----+----+----+====+====+ + | R0 | R1 | R2 | R3 | R4 | R5 | R6 | R7 | + +----+----+----+----+----+----+====+====+ + ⇓ + +----+----+----+----+----+----+====+====+ + | W0 | W1 | W2 | W3 | W4 | W5 | W6 | W7 | + +----+----+----+----+----+----+====+====+ + ⇓ + +====+====+----+----+----+----+====+====+ + | V2 | V3 | V4 | V5 | V6 | V7 | V0 | V1 | + +====+====+----+----+----+----+====+====+ + ⇓ + +====+====+----+----+----+----+----+----+ + | W0 | W1 | W2 | W3 | W4 | W5 | W6 | W7 | + +====+====+----+----+----+----+----+----+ + ⇓ + +====+====+----+----+----+----+----+----+ + | R0 | R1 | R2 | R3 | R4 | R5 | R6 | R7 | + +====+====+----+----+----+----+----+----+ + */ assert!(phase == 0 || phase == 1); // TODO: Debug mode match phase { @@ -491,7 +501,7 @@ impl MemAlignSM { // Get the first width bytes of the unaligned value let value_to_write = (value & width_bytes) << (offset * CHUNK_BITS); - // Write zeroes to value_read from offset to offset + width + // Write zeroes to read_value from offset to offset + width // and add the value to write to the value read (value_first_read & !mask) | value_to_write }; @@ -504,6 +514,27 @@ impl MemAlignSM { } // Otherwise, do the RWVRW 1 => { + /* RWVWR with offset=6, width=4 + +----+----+----+----+----+----+====+====+ + | R0 | R1 | R2 | R3 | R4 | R5 | R6 | R7 | + +----+----+----+----+----+----+====+====+ + ⇓ + +----+----+----+----+----+----+====+====+ + | W0 | W1 | W2 | W3 | W4 | W5 | W6 | W7 | + +----+----+----+----+----+----+====+====+ + ⇓ + +====+====+----+----+----+----+====+====+ + | V2 | V3 | V4 | V5 | V6 | V7 | V0 | V1 | + +====+====+----+----+----+----+====+====+ + ⇓ + +====+====+----+----+----+----+----+----+ + | W0 | W1 | W2 | W3 | W4 | W5 | W6 | W7 | + +====+====+----+----+----+----+----+----+ + ⇓ + +====+====+----+----+----+----+----+----+ + | R0 | R1 | R2 | R3 | R4 | R5 | R6 | R7 | + +====+====+----+----+----+----+----+----+ + */ println!("TWO WRITES"); println!("NUM_ROWS: [{},{}]", num_rows, *num_rows + 4); drop(num_rows); @@ -518,7 +549,7 @@ impl MemAlignSM { let value = input.value; // Compute the shift - let shift = (offset + width) % CHUNK_NUM; + let rem_bytes = (offset + width) % CHUNK_NUM; // Get the aligned address let addr_first_read_write = addr >> CHUNK_BITS; @@ -539,7 +570,7 @@ impl MemAlignSM { // Get the first width bytes of the unaligned value let value_to_write = (value & width_bytes) << (offset * CHUNK_BITS); - // Write zeroes to value_read from offset to offset + width + // Write zeroes to read_value from offset to offset + width // and add the value to write to the value read (value_first_read & !mask) | value_to_write }; @@ -551,14 +582,14 @@ impl MemAlignSM { let value_second_write = { // TODO: Fix // Normalize the width - let width_bytes = (1 << (shift * CHUNK_BITS)) - 1; + let width_bytes = (1 << (rem_bytes * CHUNK_BITS)) - 1; let mask: u64 = width_bytes << (offset * CHUNK_BITS); // Get the first width bytes of the unaligned value let value_to_write = value & width_bytes; - // Write zeroes to value_read from offset to offset + width + // Write zeroes to read_value from offset to offset + width // and add the value to write to the value read (value_second_read & !mask) | value_to_write }; @@ -632,60 +663,44 @@ impl MemAlignSM { }; for i in 0..CHUNK_NUM { - let pos = i as u64; - - first_read_row.reg[i] = { - F::from_canonical_u64( - value_first_read & (CHUNK_BITS_MASK << (pos * CHUNK_BITS_U64)), - ) - }; + first_read_row.reg[i] = + F::from_canonical_u64(Self::get_byte(value_first_read, i, 0)); if i < offset { first_read_row.sel[i] = F::from_bool(true); } - first_write_row.reg[i] = { - F::from_canonical_u64( - value_first_write & (CHUNK_BITS_MASK << (pos * CHUNK_BITS_U64)), - ) - }; + first_write_row.reg[i] = + F::from_canonical_u64(Self::get_byte(value_first_write, i, 0)); if i >= offset { first_write_row.sel[i] = F::from_bool(true); } value_row.reg[i] = { - if i < shift { + if i < rem_bytes { second_write_row.reg[i] } else if i >= offset { first_write_row.reg[i] } else { - F::from_canonical_u64( - value - & (CHUNK_BITS_MASK - << (((shift as u64 + pos) % CHUNK_NUM_U64) - * CHUNK_BITS_U64)), - ) + F::from_canonical_u64(Self::get_byte( + value_first_write, + i, + CHUNK_NUM - offset, + )) } }; if i == offset { value_row.sel[i] = F::from_bool(true); } - second_write_row.reg[i] = { - F::from_canonical_u64( - value_second_write - & (CHUNK_BITS_MASK << (pos * CHUNK_BITS_U64)), - ) - }; - if i < shift { + second_write_row.reg[i] = + F::from_canonical_u64(Self::get_byte(value_second_write, i, 0)); + if i < rem_bytes { second_write_row.sel[i] = F::from_bool(true); } - second_read_row.reg[i] = { - F::from_canonical_u64( - value_second_read & (CHUNK_BITS_MASK << (pos * CHUNK_BITS_U64)), - ) - }; - if i >= shift { + second_read_row.reg[i] = + F::from_canonical_u64(Self::get_byte(value_second_read, i, 0)); + if i >= rem_bytes { second_read_row.sel[i] = F::from_bool(true); } } @@ -711,6 +726,11 @@ impl MemAlignSM { } } + fn get_byte(value: u64, index: usize, offset: usize) -> u64 { + let chunk = (offset + index) % CHUNK_NUM; + (value >> (chunk * CHUNK_BITS)) & CHUNK_BITS_MASK + } + pub fn prove(&self, computed_rows: &[MemAlignRow]) { if let Ok(mut rows) = self.rows.lock() { rows.extend_from_slice(computed_rows); From 2b511ece17873f4579f1131eeea146cfac5b2271 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9ctor=20Masip?= Date: Mon, 11 Nov 2024 06:41:45 +0000 Subject: [PATCH 31/44] mem align working --- state-machines/mem/src/mem_align_sm.rs | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/state-machines/mem/src/mem_align_sm.rs b/state-machines/mem/src/mem_align_sm.rs index 446f240d..fec8a776 100644 --- a/state-machines/mem/src/mem_align_sm.rs +++ b/state-machines/mem/src/mem_align_sm.rs @@ -425,6 +425,10 @@ impl MemAlignSM { ..Default::default() }; + println!("VALUE_FIRST_READ: {:?}", value_first_read.to_le_bytes()); + println!("VALUE: {:?}", value.to_le_bytes()); + println!("VALUE_SECOND_READ: {:?}", value_second_read.to_le_bytes()); + for i in 0..CHUNK_NUM { first_read_row.reg[i] = F::from_canonical_u64(Self::get_byte(value_first_read, i, 0)); @@ -433,7 +437,7 @@ impl MemAlignSM { } value_row.reg[i] = - F::from_canonical_u64(Self::get_byte(value, i, CHUNK_NUM - offset)); + F::from_canonical_u64(Self::get_byte(value, i, offset)); if i == offset { value_row.sel[i] = F::from_bool(true); } @@ -580,16 +584,15 @@ impl MemAlignSM { // Compute the second write value let value_second_write = { - // TODO: Fix // Normalize the width - let width_bytes = (1 << (rem_bytes * CHUNK_BITS)) - 1; + let width_norm = CHUNK_NUM - offset; - let mask: u64 = width_bytes << (offset * CHUNK_BITS); + let mask: u64 = (1 << (rem_bytes * CHUNK_BITS)) - 1; // Get the first width bytes of the unaligned value - let value_to_write = value & width_bytes; + let value_to_write = (value >> width_norm * CHUNK_BITS) & mask; - // Write zeroes to read_value from offset to offset + width + // Write zeroes to read_value from 0 to offset + width // and add the value to write to the value read (value_second_read & !mask) | value_to_write }; @@ -662,6 +665,11 @@ impl MemAlignSM { ..Default::default() }; + println!("VALUE_FIRST_READ: {:?}", value_first_read.to_le_bytes()); + println!("VALUE_FIRST_WRITE: {:?}", value_first_write.to_le_bytes()); + println!("VALUE: {:?}", value.to_le_bytes()); + println!("VALUE_SECOND_WRITE: {:?}", value_second_write.to_le_bytes()); + println!("VALUE_SECOND_READ: {:?}", value_second_read.to_le_bytes()); for i in 0..CHUNK_NUM { first_read_row.reg[i] = F::from_canonical_u64(Self::get_byte(value_first_read, i, 0)); @@ -682,7 +690,7 @@ impl MemAlignSM { first_write_row.reg[i] } else { F::from_canonical_u64(Self::get_byte( - value_first_write, + value, i, CHUNK_NUM - offset, )) From c5a86830c0409aa78ce7ed27326d7c6f1e5e15be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9ctor=20Masip?= Date: Wed, 20 Nov 2024 17:26:00 +0000 Subject: [PATCH 32/44] CArgo fmt --- core/src/zisk_required_operation.rs | 3 +- emulator/src/emu.rs | 28 +++++++-------- pil/src/pil_helpers/traces.rs | 2 -- state-machines/mem/src/mem_proxy.rs | 53 ++++++++++++++++------------- state-machines/mem/src/mem_sm.rs | 4 +-- state-machines/rom/src/rom.rs | 11 ++++-- witness-computation/src/executor.rs | 11 ++++-- 7 files changed, 62 insertions(+), 50 deletions(-) diff --git a/core/src/zisk_required_operation.rs b/core/src/zisk_required_operation.rs index 7702c71a..41056a6a 100644 --- a/core/src/zisk_required_operation.rs +++ b/core/src/zisk_required_operation.rs @@ -1,5 +1,4 @@ -use std::collections::HashMap; -use std::fmt; +use std::{collections::HashMap, fmt}; #[derive(Clone)] pub struct ZiskRequiredOperation { diff --git a/emulator/src/emu.rs b/emulator/src/emu.rs index 3e3bfc2e..10da5f38 100644 --- a/emulator/src/emu.rs +++ b/emulator/src/emu.rs @@ -488,9 +488,9 @@ impl<'a> Emu<'a> { } // Log emulation step, if requested - if options.print_step.is_some() - && (options.print_step.unwrap() != 0) - && ((self.ctx.inst_ctx.step % options.print_step.unwrap()) == 0) + if options.print_step.is_some() && + (options.print_step.unwrap() != 0) && + ((self.ctx.inst_ctx.step % options.print_step.unwrap()) == 0) { println!("step={}", self.ctx.inst_ctx.step); } @@ -689,9 +689,9 @@ impl<'a> Emu<'a> { // Increment step counter self.ctx.inst_ctx.step += 1; - if self.ctx.inst_ctx.end - || ((self.ctx.inst_ctx.step - self.ctx.last_callback_step) - == self.ctx.callback_steps) + if self.ctx.inst_ctx.end || + ((self.ctx.inst_ctx.step - self.ctx.last_callback_step) == + self.ctx.callback_steps) { // In run() we have checked the callback consistency with ctx.do_callback let callback = callback.as_ref().unwrap(); @@ -903,11 +903,11 @@ impl<'a> Emu<'a> { let mut current_box_id = 0; let mut current_step_idx = loop { - if current_box_id == vec_traces.len() - 1 - || vec_traces[current_box_id + 1].start_state.step >= emu_trace_start.step + if current_box_id == vec_traces.len() - 1 || + vec_traces[current_box_id + 1].start_state.step >= emu_trace_start.step { - break emu_trace_start.step as usize - - vec_traces[current_box_id].start_state.step as usize; + break emu_trace_start.step as usize - + vec_traces[current_box_id].start_state.step as usize; } current_box_id += 1; }; @@ -1018,8 +1018,8 @@ impl<'a> Emu<'a> { let b = [inst_ctx.b & 0xFFFFFFFF, (inst_ctx.b >> 32) & 0xFFFFFFFF]; let c = [inst_ctx.c & 0xFFFFFFFF, (inst_ctx.c >> 32) & 0xFFFFFFFF]; - let addr1 = (inst.b_offset_imm0 as i64 - + if inst.b_src == SRC_IND { inst_ctx.a as i64 } else { 0 }) as u64; + let addr1 = (inst.b_offset_imm0 as i64 + + if inst.b_src == SRC_IND { inst_ctx.a as i64 } else { 0 }) as u64; let jmp_offset1 = if inst.jmp_offset1 >= 0 { F::from_canonical_u64(inst.jmp_offset1 as u64) @@ -1097,8 +1097,8 @@ impl<'a> Emu<'a> { m32: F::from_bool(inst.m32), addr1: F::from_canonical_u64(addr1), __debug_operation_bus_enabled: F::from_bool( - inst.op_type == ZiskOperationType::Binary - || inst.op_type == ZiskOperationType::BinaryE, + inst.op_type == ZiskOperationType::Binary || + inst.op_type == ZiskOperationType::BinaryE, ), } } diff --git a/pil/src/pil_helpers/traces.rs b/pil/src/pil_helpers/traces.rs index da9b392c..1545cdca 100644 --- a/pil/src/pil_helpers/traces.rs +++ b/pil/src/pil_helpers/traces.rs @@ -46,5 +46,3 @@ trace!(SpecifiedRangesRow, SpecifiedRangesTrace { trace!(U8AirRow, U8AirTrace { mul: F, }); - - diff --git a/state-machines/mem/src/mem_proxy.rs b/state-machines/mem/src/mem_proxy.rs index 612c2e29..482ede2a 100644 --- a/state-machines/mem/src/mem_proxy.rs +++ b/state-machines/mem/src/mem_proxy.rs @@ -1,8 +1,10 @@ -use std::collections::VecDeque; -use std::fmt; -use std::sync::{ - atomic::{AtomicU32, Ordering}, - Arc, +use std::{ + collections::VecDeque, + fmt, + sync::{ + atomic::{AtomicU32, Ordering}, + Arc, + }, }; use crate::{MemAlignResponse, MemAlignRomSM, MemAlignSM, MemSM}; @@ -127,12 +129,13 @@ impl MemProxy { let aligned_mem_address = mem_op.address & MEM_ADDR_MASK; aligned_mem_address == mem_op.address && mem_op.width == MEM_BYTES } - /// Process information of mem_op and mem_align_op to push mem_op operation. Only two possible situations: + /// Process information of mem_op and mem_align_op to push mem_op operation. Only two possible + /// situations: /// 1) read, only on single mem_op is pushed /// 2) read+write, two mem_op are pushed, one read and one write. /// - /// This process is used for each aligned memory address, means that the "second part" of non aligned memory - /// operation is processed on addr + MEM_BYTES. + /// This process is used for each aligned memory address, means that the "second part" of non + /// aligned memory operation is processed on addr + MEM_BYTES. fn push_mem_align_op( &self, mem_addr: u64, @@ -221,16 +224,17 @@ impl MemProxy { continue; } - // Check if there are open mem align operations to be processed in this moment. Two possible - // conditions to process open mem align operations: + // Check if there are open mem align operations to be processed in this moment. Two + // possible conditions to process open mem align operations: // 1) the address of open operation is less than the aligned address. - // 2) the address of open operation is equal to the aligned address, but the step of the open + // 2) the address of open operation is equal to the aligned address, but the step of the + // open // operation is less than the step of the current operation. - while open_mem_align_ops.len() > 0 - && (open_mem_align_ops[0].address < aligned_mem_address - || (open_mem_align_ops[0].address == aligned_mem_address - && open_mem_align_ops[0].mem_op.step < mem_op.step)) + while open_mem_align_ops.len() > 0 && + (open_mem_align_ops[0].address < aligned_mem_address || + (open_mem_align_ops[0].address == aligned_mem_address && + open_mem_align_ops[0].mem_op.step < mem_op.step)) { let open_op = open_mem_align_ops.pop_front().unwrap(); let mem_value = if open_op.address == last_addr { last_value } else { 0 }; @@ -258,8 +262,8 @@ impl MemProxy { last_addr = open_op.address; // check if need to flush the inputs of the module - if (mem_module_inputs[mem_module_id].len() as u64) - >= self.modules_data[mem_module_id].flush_input_size + if (mem_module_inputs[mem_module_id].len() as u64) >= + self.modules_data[mem_module_id].flush_input_size { self.modules[mem_module_id].send_inputs(&mut mem_module_inputs[mem_module_id]); } @@ -267,7 +271,8 @@ impl MemProxy { aligned_mem_address = mem_op.address & MEM_ADDR_MASK; - // check if the aligned address is the last address to avoid processing the last fake mem_op + // check if the aligned address is the last address to avoid processing the last fake + // mem_op if aligned_mem_address == MEM_ADDR_MASK { assert!( open_mem_align_ops.len() == 0, @@ -310,8 +315,8 @@ impl MemProxy { } // check if need to flush the inputs of the module - if (mem_module_inputs[mem_module_id].len() as u64) - >= self.modules_data[mem_module_id].flush_input_size + if (mem_module_inputs[mem_module_id].len() as u64) >= + self.modules_data[mem_module_id].flush_input_size { self.modules[mem_module_id].send_inputs(&mut mem_module_inputs[mem_module_id]); } @@ -350,8 +355,8 @@ fn mem_align_call( more_address: double_address, step: mem_op.step + 1, value: Some( - (mem_value & (0xFFFF_FFFF_FFFF_FFFFu64 ^ (mask << offset))) - | ((mem_op.value & mask) << offset), + (mem_value & (0xFFFF_FFFF_FFFF_FFFFu64 ^ (mask << offset))) | + ((mem_op.value & mask) << offset), ), } } else { @@ -359,8 +364,8 @@ fn mem_align_call( more_address: false, step: mem_op.step + 1, value: Some( - (mem_value & (0xFFFF_FFFF_FFFF_FFFFu64 << (offset + width - 64))) - | ((mem_op.value & mask) >> (128 - offset - width)), + (mem_value & (0xFFFF_FFFF_FFFF_FFFFu64 << (offset + width - 64))) | + ((mem_op.value & mask) >> (128 - offset - width)), ), } } diff --git a/state-machines/mem/src/mem_sm.rs b/state-machines/mem/src/mem_sm.rs index fe54e310..50dc029b 100644 --- a/state-machines/mem/src/mem_sm.rs +++ b/state-machines/mem/src/mem_sm.rs @@ -202,8 +202,8 @@ impl MemSM { let addr_changes = trace[i - 1].addr != trace[i].addr; trace[i].addr_changes = if addr_changes { F::one() } else { F::zero() }; - let same_value = trace[i - 1].value[0] == trace[i].value[0] - && trace[i - 1].value[1] == trace[i].value[1]; + let same_value = trace[i - 1].value[0] == trace[i].value[0] && + trace[i - 1].value[1] == trace[i].value[1]; trace[i].same_value = if same_value { F::one() } else { F::zero() }; let first_addr_access_is_read = addr_changes && !mem_op.is_write; diff --git a/state-machines/rom/src/rom.rs b/state-machines/rom/src/rom.rs index 6970655d..db45391c 100644 --- a/state-machines/rom/src/rom.rs +++ b/state-machines/rom/src/rom.rs @@ -41,8 +41,13 @@ impl RomSM { let main_trace_len = self.wcm.get_pctx().pilout.get_air(ZISK_AIRGROUP_ID, MAIN_AIR_IDS[0]).num_rows(); - let prover_buffer = - Self::compute_trace_rom(rom, buffer_allocator, &sctx, pc_histogram, main_trace_len as u64)?; + let prover_buffer = Self::compute_trace_rom( + rom, + buffer_allocator, + &sctx, + pc_histogram, + main_trace_len as u64, + )?; let air_instance = AirInstance::new(sctx.clone(), ZISK_AIRGROUP_ID, ROM_AIR_IDS[0], None, prover_buffer); @@ -179,4 +184,4 @@ impl RomSM { } } -impl WitnessComponent for RomSM {} \ No newline at end of file +impl WitnessComponent for RomSM {} diff --git a/witness-computation/src/executor.rs b/witness-computation/src/executor.rs index 48e3fe56..9adf2504 100644 --- a/witness-computation/src/executor.rs +++ b/witness-computation/src/executor.rs @@ -83,7 +83,8 @@ impl ZiskExecutor { // TODO - If there is more than one Main AIR available, the MAX_ACCUMULATED will be the one // with the highest num_rows. It has to be a power of 2. - let main_sm = MainSM::new(wcm.clone(), mem_proxy_sm.clone(), arith_sm.clone(), binary_sm.clone()); + let main_sm = + MainSM::new(wcm.clone(), mem_proxy_sm.clone(), arith_sm.clone(), binary_sm.clone()); Self { zisk_rom, main_sm, rom_sm, mem_proxy_sm, binary_sm, arith_sm } } @@ -197,7 +198,11 @@ impl ZiskExecutor { // ---------------------------------------------- let mem_thread = thread::spawn({ let mem_proxy_sm = self.mem_proxy_sm.clone(); - move || mem_proxy_sm.prove(&mut mem_required).expect("Error during Memory witness computation") + move || { + mem_proxy_sm + .prove(&mut mem_required) + .expect("Error during Memory witness computation") + } }); // ROM State Machine @@ -290,4 +295,4 @@ impl ZiskExecutor { self.binary_sm.unregister_predecessor(); // self.arith_sm.register_predecessor(scope); } -} \ No newline at end of file +} From 99f4b8ca1c571922901c2e099dfa4c054265ee36 Mon Sep 17 00:00:00 2001 From: zkronos73 Date: Thu, 21 Nov 2024 06:28:30 +0000 Subject: [PATCH 33/44] WIP memory proxy --- core/src/zisk_required_operation.rs | 6 +- emulator/src/emu.rs | 14 +- pil/src/pil_helpers/traces.rs | 2 - state-machines/common/src/field.rs | 8 + state-machines/main/src/main_sm.rs | 33 ++ state-machines/mem/src/lib.rs | 4 + state-machines/mem/src/mem_align_sm.rs | 83 ++-- state-machines/mem/src/mem_constants.rs | 12 + state-machines/mem/src/mem_helpers.rs | 65 +++ state-machines/mem/src/mem_proxy.rs | 579 +++++++++++++----------- state-machines/mem/src/mem_sm.rs | 4 +- witness-computation/src/executor.rs | 26 +- 12 files changed, 514 insertions(+), 322 deletions(-) create mode 100644 state-machines/common/src/field.rs create mode 100644 state-machines/mem/src/mem_constants.rs create mode 100644 state-machines/mem/src/mem_helpers.rs diff --git a/core/src/zisk_required_operation.rs b/core/src/zisk_required_operation.rs index 7702c71a..f518824e 100644 --- a/core/src/zisk_required_operation.rs +++ b/core/src/zisk_required_operation.rs @@ -11,10 +11,10 @@ pub struct ZiskRequiredOperation { #[derive(Clone)] pub struct ZiskRequiredMemory { - pub step: u64, + pub address: u32, pub is_write: bool, - pub address: u64, - pub width: u64, + pub width: u8, + pub step: u64, pub value: u64, } diff --git a/emulator/src/emu.rs b/emulator/src/emu.rs index 3e3bfc2e..9c073204 100644 --- a/emulator/src/emu.rs +++ b/emulator/src/emu.rs @@ -111,7 +111,7 @@ impl<'a> Emu<'a> { let required_memory = ZiskRequiredMemory { step: self.ctx.inst_ctx.step, is_write: false, - address: addr, + address: addr as u32, width: 8, value: self.ctx.inst_ctx.a, }; @@ -185,7 +185,7 @@ impl<'a> Emu<'a> { let required_memory = ZiskRequiredMemory { step: self.ctx.inst_ctx.step, is_write: false, - address: addr, + address: addr as u32, width: 8, value: self.ctx.inst_ctx.b, }; @@ -204,8 +204,8 @@ impl<'a> Emu<'a> { let required_memory = ZiskRequiredMemory { step: self.ctx.inst_ctx.step, is_write: false, - address: addr, - width: instruction.ind_width, + address: addr as u32, + width: instruction.ind_width as u8, value: self.ctx.inst_ctx.b, }; emu_mem.push(required_memory); @@ -284,7 +284,7 @@ impl<'a> Emu<'a> { let required_memory = ZiskRequiredMemory { step: self.ctx.inst_ctx.step, is_write: true, - address: addr as u64, + address: addr as u32, width: 8, value: val as u64, }; @@ -306,8 +306,8 @@ impl<'a> Emu<'a> { let required_memory = ZiskRequiredMemory { step: self.ctx.inst_ctx.step, is_write: true, - address: addr as u64, - width: instruction.ind_width, + address: addr as u32, + width: instruction.ind_width as u8, value: val as u64, }; emu_mem.push(required_memory); diff --git a/pil/src/pil_helpers/traces.rs b/pil/src/pil_helpers/traces.rs index da9b392c..1545cdca 100644 --- a/pil/src/pil_helpers/traces.rs +++ b/pil/src/pil_helpers/traces.rs @@ -46,5 +46,3 @@ trace!(SpecifiedRangesRow, SpecifiedRangesTrace { trace!(U8AirRow, U8AirTrace { mul: F, }); - - diff --git a/state-machines/common/src/field.rs b/state-machines/common/src/field.rs new file mode 100644 index 00000000..55d2c919 --- /dev/null +++ b/state-machines/common/src/field.rs @@ -0,0 +1,8 @@ +pub fn i64_to_u64_field(value: i64) -> u64 { + const PRIME_MINUS_ONE: u64 = 0xFFFF_FFFF_0000_0000; + if value >= 0 { + value as u64 + } else { + PRIME_MINUS_ONE - (0xFFFF_FFFF_FFFF_FFFF - value as u64) + } +} diff --git a/state-machines/main/src/main_sm.rs b/state-machines/main/src/main_sm.rs index ca3d5c0a..000f830a 100644 --- a/state-machines/main/src/main_sm.rs +++ b/state-machines/main/src/main_sm.rs @@ -150,6 +150,39 @@ impl MainSM { segment_trace.steps[slice_start..slice_end].iter().enumerate() { partial_trace[i] = emu.step_slice_full_trace(emu_trace_step); + // if partial_trace[i].a_src_mem == F::one() { + // println!( + // "A=MEM_OP_RD({}) [{},{}] PC:{}", + // partial_trace[i].a_offset_imm0, + // partial_trace[i].a[0], + // partial_trace[i].a[1], + // partial_trace[i].pc + // ); + // } + // if partial_trace[i].b_src_mem == F::one() || partial_trace[i].b_src_ind == F::one() + // { + // println!( + // "B=MEM_OP_RD({0}) [{1},{2}] PC:{3}", + // partial_trace[i].addr1, + // partial_trace[i].b[0], + // partial_trace[i].b[1], + // partial_trace[i].pc + // ); + // } + // if partial_trace[i].b_src_mem == F::one() || partial_trace[i].b_src_ind == F::one() + // { + // println!( + // "MEM_OP_WR({}) [{}, {}] PC:{}", + // partial_trace[i].store_offset + // + partial_trace[i].store_ind * partial_trace[i].a[0], + // partial_trace[i].store_ra + // * (partial_trace[i].pc + partial_trace[i].jmp_offset2 + // - partial_trace[i].c[0]) + // + partial_trace[i].c[0], + // (F::one() - partial_trace[i].store_ra) * partial_trace[i].c[1], + // partial_trace[i].pc + // ); + // } } // if there are steps in the chunk update last row diff --git a/state-machines/mem/src/lib.rs b/state-machines/mem/src/lib.rs index f117ca7c..ab1d4209 100644 --- a/state-machines/mem/src/lib.rs +++ b/state-machines/mem/src/lib.rs @@ -1,9 +1,13 @@ mod mem_align_rom_sm; mod mem_align_sm; +mod mem_constants; +mod mem_helpers; mod mem_proxy; mod mem_sm; pub use mem_align_rom_sm::*; pub use mem_align_sm::*; +pub use mem_constants::*; +pub use mem_helpers::*; pub use mem_proxy::*; pub use mem_sm::*; diff --git a/state-machines/mem/src/mem_align_sm.rs b/state-machines/mem/src/mem_align_sm.rs index 07d034a3..12ca9dc0 100644 --- a/state-machines/mem/src/mem_align_sm.rs +++ b/state-machines/mem/src/mem_align_sm.rs @@ -27,7 +27,7 @@ const CHUNK_BITS_U64: u64 = CHUNK_BITS as u64; const OFFSET_MASK: u64 = CHUNK_NUM_U64 - 1; const CHUNK_BITS_MASK: u64 = (1 << CHUNK_BITS) - 1; -const ALLOWED_WIDTHS: [u64; 4] = [1, 2, 4, 8]; +const ALLOWED_WIDTHS: [u8; 4] = [1, 2, 4, 8]; pub struct MemAlignResponse { pub more_address: bool, @@ -129,7 +129,7 @@ impl MemAlignSM { }; // Compute the offset - let offset = addr & OFFSET_MASK; + let offset = addr as u64 & OFFSET_MASK; let offset = if offset <= usize::MAX as u64 { offset as usize } else { @@ -168,7 +168,7 @@ impl MemAlignSM { let mut read_row = MemAlignRow:: { step: F::from_canonical_u64(step), - addr: F::from_canonical_u64(addr_read), + addr: F::from_canonical_u32(addr_read), // offset: F::from_canonical_u64(0), // wr: F::from_bool(false), // pc: F::from_canonical_u64(0), @@ -179,7 +179,7 @@ impl MemAlignSM { let mut value_row = MemAlignRow:: { step: F::from_canonical_u64(step), - addr: F::from_canonical_u64(addr), + addr: F::from_canonical_u32(addr), offset: F::from_canonical_usize(offset), width: F::from_canonical_usize(width), // wr: F::from_bool(false), @@ -270,7 +270,7 @@ impl MemAlignSM { let mut read_row = MemAlignRow:: { step: F::from_canonical_u64(step), - addr: F::from_canonical_u64(addr_read), + addr: F::from_canonical_u32(addr_read), // offset: F::from_canonical_u64(0), width: F::from_canonical_u64(CHUNK_NUM_U64), // wr: F::from_bool(false), @@ -282,7 +282,7 @@ impl MemAlignSM { let mut write_row = MemAlignRow:: { step: F::from_canonical_u64(step + 1), - addr: F::from_canonical_u64(addr_read), + addr: F::from_canonical_u32(addr_read), // offset: F::from_canonical_u64(0), width: F::from_canonical_u64(CHUNK_NUM_U64), wr: F::from_bool(true), @@ -294,7 +294,7 @@ impl MemAlignSM { let mut value_row = MemAlignRow:: { step: F::from_canonical_u64(step), - addr: F::from_canonical_u64(addr), + addr: F::from_canonical_u32(addr), offset: F::from_canonical_usize(offset), width: F::from_canonical_usize(width), // wr: F::from_bool(false), @@ -389,7 +389,7 @@ impl MemAlignSM { let mut first_read_row = MemAlignRow:: { step: F::from_canonical_u64(step), - addr: F::from_canonical_u64(addr_first_read), + addr: F::from_canonical_u32(addr_first_read), // offset: F::from_canonical_u64(0), width: F::from_canonical_u64(CHUNK_NUM_U64), // wr: F::from_bool(false), @@ -401,7 +401,7 @@ impl MemAlignSM { let mut value_row = MemAlignRow:: { step: F::from_canonical_u64(step), - addr: F::from_canonical_u64(addr), + addr: F::from_canonical_u32(addr), offset: F::from_canonical_usize(offset), width: F::from_canonical_usize(width), // wr: F::from_bool(false), @@ -413,7 +413,7 @@ impl MemAlignSM { let mut second_read_row = MemAlignRow:: { step: F::from_canonical_u64(step), - addr: F::from_canonical_u64(addr_second_read), + addr: F::from_canonical_u32(addr_second_read), // offset: F::from_canonical_u64(0), width: F::from_canonical_u64(CHUNK_NUM_U64), // wr: F::from_bool(false), @@ -486,15 +486,15 @@ impl MemAlignSM { let width_norm = CHUNK_NUM - offset; let width_bytes: u64 = (1 << (width_norm * CHUNK_BITS)) - 1; - + let mask: u64 = width_bytes << (offset * CHUNK_BITS); - + // Get the first width bytes of the unaligned value let value_to_write = (value & width_bytes) << (offset * CHUNK_BITS); - + // Write zeroes to value_read from offset to offset + width // and add the value to write to the value read - + (value_first_read & !mask) | value_to_write }; @@ -535,15 +535,15 @@ impl MemAlignSM { let width_norm = CHUNK_NUM - offset; let width_bytes: u64 = (1 << (width_norm * CHUNK_BITS)) - 1; - + let mask: u64 = width_bytes << (offset * CHUNK_BITS); - + // Get the first width bytes of the unaligned value let value_to_write = (value & width_bytes) << (offset * CHUNK_BITS); - + // Write zeroes to value_read from offset to offset + width // and add the value to write to the value read - + (value_first_read & !mask) | value_to_write }; @@ -551,20 +551,21 @@ impl MemAlignSM { let value_second_read = mem_values[1]; // Compute the second write value - let value_second_write = { // TODO: Fix + let value_second_write = { + // TODO: Fix // Normalize the width let width_norm = CHUNK_NUM - offset; let width_bytes: u64 = (1 << (width_norm * CHUNK_BITS)) - 1; - + let mask: u64 = width_bytes << (offset * CHUNK_BITS); - + // Get the first width bytes of the unaligned value let value_to_write = (value & width_bytes) << (offset * CHUNK_BITS); - + // Write zeroes to value_read from offset to offset + width // and add the value to write to the value read - + (value_second_read & !mask) | value_to_write }; @@ -575,7 +576,7 @@ impl MemAlignSM { // RWVWR let mut first_read_row = MemAlignRow:: { step: F::from_canonical_u64(step), - addr: F::from_canonical_u64(addr_first_read_write), + addr: F::from_canonical_u32(addr_first_read_write), // offset: F::from_canonical_u64(0), width: F::from_canonical_u64(CHUNK_NUM_U64), // wr: F::from_bool(false), @@ -587,7 +588,7 @@ impl MemAlignSM { let mut first_write_row = MemAlignRow:: { step: F::from_canonical_u64(step + 1), - addr: F::from_canonical_u64(addr_first_read_write), + addr: F::from_canonical_u32(addr_first_read_write), // offset: F::from_canonical_u64(0), width: F::from_canonical_u64(CHUNK_NUM_U64), wr: F::from_bool(true), @@ -599,7 +600,7 @@ impl MemAlignSM { let mut value_row = MemAlignRow:: { step: F::from_canonical_u64(step), - addr: F::from_canonical_u64(addr), + addr: F::from_canonical_u32(addr), offset: F::from_canonical_usize(offset), width: F::from_canonical_usize(width), // wr: F::from_bool(false), @@ -611,7 +612,7 @@ impl MemAlignSM { let mut second_write_row = MemAlignRow:: { step: F::from_canonical_u64(step), - addr: F::from_canonical_u64(addr_second_read_write), + addr: F::from_canonical_u32(addr_second_read_write), // offset: F::from_canonical_u64(0), width: F::from_canonical_u64(CHUNK_NUM_U64), wr: F::from_bool(true), @@ -623,7 +624,7 @@ impl MemAlignSM { let mut second_read_row = MemAlignRow:: { step: F::from_canonical_u64(step + 1), - addr: F::from_canonical_u64(addr_second_read_write), + addr: F::from_canonical_u32(addr_second_read_write), // offset: F::from_canonical_u64(0), width: F::from_canonical_u64(CHUNK_NUM_U64), // wr: F::from_bool(false), @@ -873,7 +874,7 @@ impl MemAlignSM { // let mut read_row = MemAlignRow:: { // step: F::from_canonical_u64(step_read), - // addr: F::from_canonical_u64(addr_read), + // addr: F::from_canonical_u32(addr_read), // // offset: F::from_canonical_u64(0), // // wr: F::from_bool(false), // // pc: F::from_canonical_u64(0), @@ -884,7 +885,7 @@ impl MemAlignSM { // let mut value_row = MemAlignRow:: { // step: F::from_canonical_u64(step), - // addr: F::from_canonical_u64(addr), + // addr: F::from_canonical_u32(addr), // offset: F::from_canonical_usize(offset), // width: F::from_canonical_usize(width), // // wr: F::from_bool(false), @@ -929,7 +930,7 @@ impl MemAlignSM { // // RWV // let mut read_row = MemAlignRow:: { // step: F::from_canonical_u64(step_read), - // addr: F::from_canonical_u64(addr_read_write), + // addr: F::from_canonical_u32(addr_read_write), // // offset: F::from_canonical_u64(0), // width: F::from_canonical_u64(CHUNK_NUM_U64), // // wr: F::from_bool(false), @@ -941,7 +942,7 @@ impl MemAlignSM { // let mut write_row = MemAlignRow:: { // step: F::from_canonical_u64(step_write), - // addr: F::from_canonical_u64(addr_read_write), + // addr: F::from_canonical_u32(addr_read_write), // // offset: F::from_canonical_u64(0), // width: F::from_canonical_u64(CHUNK_NUM_U64), // wr: F::from_bool(true), @@ -953,7 +954,7 @@ impl MemAlignSM { // let mut value_row = MemAlignRow:: { // step: F::from_canonical_u64(step), - // addr: F::from_canonical_u64(addr), + // addr: F::from_canonical_u32(addr), // offset: F::from_canonical_usize(offset), // width: F::from_canonical_usize(width), // // wr: F::from_bool(false), @@ -1004,7 +1005,7 @@ impl MemAlignSM { // // RVR // let mut first_read_row = MemAlignRow:: { // step: F::from_canonical_u64(step_first_read), - // addr: F::from_canonical_u64(addr_first_read), + // addr: F::from_canonical_u32(addr_first_read), // // offset: F::from_canonical_u64(0), // width: F::from_canonical_u64(CHUNK_NUM_U64), // // wr: F::from_bool(false), @@ -1016,7 +1017,7 @@ impl MemAlignSM { // let mut value_row = MemAlignRow:: { // step: F::from_canonical_u64(step), - // addr: F::from_canonical_u64(addr), + // addr: F::from_canonical_u32(addr), // offset: F::from_canonical_usize(offset), // width: F::from_canonical_usize(width), // // wr: F::from_bool(false), @@ -1028,7 +1029,7 @@ impl MemAlignSM { // let mut second_read_row = MemAlignRow:: { // step: F::from_canonical_u64(step_second_read), - // addr: F::from_canonical_u64(addr_second_read), + // addr: F::from_canonical_u32(addr_second_read), // // offset: F::from_canonical_u64(0), // width: F::from_canonical_u64(CHUNK_NUM_U64), // // wr: F::from_bool(false), @@ -1084,7 +1085,7 @@ impl MemAlignSM { // // RWVWR // let mut first_read_row = MemAlignRow:: { // step: F::from_canonical_u64(step_first_read), - // addr: F::from_canonical_u64(addr_first_read_write), + // addr: F::from_canonical_u32(addr_first_read_write), // // offset: F::from_canonical_u64(0), // width: F::from_canonical_u64(CHUNK_NUM_U64), // // wr: F::from_bool(false), @@ -1096,7 +1097,7 @@ impl MemAlignSM { // let mut first_write_row = MemAlignRow:: { // step: F::from_canonical_u64(step_first_write), - // addr: F::from_canonical_u64(addr_first_read_write), + // addr: F::from_canonical_u32(addr_first_read_write), // // offset: F::from_canonical_u64(0), // width: F::from_canonical_u64(CHUNK_NUM_U64), // wr: F::from_bool(true), @@ -1108,7 +1109,7 @@ impl MemAlignSM { // let mut value_row = MemAlignRow:: { // step: F::from_canonical_u64(step), - // addr: F::from_canonical_u64(addr), + // addr: F::from_canonical_u32(addr), // offset: F::from_canonical_usize(offset), // width: F::from_canonical_usize(width), // // wr: F::from_bool(false), @@ -1120,7 +1121,7 @@ impl MemAlignSM { // let mut second_write_row = MemAlignRow:: { // step: F::from_canonical_u64(step_second_write), - // addr: F::from_canonical_u64(addr_second_read_write), + // addr: F::from_canonical_u32(addr_second_read_write), // // offset: F::from_canonical_u64(0), // width: F::from_canonical_u64(CHUNK_NUM_U64), // wr: F::from_bool(true), @@ -1132,7 +1133,7 @@ impl MemAlignSM { // let mut second_read_row = MemAlignRow:: { // step: F::from_canonical_u64(step_second_read), - // addr: F::from_canonical_u64(addr_second_read_write), + // addr: F::from_canonical_u32(addr_second_read_write), // // offset: F::from_canonical_u64(0), // width: F::from_canonical_u64(CHUNK_NUM_U64), // // wr: F::from_bool(false), diff --git a/state-machines/mem/src/mem_constants.rs b/state-machines/mem/src/mem_constants.rs new file mode 100644 index 00000000..4e177ee3 --- /dev/null +++ b/state-machines/mem/src/mem_constants.rs @@ -0,0 +1,12 @@ +pub const MEM_ADDR_MASK: u64 = 0xFFFF_FFFF_FFFF_FFF8; +pub const MEM_BYTES: u64 = 8; + +pub const MAX_MEM_STEP_OFFSET: u64 = 2; +pub const MAX_MEM_OPS_PER_MAIN_STEP: u64 = (MAX_MEM_STEP_OFFSET + 1) * 2; + +pub const MEM_STEP_BITS: u64 = 34; // with step_slot = 8 => 2GB steps ( +pub const MEM_STEP_MASK: u64 = (1 << MEM_STEP_BITS) - 1; // 256 MB +pub const MEM_ADDR_BITS: u64 = 64 - MEM_STEP_BITS; + +pub const MAX_MEM_STEP: u64 = (1 << MEM_STEP_BITS) - 1; +pub const MAX_MEM_ADDR: u64 = (1 << MEM_ADDR_BITS) - 1; diff --git a/state-machines/mem/src/mem_helpers.rs b/state-machines/mem/src/mem_helpers.rs new file mode 100644 index 00000000..ac4ca198 --- /dev/null +++ b/state-machines/mem/src/mem_helpers.rs @@ -0,0 +1,65 @@ +use crate::MemAlignResponse; +use std::fmt; +use zisk_core::ZiskRequiredMemory; + +fn format_u64_hex(value: u64) -> String { + let hex_str = format!("{:016x}", value); + hex_str + .as_bytes() + .chunks(4) + .map(|chunk| std::str::from_utf8(chunk).unwrap()) + .collect::>() + .join("_") +} + +impl fmt::Debug for MemAlignResponse { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "more:{0} step:{1} value:{2:016X}({2:})", + self.more_address, + self.step, + self.value.unwrap_or(0) + ) + } +} + +pub fn mem_align_call( + mem_op: &ZiskRequiredMemory, + mem_values: [u64; 2], + phase: u8, +) -> MemAlignResponse { + // DEBUG: only for testing + let offset = (mem_op.address & 0x7) * 8; + let width = (mem_op.width as u64) * 8; + let double_address = (offset + width as u32) > 64; + let mem_value = mem_values[phase as usize]; + let mask = 0xFFFF_FFFF_FFFF_FFFFu64 >> (64 - width); + if mem_op.is_write { + if phase == 0 { + MemAlignResponse { + more_address: double_address, + step: mem_op.step + 1, + value: Some( + (mem_value & (0xFFFF_FFFF_FFFF_FFFFu64 ^ (mask << offset))) + | ((mem_op.value & mask) << offset), + ), + } + } else { + MemAlignResponse { + more_address: false, + step: mem_op.step + 1, + value: Some( + (mem_value & (0xFFFF_FFFF_FFFF_FFFFu64 << (offset + width as u32 - 64))) + | ((mem_op.value & mask) >> (128 - (offset + width as u32))), + ), + } + } + } else { + MemAlignResponse { + more_address: double_address && phase == 0, + step: mem_op.step + 1, + value: None, + } + } +} diff --git a/state-machines/mem/src/mem_proxy.rs b/state-machines/mem/src/mem_proxy.rs index 612c2e29..8b4e40a8 100644 --- a/state-machines/mem/src/mem_proxy.rs +++ b/state-machines/mem/src/mem_proxy.rs @@ -1,23 +1,22 @@ -use std::collections::VecDeque; -use std::fmt; -use std::sync::{ - atomic::{AtomicU32, Ordering}, - Arc, +use std::{ + collections::VecDeque, + sync::{ + atomic::{AtomicU32, Ordering}, + Arc, + }, }; -use crate::{MemAlignResponse, MemAlignRomSM, MemAlignSM, MemSM}; +use crate::{ + mem_align_call, MemAlignResponse, MemAlignRomSM, MemAlignSM, MemSM, MAX_MEM_ADDR, + MAX_MEM_OPS_PER_MAIN_STEP, MAX_MEM_STEP, MEM_ADDR_BITS, MEM_ADDR_MASK, MEM_BYTES, +}; use p3_field::PrimeField; use pil_std_lib::Std; use proofman_util::{timer_start_debug, timer_stop_and_log_debug}; use zisk_core::ZiskRequiredMemory; use proofman::{WitnessComponent, WitnessManager}; - -const MEM_ADDR_MASK: u64 = 0xFFFF_FFFF_FFFF_FFF8; -const MEM_BYTES: u64 = 8; - -const MAX_MEM_STEP_OFFSET: u64 = 2; -const MAX_MEM_OPS_PER_MAIN_STEP: u64 = (MAX_MEM_STEP_OFFSET + 1) * 2; +use zisk_pil::QUICKOPS_AIRGROUP_ID; pub trait MemModule: Send + Sync { fn send_inputs(&self, mem_op: &[ZiskRequiredMemory]); @@ -27,348 +26,400 @@ pub trait MemModule: Send + Sync { fn register_predecessor(&self); } +trait MemAlignSm { + fn get_mem_op( + &self, + mem_op: &ZiskRequiredMemory, + mem_values: [u64; 2], + phase: u8, + ) -> MemAlignResponse; +} + struct MemModuleData { + pub name: String, pub inputs: Vec, pub addr_ranges: Vec<(u64, u64)>, pub flush_input_size: u64, } -impl fmt::Debug for MemAlignResponse { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "more:{} step:{} value:{:2}({:3})", - self.more_address, - self.step, - format_hex(self.value.unwrap_or(0)), - self.value.unwrap_or(0) - ) - } +struct MemAlignOperation { + addr: u32, + mem_op: ZiskRequiredMemory, + mem_value: [u64; 2], } + pub struct MemProxy { // Count of registered predecessors registered_predecessors: AtomicU32, // Secondary State machines - // mem_sm: Arc>, + mem_sm: Arc>, mem_align_sm: Arc>, - modules: Vec>>, - modules_data: Vec, + mem_align_rom_sm: Arc>, } -pub struct MemOperation { - pub step: u64, - pub is_write: bool, - pub address: u64, - pub width: u64, - pub value: u64, -} - -pub struct MemAlignOperation { - pub address: u64, - pub mem_op: ZiskRequiredMemory, - pub mem_value: [u64; 2], +pub struct MemProxyEngine { + modules: Vec>>, + modules_data: Vec, + open_mem_align_ops: VecDeque, + last_addr: u32, + last_addr_value: u64, + current_module_id: usize, + current_module: String, + module_end_addr: u32, } impl MemProxy { pub fn new(wcm: Arc>, std: Arc>) -> Arc { let mem_align_rom_sm = MemAlignRomSM::new(wcm.clone()); - let mem_align_sm = MemAlignSM::new(wcm.clone(), std, mem_align_rom_sm); - - let mut modules: Vec>> = Vec::new(); - - modules.push(MemSM::new(wcm.clone()).clone()); - let mut modules_data: Vec = Vec::new(); + let mem_align_sm = MemAlignSM::new(wcm.clone(), std, mem_align_rom_sm.clone()); + let mem_sm = MemSM::new(wcm.clone()); - for module in modules.iter_mut() { - modules_data.push(Self::init_module(module)); - } let mem_proxy = Self { registered_predecessors: AtomicU32::new(0), mem_align_sm, - modules, - modules_data, + mem_align_rom_sm, + mem_sm, }; let mem_proxy = Arc::new(mem_proxy); wcm.register_component(mem_proxy.clone(), None, None); // For all the secondary state machines, register the main state machine as a predecessor + mem_proxy.mem_align_rom_sm.register_predecessor(); mem_proxy.mem_align_sm.register_predecessor(); + mem_proxy.mem_sm.register_predecessor(); mem_proxy } - pub fn main_step_to_mem_step(step: u64, step_offset: u8) -> u64 { - 1 + MAX_MEM_OPS_PER_MAIN_STEP * step + 2 * step_offset as u64 - } pub fn register_predecessor(&self) { self.registered_predecessors.fetch_add(1, Ordering::SeqCst); } pub fn unregister_predecessor(&self) { if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { - for module in self.modules.iter() { - module.unregister_predecessor(); - } - // self.mem_sm.unregister_predecessor(); + self.mem_align_rom_sm.unregister_predecessor(); self.mem_align_sm.unregister_predecessor(); + self.mem_sm.unregister_predecessor(); } } - fn init_module(module: &Arc>) -> MemModuleData { - module.register_predecessor(); - let ranges = module.get_addr_ranges(); - let flush_input_size = module.get_flush_input_size(); - MemModuleData { inputs: Vec::new(), addr_ranges: ranges, flush_input_size } + pub fn prove( + &self, + mem_operations: &mut Vec, + ) -> Result<(), Box> { + let mut engine = MemProxyEngine::::new(); + engine.add_module("mem", self.mem_sm.clone()); + engine.prove(&self.mem_align_sm, mem_operations) } +} - /// Static method to decide it the memory operation needs to be processed by - /// memAlign, because it isn't a 8-byte and 8-byte aligned memory access. - fn is_aligned(mem_op: &ZiskRequiredMemory) -> bool { - let aligned_mem_address = mem_op.address & MEM_ADDR_MASK; - aligned_mem_address == mem_op.address && mem_op.width == MEM_BYTES - } - /// Process information of mem_op and mem_align_op to push mem_op operation. Only two possible situations: - /// 1) read, only on single mem_op is pushed - /// 2) read+write, two mem_op are pushed, one read and one write. - /// - /// This process is used for each aligned memory address, means that the "second part" of non aligned memory - /// operation is processed on addr + MEM_BYTES. - fn push_mem_align_op( - &self, - mem_addr: u64, - mem_value: u64, - mem_op: &ZiskRequiredMemory, - mem_align_op: &MemAlignResponse, - input: &mut Vec, - ) -> u64 { - // Prepare aligned memory access - let read = ZiskRequiredMemory { - step: mem_align_op.step, - is_write: false, - address: mem_addr, - width: MEM_BYTES, - value: mem_value, - }; - input.push(read); +impl MemProxyEngine { + pub fn new() -> Self { + let mut modules: Vec>> = Vec::new(); + let mut modules_data: Vec = Vec::new(); - if mem_op.is_write { - let mem_value = mem_align_op.value.expect("value returned by mem_align"); - let write = ZiskRequiredMemory { - step: mem_align_op.step + 1, - is_write: true, - address: mem_addr, - width: MEM_BYTES, - value: mem_value, - }; - input.push(write); - mem_value - } else { - mem_value + Self { + modules, + modules_data, + last_addr: 0, + last_addr_value: 0, + current_module_id: 0, + current_module: String::new(), + module_end_addr: 0, + open_mem_align_ops: VecDeque::new(), } } - fn create_modules_inputs(&self) -> Vec> { - let mut mem_module_inputs: Vec> = Default::default(); - for module in self.modules.iter() { - mem_module_inputs.push(Vec::new()); + + pub fn add_module(&mut self, name: &str, module: Arc>) { + if self.modules.is_empty() { + self.current_module = String::from(name); } - mem_module_inputs - } - fn get_mem_module_id(&self, address: u64) -> (usize, u64) { - let mem_module_id = 0; - let next_addr_to_reevaluate = 0xFFFF_FFFF_FFFF; - (mem_module_id, next_addr_to_reevaluate) + self.modules.push(module.clone()); + self.modules_data.push(Self::init_module(name, &module)); } pub fn prove( - &self, + &mut self, + mem_align_sm: &MemAlignSM, mem_operations: &mut Vec, ) -> Result<(), Box> { - let mut open_mem_align_ops: VecDeque = VecDeque::new(); - let mut mem_module_inputs = self.create_modules_inputs(); + self.init_prove(&mem_operations); // Step 1. Sort the aligned memory accesses // original vector is sorted by step, sort_by_key is stable, no reordering of elements with // the same key. timer_start_debug!(MEM_SORT); - mem_operations.sort_by_key(|mem| (mem.address & 0xFFFF_FFFF_FFFF_FFF8)); + mem_operations.sort_by_key(|mem| (mem.address & 0xFFFF_FFF8)); timer_stop_and_log_debug!(MEM_SORT); - // Initialize the last values of address and value on the sorted memory operations - let mut last_addr = 0xFFFF_FFFF_FFFF_FFFFu64; - let mut last_value = 0u64; - - // Add a final fake mem_op to force flush of open_mem_align_ops - mem_operations.push(ZiskRequiredMemory { - step: 0, - is_write: false, - address: MEM_ADDR_MASK, - width: 8, - value: 0, - }); - - // Initialize the module id and next module address to reevaluate the module id, it's done - // to avoid check on each loop if memory address is inside one range or other - let (mut mem_module_id, mut next_module_addr) = if mem_operations.is_empty() { - (0, 0) - } else { - self.get_mem_module_id(mem_operations[0].address) - }; + // Step2. Add a final mark mem_op to force flush of open_mem_align_ops, because always the + // last operation is mem_op. + mem_operations.push(Self::end_of_memory_mark()); + + // Step3. Process each memory operation ordered by address and step. When a non-aligned + // memory access there are two possible situations: + // + // 1) the operation applies only applies to one memory address (read or read+write). In + // this case mem_align helper return the aligned operation for this address, and loop + // continues. + // 2) the operation applies to two consecutive memory addresses, mem_align helper returns + // the aligned operation involved for the current address, and the second part of the + // operation is enqueued to open_mem_align_ops, it will processed when processing next + // address. + // + // Inside loop, first of all, we verify if exists "previous" open mem align operations that + // be processed before current mem_op, in this case process all "previous" and after process + // the current mem_op. for mem_op in mem_operations.iter_mut() { - let mut aligned_mem_address = mem_op.address & MEM_ADDR_MASK; - - // ONLY TO TEST - if aligned_mem_address < 0xA0000000 { - continue; - } + self.log_mem_op(mem_op); - // Check if there are open mem align operations to be processed in this moment. Two possible - // conditions to process open mem align operations: - // 1) the address of open operation is less than the aligned address. - // 2) the address of open operation is equal to the aligned address, but the step of the open - // operation is less than the step of the current operation. - - while open_mem_align_ops.len() > 0 - && (open_mem_align_ops[0].address < aligned_mem_address - || (open_mem_align_ops[0].address == aligned_mem_address - && open_mem_align_ops[0].mem_op.step < mem_op.step)) - { - let open_op = open_mem_align_ops.pop_front().unwrap(); - let mem_value = if open_op.address == last_addr { last_value } else { 0 }; - - // call to mem_align to get information of the aligned memory access needed - // to prove the unaligned open operation. - let mem_align_op = self.mem_align_sm.get_mem_op(&open_op.mem_op, [mem_value, 0], 1); - - // remove element from top of queue, because we are on last phase, phase 1. - open_mem_align_ops.pop_front(); - - // check if need to reevaluate the module id - if open_op.address >= next_module_addr { - (mem_module_id, next_module_addr) = self.get_mem_module_id(open_op.address); - } - // push the aligned memory operations for current address (read or read+write) and - // update last_address and last_value. - last_value = self.push_mem_align_op( - open_op.address, - mem_value, - &mem_op, - &mem_align_op, - &mut mem_module_inputs[mem_module_id], - ); - last_addr = open_op.address; + let aligned_mem_addr = Self::to_aligned_addr(mem_op.address); + let mem_step = mem_op.step; - // check if need to flush the inputs of the module - if (mem_module_inputs[mem_module_id].len() as u64) - >= self.modules_data[mem_module_id].flush_input_size - { - self.modules[mem_module_id].send_inputs(&mut mem_module_inputs[mem_module_id]); - } + if aligned_mem_addr < 0xA0000000 { + // only for testing purposes + continue; } - aligned_mem_address = mem_op.address & MEM_ADDR_MASK; + // Check if there are open mem align operations to be processed in this moment, with + // address (or step) less than the aligned of current mem_op. + self.process_all_previous_open_mem_align_ops(aligned_mem_addr, mem_step); - // check if the aligned address is the last address to avoid processing the last fake mem_op - if aligned_mem_address == MEM_ADDR_MASK { - assert!( - open_mem_align_ops.len() == 0, - "open_mem_align_ops not empty, has {} elements", - open_mem_align_ops.len() - ); + // check if we are at end of loop + if self.check_if_end_of_memory_mark(mem_op) { break; } - // check if need to reevaluate the module id - if aligned_mem_address >= next_module_addr { - (mem_module_id, next_module_addr) = self.get_mem_module_id(aligned_mem_address); - } - - let mem_value = if aligned_mem_address == last_addr { last_value } else { 0 }; + // TODO: edge case special memory with free-input memory data as input data + let mem_value = self.get_mem_value(aligned_mem_addr, mem_op); // all open mem align operations are processed, check if new mem operation is aligned if !Self::is_aligned(&mem_op) { // In this point found non-aligned memory access, phase-0 - let mem_align_op = self.mem_align_sm.get_mem_op(mem_op, [mem_value, 0], 0); + let mem_align_op = mem_align_sm.get_mem_op(mem_op, [mem_value, 0], 0); + + // if operation applies to two consecutive memory addresses, add the second part + // is enqueued to be processed in future when processing next address on phase-1 if mem_align_op.more_address { - open_mem_align_ops.push_back(MemAlignOperation { - address: aligned_mem_address + MEM_BYTES, - mem_op: mem_op.clone(), - mem_value: [mem_value, 0], - }); + self.push_open_mem_align_op(aligned_mem_addr, mem_value, mem_op); } - last_value = self.push_mem_align_op( - aligned_mem_address, + self.push_mem_align_response_ops( + aligned_mem_addr, mem_value, - &mem_op, + mem_op, &mem_align_op, - &mut mem_module_inputs[mem_module_id], ); - last_addr = aligned_mem_address } else { - mem_module_inputs[mem_module_id].push(mem_op.clone()); - last_value = mem_op.value; - last_addr = aligned_mem_address + self.push_mem_op(mem_op); } + } + self.finish_prove(); + Ok(()) + } - // check if need to flush the inputs of the module - if (mem_module_inputs[mem_module_id].len() as u64) - >= self.modules_data[mem_module_id].flush_input_size - { - self.modules[mem_module_id].send_inputs(&mut mem_module_inputs[mem_module_id]); - } + fn process_all_previous_open_mem_align_ops(&mut self, mem_addr: u32, mem_step: u64) { + // Two possible situations to process open mem align operations: + // + // 1) the address of open operation is less than the aligned address. + // 2) the address of open operation is equal to the aligned address, but the step of the + // open operation is less than the step of the current operation. + + while self.has_open_mem_align_lt(mem_addr, mem_step) { + let open_op = self.open_mem_align_ops.pop_front().unwrap(); + let mem_value = if open_op.addr == self.last_addr { self.last_addr_value } else { 0 }; + + // call to mem_align to get information of the aligned memory access needed + // to prove the unaligned open operation. + let mem_align_op = mem_align_call(&open_op.mem_op, [mem_value, 0], 1); + + // remove element from top of queue, because we are on last phase, phase 1. + self.open_mem_align_ops.pop_front(); + + // push the aligned memory operations for current address (read or read+write) and + // update last_address and last_value. + self.push_mem_align_response_ops( + open_op.addr, + mem_value, + &open_op.mem_op, + &mem_align_op, + ); } + } - Ok(()) + pub fn main_step_to_mem_step(step: u64, step_offset: u8) -> u64 { + 1 + MAX_MEM_OPS_PER_MAIN_STEP * step + 2 * step_offset as u64 } -} -impl WitnessComponent for MemProxy {} + fn init_module(name: &str, module: &Arc>) -> MemModuleData { + // module.register_predecessor(); + let ranges = module.get_addr_ranges(); + let flush_input_size = module.get_flush_input_size(); + MemModuleData { + name: String::from(name), + inputs: Vec::new(), + addr_ranges: ranges, + flush_input_size, + } + } -fn format_hex(value: u64) -> String { - let hex_str = format!("{:016x}", value); // Format hexadecimal amb 16 dígits i padding de 0s - hex_str - .as_bytes() // Converteix a bytes per manipular fàcilment - .chunks(4) // Separa en grups de 4 caràcters (2 bytes) - .map(|chunk| std::str::from_utf8(chunk).unwrap()) // Converteix cada chunk a &str - .collect::>() // Recull els chunks com a un vector - .join("_") // Uneix amb "_" -} + /// Static method to decide it the memory operation needs to be processed by + /// memAlign, because it isn't a 8-byte and 8-byte aligned memory access. + fn is_aligned(mem_op: &ZiskRequiredMemory) -> bool { + let aligned_mem_address = (mem_op.address as u64 & MEM_ADDR_MASK) as u32; + aligned_mem_address == mem_op.address && mem_op.width == MEM_BYTES as u8 + } + fn push_mem_op(&mut self, mem_op: &ZiskRequiredMemory) { + self.push_aligned_op(mem_op.is_write, mem_op.address, mem_op.value, mem_op.step); + } -fn mem_align_call( - mem_op: &ZiskRequiredMemory, - mem_values: [u64; 2], - phase: u8, -) -> MemAlignResponse { - // DEBUG: only for testing - let offset = (mem_op.address & 0x7) * 8; - let width = (mem_op.width as u64) * 8; - let double_address = (offset + width) > 64; - let mem_value = mem_values[phase as usize]; - let mask = 0xFFFF_FFFF_FFFF_FFFFu64 >> (64 - width); - if mem_op.is_write { - if phase == 0 { - MemAlignResponse { - more_address: double_address, - step: mem_op.step + 1, - value: Some( - (mem_value & (0xFFFF_FFFF_FFFF_FFFFu64 ^ (mask << offset))) - | ((mem_op.value & mask) << offset), - ), - } + fn push_aligned_op(&mut self, is_write: bool, addr: u32, value: u64, step: u64) { + self.update_last_addr(addr, value); + let mem_op = ZiskRequiredMemory { + step, + is_write, + address: addr as u32, + width: MEM_BYTES as u8, + value, + }; + println!(" ##SEND[{0}]## mem_op: {1:?}", self.current_module, mem_op); + self.modules_data[self.current_module_id].inputs.push(mem_op); + self.last_addr_value = value; + self.check_flush_inputs(); + } + // method to add aligned read operation + #[inline(always)] + fn push_aligned_read(&mut self, addr: u32, value: u64, step: u64) { + self.push_aligned_op(false, addr, value, step); + } + // method to add aligned write operation + #[inline(always)] + fn push_aligned_write(&mut self, addr: u32, value: u64, step: u64) { + self.push_aligned_op(true, addr, value, step); + } + /// Process information of mem_op and mem_align_op to push mem_op operation. Only two possible + /// situations: + /// 1) read, only on single mem_op is pushed + /// 2) read+write, two mem_op are pushed, one read and one write. + /// + /// This process is used for each aligned memory address, means that the "second part" of non + /// aligned memory operation is processed on addr + MEM_BYTES. + fn push_mem_align_response_ops( + &mut self, + mem_addr: u32, + mem_value: u64, + mem_op: &ZiskRequiredMemory, + mem_align_op: &MemAlignResponse, + ) { + self.push_aligned_read(mem_addr, mem_value, mem_align_op.step); + if mem_op.is_write { + let mem_value = mem_align_op.value.expect("value returned by mem_align"); + self.push_aligned_write(mem_addr, mem_value, mem_align_op.step + 1); + } + } + fn create_modules_inputs(&self) -> Vec> { + let mut mem_module_inputs: Vec> = Default::default(); + for module in self.modules.iter() { + mem_module_inputs.push(Vec::new()); + } + mem_module_inputs + } + fn get_mem_module_id(&self, address: u32) -> (usize, u32) { + (0, MAX_MEM_ADDR as u32 + 1) + } + fn update_mem_module_id(&mut self, addr: u32) { + (self.current_module_id, self.module_end_addr) = self.get_mem_module_id(addr); + } + fn update_last_addr(&mut self, addr: u32, value: u64) { + self.last_addr = addr; + // check if need to reevaluate the module id + if addr >= self.module_end_addr { + self.update_mem_module_id(addr); + } + } + fn check_flush_inputs(&self) { + // check if need to flush the inputs of the module + let mid = self.current_module_id; + if (self.modules_data[mid].inputs.len() as u64) >= self.modules_data[mid].flush_input_size { + self.modules[mid].send_inputs(&self.modules_data[mid].inputs); + } + } + + fn has_open_mem_align_lt(&self, addr: u32, step: u64) -> bool { + self.open_mem_align_ops.len() > 0 && + (self.open_mem_align_ops[0].addr < addr || + (self.open_mem_align_ops[0].addr == addr && + self.open_mem_align_ops[0].mem_op.step < step)) + } + // method to process open mem align operations, second part of non aligned memory operations + // applies to two consecutive memory addresses. + + fn end_of_memory_mark() -> ZiskRequiredMemory { + ZiskRequiredMemory { + step: MAX_MEM_STEP, + is_write: false, + address: MAX_MEM_ADDR as u32, + width: MEM_BYTES as u8, + value: 0, + } + } + #[inline(always)] + fn check_if_end_of_memory_mark(&self, mem_op: &ZiskRequiredMemory) -> bool { + if mem_op.step == MAX_MEM_STEP && mem_op.address == MAX_MEM_ADDR as u32 { + assert!( + self.open_mem_align_ops.len() == 0, + "open_mem_align_ops not empty, has {} elements", + self.open_mem_align_ops.len() + ); + true } else { - MemAlignResponse { - more_address: false, - step: mem_op.step + 1, - value: Some( - (mem_value & (0xFFFF_FFFF_FFFF_FFFFu64 << (offset + width - 64))) - | ((mem_op.value & mask) >> (128 - offset - width)), - ), - } + false } - } else { - MemAlignResponse { - more_address: double_address && phase == 0, - step: mem_op.step + 1, - value: None, + } + fn init_prove(&mut self, mem_operations: &Vec) { + // Initialize the last values of address and value on the sorted memory operations + let mut last_addr = 0xFFFF_FFFF_FFFF_FFFFu64; + let mut last_value = 0u64; + + // Initialize the module id and next module address to reevaluate the module id, it's done + // to avoid check on each loop if memory address is inside one range or other + let (mut mem_module_id, mut next_module_addr) = if mem_operations.is_empty() { + (0, 0) + } else { + self.get_mem_module_id(mem_operations[0].address) + }; + } + fn finish_prove(&self) {} + fn get_mem_value(&self, addr: u32, mem_op: &ZiskRequiredMemory) -> u64 { + if addr == self.last_addr { + self.last_addr_value + } else { + 0 } } + + #[inline(always)] + fn push_open_mem_align_op( + &mut self, + aligned_mem_addr: u32, + mem_value: u64, + mem_op: &ZiskRequiredMemory, + ) { + self.open_mem_align_ops.push_back(MemAlignOperation { + addr: aligned_mem_addr + MEM_BYTES as u32, + mem_op: mem_op.clone(), + mem_value: [mem_value, 0], + }); + } + fn log_mem_op(&self, mem_op: &ZiskRequiredMemory) { + println!( + "##LOOP## mem_op: {0:?} 0x{1:#08X}({1}) 0x{2:#016X}({2})", + mem_op, self.last_addr, self.last_addr_value + ); + } + #[inline(always)] + fn to_aligned_addr(addr: u32) -> u32 { + (addr as u64 & MEM_ADDR_MASK) as u32 + } } + +impl WitnessComponent for MemProxy {} diff --git a/state-machines/mem/src/mem_sm.rs b/state-machines/mem/src/mem_sm.rs index fe54e310..af8ec365 100644 --- a/state-machines/mem/src/mem_sm.rs +++ b/state-machines/mem/src/mem_sm.rs @@ -158,7 +158,7 @@ impl MemSM { // trace[0].mem_segment = segment_id_field; // trace[0].mem_last_segment = is_last_segment_field; - trace[0].addr = F::from_canonical_u64(mem_first_row.address); + trace[0].addr = F::from_canonical_u32(mem_first_row.address); trace[0].step = F::from_canonical_u64(mem_first_row.step); trace[0].sel = F::zero(); trace[0].wr = F::zero(); @@ -184,7 +184,7 @@ impl MemSM { // trace[i].mem_segment = segment_id_field; // trace[i].mem_last_segment = is_last_segment_field; - trace[i].addr = F::from_canonical_u64(mem_op.address); // n-byte address, real address = addr * MEM_BYTES + trace[i].addr = F::from_canonical_u32(mem_op.address); // n-byte address, real address = addr * MEM_BYTES trace[i].step = F::from_canonical_u64(mem_op.step); trace[i].sel = F::one(); trace[i].wr = F::from_bool(mem_op.is_write); diff --git a/witness-computation/src/executor.rs b/witness-computation/src/executor.rs index 48e3fe56..43d71b76 100644 --- a/witness-computation/src/executor.rs +++ b/witness-computation/src/executor.rs @@ -83,7 +83,8 @@ impl ZiskExecutor { // TODO - If there is more than one Main AIR available, the MAX_ACCUMULATED will be the one // with the highest num_rows. It has to be a power of 2. - let main_sm = MainSM::new(wcm.clone(), mem_proxy_sm.clone(), arith_sm.clone(), binary_sm.clone()); + let main_sm = + MainSM::new(wcm.clone(), mem_proxy_sm.clone(), arith_sm.clone(), binary_sm.clone()); Self { zisk_rom, main_sm, rom_sm, mem_proxy_sm, binary_sm, arith_sm } } @@ -197,7 +198,11 @@ impl ZiskExecutor { // ---------------------------------------------- let mem_thread = thread::spawn({ let mem_proxy_sm = self.mem_proxy_sm.clone(); - move || mem_proxy_sm.prove(&mut mem_required).expect("Error during Memory witness computation") + move || { + mem_proxy_sm + .prove(&mut mem_required) + .expect("Error during Memory witness computation") + } }); // ROM State Machine @@ -282,6 +287,21 @@ impl ZiskExecutor { mem_thread.join().expect("Error during Memory witness computation"); + // match mem_thread.join() { + // Ok(_) => println!("El thread ha finalitzat correctament."), + // Err(e) => { + // println!("El thread ha fet panic!"); + // + // // Converteix l'error en una cadena llegible (opcional) + // if let Some(missatge) = e.downcast_ref::<&str>() { + // println!("Missatge d'error: {}", missatge); + // } else if let Some(missatge) = e.downcast_ref::() { + // println!("Missatge d'error: {}", missatge); + // } else { + // println!("No es pot determinar el tipus d'error."); + // } + // } + // } if let Some(thread) = rom_thread { let _ = thread.join().expect("Error during ROM witness computation"); } @@ -290,4 +310,4 @@ impl ZiskExecutor { self.binary_sm.unregister_predecessor(); // self.arith_sm.register_predecessor(scope); } -} \ No newline at end of file +} From 9d8e75cdb13f71108296442fdfeb6bb0002c34ac Mon Sep 17 00:00:00 2001 From: zkronos73 Date: Thu, 21 Nov 2024 19:51:24 +0000 Subject: [PATCH 34/44] WIP memory state machhines --- Cargo.lock | 10 -- Cargo.toml | 24 +-- state-machines/mem/src/lib.rs | 2 + state-machines/mem/src/mem_constants.rs | 2 +- state-machines/mem/src/mem_proxy.rs | 196 +++++++++++++----------- state-machines/mem/src/mem_sm.rs | 9 +- state-machines/mem/src/mem_unmapped.rs | 33 ++++ 7 files changed, 159 insertions(+), 117 deletions(-) create mode 100644 state-machines/mem/src/mem_unmapped.rs diff --git a/Cargo.lock b/Cargo.lock index 73f4a1a1..6fb276c5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1467,7 +1467,6 @@ dependencies = [ [[package]] name = "pil-std-lib" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.10#cb182461a9e8dc4077be76b81733b5384d640a21" dependencies = [ "log", "num-bigint", @@ -1485,7 +1484,6 @@ dependencies = [ [[package]] name = "pilout" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.10#cb182461a9e8dc4077be76b81733b5384d640a21" dependencies = [ "bytes", "log", @@ -1605,7 +1603,6 @@ dependencies = [ [[package]] name = "proofman" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.10#cb182461a9e8dc4077be76b81733b5384d640a21" dependencies = [ "colored", "env_logger", @@ -1626,7 +1623,6 @@ dependencies = [ [[package]] name = "proofman-common" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.10#cb182461a9e8dc4077be76b81733b5384d640a21" dependencies = [ "env_logger", "log", @@ -1644,7 +1640,6 @@ dependencies = [ [[package]] name = "proofman-hints" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.10#cb182461a9e8dc4077be76b81733b5384d640a21" dependencies = [ "p3-field", "proofman-common", @@ -1654,7 +1649,6 @@ dependencies = [ [[package]] name = "proofman-macros" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.10#cb182461a9e8dc4077be76b81733b5384d640a21" dependencies = [ "proc-macro2", "quote", @@ -1664,7 +1658,6 @@ dependencies = [ [[package]] name = "proofman-starks-lib-c" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.10#cb182461a9e8dc4077be76b81733b5384d640a21" dependencies = [ "log", ] @@ -1672,7 +1665,6 @@ dependencies = [ [[package]] name = "proofman-util" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.10#cb182461a9e8dc4077be76b81733b5384d640a21" dependencies = [ "colored", "sysinfo 0.31.4", @@ -2319,7 +2311,6 @@ checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" [[package]] name = "stark" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.10#cb182461a9e8dc4077be76b81733b5384d640a21" dependencies = [ "log", "p3-field", @@ -2669,7 +2660,6 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?rev=0.0.10#cb182461a9e8dc4077be76b81733b5384d640a21" dependencies = [ "proofman-starks-lib-c", ] diff --git a/Cargo.toml b/Cargo.toml index b97f5cb2..1f8e4798 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,19 +26,19 @@ opt-level = 3 opt-level = 3 [workspace.dependencies] -proofman-common = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.10" } -proofman-macros = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.10" } -proofman-util = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.10" } -proofman = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.10" } -pil-std-lib = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.10" } -stark = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.10" } +# proofman-common = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.10" } +# proofman-macros = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.10" } +# proofman-util = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.10" } +# proofman = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.10" } +# pil-std-lib = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.10" } +# stark = { git = "https://github.com/0xPolygonHermez/pil2-proofman.git", rev = "0.0.10" } # Local development -# proofman-common = { path = "../pil2-proofman/common" } -# proofman-macros = { path = "../pil2-proofman/macros" } -# proofman-util = { path = "../pil2-proofman/util" } -# proofman = { path = "../pil2-proofman/proofman" } -# pil-std-lib = { path = "../pil2-proofman/pil2-components/lib/std/rs" } -# stark = { path = "../pil2-proofman/provers/stark" } +proofman-common = { path = "../pil2-proofman/common" } +proofman-macros = { path = "../pil2-proofman/macros" } +proofman-util = { path = "../pil2-proofman/util" } +proofman = { path = "../pil2-proofman/proofman" } +pil-std-lib = { path = "../pil2-proofman/pil2-components/lib/std/rs" } +stark = { path = "../pil2-proofman/provers/stark" } p3-field = { git = "https://github.com/Plonky3/Plonky3.git", rev = "c3d754ef77b9fce585b46b972af751fe6e7a9803" } log = "0.4" diff --git a/state-machines/mem/src/lib.rs b/state-machines/mem/src/lib.rs index ab1d4209..4f3116db 100644 --- a/state-machines/mem/src/lib.rs +++ b/state-machines/mem/src/lib.rs @@ -4,6 +4,7 @@ mod mem_constants; mod mem_helpers; mod mem_proxy; mod mem_sm; +mod mem_unmapped; pub use mem_align_rom_sm::*; pub use mem_align_sm::*; @@ -11,3 +12,4 @@ pub use mem_constants::*; pub use mem_helpers::*; pub use mem_proxy::*; pub use mem_sm::*; +pub use mem_unmapped::*; diff --git a/state-machines/mem/src/mem_constants.rs b/state-machines/mem/src/mem_constants.rs index 4e177ee3..cb113775 100644 --- a/state-machines/mem/src/mem_constants.rs +++ b/state-machines/mem/src/mem_constants.rs @@ -9,4 +9,4 @@ pub const MEM_STEP_MASK: u64 = (1 << MEM_STEP_BITS) - 1; // 256 MB pub const MEM_ADDR_BITS: u64 = 64 - MEM_STEP_BITS; pub const MAX_MEM_STEP: u64 = (1 << MEM_STEP_BITS) - 1; -pub const MAX_MEM_ADDR: u64 = (1 << MEM_ADDR_BITS) - 1; +pub const MAX_MEM_ADDR: u64 = 0xFFFF_FFFF; diff --git a/state-machines/mem/src/mem_proxy.rs b/state-machines/mem/src/mem_proxy.rs index edb2db6a..4fabbb9f 100644 --- a/state-machines/mem/src/mem_proxy.rs +++ b/state-machines/mem/src/mem_proxy.rs @@ -6,8 +6,10 @@ use std::{ }, }; +const UNMAPPED_MODULE_ID: u8 = 0xFE; + use crate::{ - mem_align_call, MemAlignResponse, MemAlignRomSM, MemAlignSM, MemSM, MAX_MEM_ADDR, + mem_align_call, MemAlignResponse, MemAlignRomSM, MemAlignSM, MemSM, MemUnmapped, MAX_MEM_ADDR, MAX_MEM_OPS_PER_MAIN_STEP, MAX_MEM_STEP, MEM_ADDR_BITS, MEM_ADDR_MASK, MEM_BYTES, }; use p3_field::PrimeField; @@ -16,11 +18,10 @@ use proofman_util::{timer_start_debug, timer_stop_and_log_debug}; use zisk_core::ZiskRequiredMemory; use proofman::{WitnessComponent, WitnessManager}; -use zisk_pil::QUICKOPS_AIRGROUP_ID; pub trait MemModule: Send + Sync { fn send_inputs(&self, mem_op: &[ZiskRequiredMemory]); - fn get_addr_ranges(&self) -> Vec<(u64, u64)>; + fn get_addr_ranges(&self) -> Vec<(u32, u32)>; fn get_flush_input_size(&self) -> u64; fn unregister_predecessor(&self); fn register_predecessor(&self); @@ -37,8 +38,9 @@ trait MemAlignSm { struct MemModuleData { pub name: String, + pub id: u8, + pub ranges: Vec<(u32, u32)>, pub inputs: Vec, - pub addr_ranges: Vec<(u64, u64)>, pub flush_input_size: u64, } @@ -58,10 +60,18 @@ pub struct MemProxy { mem_align_rom_sm: Arc>, } +#[derive(Debug)] +pub struct AddressRegion { + from_address: u32, + to_address: u32, + module_id: u8, +} pub struct MemProxyEngine { modules: Vec>>, modules_data: Vec, open_mem_align_ops: VecDeque, + address_map: Vec, + address_map_closed: bool, last_addr: u32, last_addr_value: u64, current_module_id: usize, @@ -109,6 +119,7 @@ impl MemProxy { ) -> Result<(), Box> { let mut engine = MemProxyEngine::::new(); engine.add_module("mem", self.mem_sm.clone()); + engine.close_address_map(); engine.prove(&self.mem_align_sm, mem_operations) } } @@ -127,6 +138,8 @@ impl MemProxyEngine { current_module: String::new(), module_end_addr: 0, open_mem_align_ops: VecDeque::new(), + address_map: Vec::new(), + address_map_closed: false, } } @@ -134,9 +147,34 @@ impl MemProxyEngine { if self.modules.is_empty() { self.current_module = String::from(name); } + let module_id = self.modules.len() as u8; self.modules.push(module.clone()); - self.modules_data.push(Self::init_module(name, &module)); + + let ranges = module.get_addr_ranges(); + let flush_input_size = module.get_flush_input_size(); + + for range in ranges.iter() { + println!("## PROXY adding range 0x{:X} 0x{:X} ##", range.0, range.1); + self.insert_address_range(range.0, range.1, module_id); + } + self.modules_data.push(MemModuleData { + name: String::from(name), + id: module_id, + ranges, + inputs: Vec::new(), + flush_input_size, + }); + } + /* insert in sort way the address map and verify that */ + fn insert_address_range(&mut self, from_address: u32, to_address: u32, module_id: u8) { + let region = AddressRegion { from_address, to_address, module_id }; + if let Some(index) = self.address_map.iter().position(|x| x.from_address >= from_address) { + self.address_map.insert(index, region); + } else { + self.address_map.push(region); + } } + pub fn prove( &mut self, mem_align_sm: &MemAlignSM, @@ -176,11 +214,6 @@ impl MemProxyEngine { let aligned_mem_addr = Self::to_aligned_addr(mem_op.address); let mem_step = mem_op.step; - if aligned_mem_addr < 0xA0000000 { - // only for testing purposes - continue; - } - // Check if there are open mem align operations to be processed in this moment, with // address (or step) less than the aligned of current mem_op. self.process_all_previous_open_mem_align_ops(aligned_mem_addr, mem_step); @@ -250,18 +283,6 @@ impl MemProxyEngine { 1 + MAX_MEM_OPS_PER_MAIN_STEP * step + 2 * step_offset as u64 } - fn init_module(name: &str, module: &Arc>) -> MemModuleData { - // module.register_predecessor(); - let ranges = module.get_addr_ranges(); - let flush_input_size = module.get_flush_input_size(); - MemModuleData { - name: String::from(name), - inputs: Vec::new(), - addr_ranges: ranges, - flush_input_size, - } - } - /// Static method to decide it the memory operation needs to be processed by /// memAlign, because it isn't a 8-byte and 8-byte aligned memory access. fn is_aligned(mem_op: &ZiskRequiredMemory) -> bool { @@ -281,7 +302,7 @@ impl MemProxyEngine { width: MEM_BYTES as u8, value, }; - println!(" ##SEND[{0}]## mem_op: {1:?}", self.current_module, mem_op); + println!("## PROXY SEND {0} ## {1:?}", self.current_module, mem_op); self.modules_data[self.current_module_id].inputs.push(mem_op); self.last_addr_value = value; self.check_flush_inputs(); @@ -323,24 +344,46 @@ impl MemProxyEngine { } mem_module_inputs } - fn get_mem_module_id(&self, address: u32) -> (usize, u32) { - (0, MAX_MEM_ADDR as u32 + 1) + fn set_active_region(&mut self, region_id: usize) { + self.current_module_id = self.address_map[region_id].module_id as usize; + self.current_module = self.modules_data[self.current_module_id].name.clone(); + self.module_end_addr = self.address_map[region_id].to_address; } fn update_mem_module_id(&mut self, addr: u32) { - (self.current_module_id, self.module_end_addr) = self.get_mem_module_id(addr); + println!( + "## \x1B[31mGET MODULE ID\x1B[0m ## 0x{0:X} module_end_addr:0x{1:X} 0x{2:X}", + addr, self.module_end_addr, MAX_MEM_ADDR as u32 + ); + // println!("{:?}", self.address_map); + if let Some(index) = + self.address_map.iter().position(|x| x.from_address <= addr && x.to_address >= addr) + { + self.set_active_region(index); + } else { + assert!(false, "out-of-memory 0x{:X}", addr); + } } fn update_last_addr(&mut self, addr: u32, value: u64) { self.last_addr = addr; // check if need to reevaluate the module id - if addr >= self.module_end_addr { + if addr > self.module_end_addr { self.update_mem_module_id(addr); } } - fn check_flush_inputs(&self) { + fn check_flush_inputs(&mut self) { // check if need to flush the inputs of the module let mid = self.current_module_id; + println!( + "## PROXY FLUSH ## {0} {1} {2}", + mid, + self.modules_data[mid].inputs.len(), + self.modules_data[mid].flush_input_size + ); if (self.modules_data[mid].inputs.len() as u64) >= self.modules_data[mid].flush_input_size { + // TODO: optimize passing ownership of inputs to module, and creating a new input + // object self.modules[mid].send_inputs(&self.modules_data[mid].inputs); + self.modules_data[mid].inputs.clear(); } } @@ -376,17 +419,17 @@ impl MemProxyEngine { } } fn init_prove(&mut self, mem_operations: &Vec) { - // Initialize the last values of address and value on the sorted memory operations - let mut last_addr = 0xFFFF_FFFF_FFFF_FFFFu64; - let mut last_value = 0u64; - - // Initialize the module id and next module address to reevaluate the module id, it's done - // to avoid check on each loop if memory address is inside one range or other - let (mut mem_module_id, mut next_module_addr) = if mem_operations.is_empty() { - (0, 0) - } else { - self.get_mem_module_id(mem_operations[0].address) - }; + if !self.address_map_closed { + self.close_address_map(); + } + println!( + "## PROXY INIT ## {:?} {} {}", + self.address_map[0], self.current_module_id, self.current_module + ); + self.current_module_id = self.address_map[0].module_id as usize; + println!("## PROXY INIT2 ## {} {}", self.current_module_id, self.modules_data.len()); + self.current_module = self.modules_data[self.current_module_id].name.clone(); + self.module_end_addr = self.address_map[0].to_address; } fn finish_prove(&self) {} fn get_mem_value(&self, addr: u32, mem_op: &ZiskRequiredMemory) -> u64 { @@ -396,6 +439,28 @@ impl MemProxyEngine { 0 } } + fn close_address_map(&mut self) { + let mut next_address = 0; + let mut unmapped_regions: Vec<(u32, u32)> = Vec::new(); + for address_region in self.address_map.iter() { + if next_address < address_region.from_address { + unmapped_regions.push((next_address, address_region.from_address - 1)); + } + next_address = address_region.to_address + 1; + } + if !unmapped_regions.is_empty() { + let mut unmapped_module = MemUnmapped::::new(); + for unmapped_region in unmapped_regions.iter() { + println!( + "\x1B[36m## PROXY UNMAPPED ## unmapped_region: 0x{0:X} 0x{1:X}\x1B[0m", + unmapped_region.0, unmapped_region.1 + ); + unmapped_module.add_range(unmapped_region.0, unmapped_region.1); + } + self.add_module("unmapped", Arc::new(unmapped_module)); + } + self.address_map_closed = true; + } #[inline(always)] fn push_open_mem_align_op( @@ -412,7 +477,7 @@ impl MemProxyEngine { } fn log_mem_op(&self, mem_op: &ZiskRequiredMemory) { println!( - "##LOOP## mem_op: {0:?} 0x{1:#08X}({1}) 0x{2:#016X}({2})", + "## PROXY LOOP ## mem_op: {0:?} 0x{1:#08X}({1}) 0x{2:#016X}({2})", mem_op, self.last_addr, self.last_addr_value ); } @@ -423,54 +488,3 @@ impl MemProxyEngine { } impl WitnessComponent for MemProxy {} -/* -fn format_hex(value: u64) -> String { - let hex_str = format!("{:016x}", value); // Format hexadecimal amb 16 dígits i padding de 0s - hex_str - .as_bytes() // Converteix a bytes per manipular fàcilment - .chunks(4) // Separa en grups de 4 caràcters (2 bytes) - .map(|chunk| std::str::from_utf8(chunk).unwrap()) // Converteix cada chunk a &str - .collect::>() // Recull els chunks com a un vector - .join("_") // Uneix amb "_" -} - -fn mem_align_call( - mem_op: &ZiskRequiredMemory, - mem_values: [u64; 2], - phase: u8, -) -> MemAlignResponse { - // DEBUG: only for testing - let offset = (mem_op.address & 0x7) * 8; - let width = (mem_op.width as u64) * 8; - let double_address = (offset + width) > 64; - let mem_value = mem_values[phase as usize]; - let mask = 0xFFFF_FFFF_FFFF_FFFFu64 >> (64 - width); - if mem_op.is_write { - if phase == 0 { - MemAlignResponse { - more_address: double_address, - step: mem_op.step + 1, - value: Some( - (mem_value & (0xFFFF_FFFF_FFFF_FFFFu64 ^ (mask << offset))) | - ((mem_op.value & mask) << offset), - ), - } - } else { - MemAlignResponse { - more_address: false, - step: mem_op.step + 1, - value: Some( - (mem_value & (0xFFFF_FFFF_FFFF_FFFFu64 << (offset + width - 64))) | - ((mem_op.value & mask) >> (128 - offset - width)), - ), - } - } - } else { - MemAlignResponse { - more_address: double_address && phase == 0, - step: mem_op.step + 1, - value: None, - } - } -} -*/ diff --git a/state-machines/mem/src/mem_sm.rs b/state-machines/mem/src/mem_sm.rs index e854ddc5..6cec707e 100644 --- a/state-machines/mem/src/mem_sm.rs +++ b/state-machines/mem/src/mem_sm.rs @@ -3,6 +3,8 @@ use std::sync::{ Arc, Mutex, }; +const MEM_INITIAL_ADDRESS: u32 = 0xA0000000; +const MEM_FINAL_ADDRESS: u32 = MEM_INITIAL_ADDRESS + 128 * 1024 * 1024; use crate::MemModule; use p3_field::PrimeField; use proofman::{WitnessComponent, WitnessManager}; @@ -268,11 +270,12 @@ impl MemModule for MemSM { fn send_inputs(&self, mem_op: &[ZiskRequiredMemory]) { self.prove(&mem_op); } - fn get_addr_ranges(&self) -> Vec<(u64, u64)> { - vec![] + fn get_addr_ranges(&self) -> Vec<(u32, u32)> { + vec![(MEM_INITIAL_ADDRESS, MEM_FINAL_ADDRESS)] } fn get_flush_input_size(&self) -> u64 { - self.num_rows as u64 + // self.num_rows as u64 + 1024 } fn unregister_predecessor(&self) {} fn register_predecessor(&self) {} diff --git a/state-machines/mem/src/mem_unmapped.rs b/state-machines/mem/src/mem_unmapped.rs new file mode 100644 index 00000000..76659ef8 --- /dev/null +++ b/state-machines/mem/src/mem_unmapped.rs @@ -0,0 +1,33 @@ +use std::marker::PhantomData; + +use crate::MemModule; +use p3_field::PrimeField; + +use zisk_core::ZiskRequiredMemory; + +pub struct MemUnmapped { + ranges: Vec<(u32, u32)>, + __data: PhantomData, +} + +impl MemUnmapped { + pub fn new() -> Self { + Self { ranges: Vec::new(), __data: PhantomData } + } + pub fn add_range(&mut self, _start: u32, _end: u32) { + self.ranges.push((_start, _end)); + } +} +impl MemModule for MemUnmapped { + fn send_inputs(&self, _mem_op: &[ZiskRequiredMemory]) { + println!("## MemUnmapped ## access {:?}", _mem_op); + } + fn get_addr_ranges(&self) -> Vec<(u32, u32)> { + self.ranges.to_vec() + } + fn get_flush_input_size(&self) -> u64 { + 1024 + } + fn unregister_predecessor(&self) {} + fn register_predecessor(&self) {} +} From 5819463e337e335725c7e866c797b0ad290e2c1d Mon Sep 17 00:00:00 2001 From: zkronos73 Date: Thu, 21 Nov 2024 23:49:17 +0000 Subject: [PATCH 35/44] WIP memory proxy --- state-machines/mem/Cargo.toml | 3 +- state-machines/mem/src/lib.rs | 2 + state-machines/mem/src/mem_proxy.rs | 431 +-------------------- state-machines/mem/src/mem_proxy_engine.rs | 422 ++++++++++++++++++++ state-machines/mem/src/mem_sm.rs | 7 +- state-machines/mem/src/mem_unmapped.rs | 8 +- 6 files changed, 435 insertions(+), 438 deletions(-) create mode 100644 state-machines/mem/src/mem_proxy_engine.rs diff --git a/state-machines/mem/Cargo.toml b/state-machines/mem/Cargo.toml index 3da87e7b..97e73def 100644 --- a/state-machines/mem/Cargo.toml +++ b/state-machines/mem/Cargo.toml @@ -21,4 +21,5 @@ num-bigint = { workspace = true } [features] default = [] -no_lib_link = ["proofman-common/no_lib_link", "proofman/no_lib_link"] \ No newline at end of file +no_lib_link = ["proofman-common/no_lib_link", "proofman/no_lib_link"] +debug_mem_proxy_engine = [] \ No newline at end of file diff --git a/state-machines/mem/src/lib.rs b/state-machines/mem/src/lib.rs index 4f3116db..6e04d6e9 100644 --- a/state-machines/mem/src/lib.rs +++ b/state-machines/mem/src/lib.rs @@ -3,6 +3,7 @@ mod mem_align_sm; mod mem_constants; mod mem_helpers; mod mem_proxy; +mod mem_proxy_engine; mod mem_sm; mod mem_unmapped; @@ -11,5 +12,6 @@ pub use mem_align_sm::*; pub use mem_constants::*; pub use mem_helpers::*; pub use mem_proxy::*; +pub use mem_proxy_engine::*; pub use mem_sm::*; pub use mem_unmapped::*; diff --git a/state-machines/mem/src/mem_proxy.rs b/state-machines/mem/src/mem_proxy.rs index 4fabbb9f..2d2a1dbb 100644 --- a/state-machines/mem/src/mem_proxy.rs +++ b/state-machines/mem/src/mem_proxy.rs @@ -1,55 +1,15 @@ -use std::{ - collections::VecDeque, - sync::{ - atomic::{AtomicU32, Ordering}, - Arc, - }, +use std::sync::{ + atomic::{AtomicU32, Ordering}, + Arc, }; -const UNMAPPED_MODULE_ID: u8 = 0xFE; - -use crate::{ - mem_align_call, MemAlignResponse, MemAlignRomSM, MemAlignSM, MemSM, MemUnmapped, MAX_MEM_ADDR, - MAX_MEM_OPS_PER_MAIN_STEP, MAX_MEM_STEP, MEM_ADDR_BITS, MEM_ADDR_MASK, MEM_BYTES, -}; +use crate::{MemAlignRomSM, MemAlignSM, MemProxyEngine, MemSM}; use p3_field::PrimeField; use pil_std_lib::Std; -use proofman_util::{timer_start_debug, timer_stop_and_log_debug}; use zisk_core::ZiskRequiredMemory; use proofman::{WitnessComponent, WitnessManager}; -pub trait MemModule: Send + Sync { - fn send_inputs(&self, mem_op: &[ZiskRequiredMemory]); - fn get_addr_ranges(&self) -> Vec<(u32, u32)>; - fn get_flush_input_size(&self) -> u64; - fn unregister_predecessor(&self); - fn register_predecessor(&self); -} - -trait MemAlignSm { - fn get_mem_op( - &self, - mem_op: &ZiskRequiredMemory, - mem_values: [u64; 2], - phase: u8, - ) -> MemAlignResponse; -} - -struct MemModuleData { - pub name: String, - pub id: u8, - pub ranges: Vec<(u32, u32)>, - pub inputs: Vec, - pub flush_input_size: u64, -} - -struct MemAlignOperation { - addr: u32, - mem_op: ZiskRequiredMemory, - mem_value: [u64; 2], -} - pub struct MemProxy { // Count of registered predecessors registered_predecessors: AtomicU32, @@ -60,25 +20,6 @@ pub struct MemProxy { mem_align_rom_sm: Arc>, } -#[derive(Debug)] -pub struct AddressRegion { - from_address: u32, - to_address: u32, - module_id: u8, -} -pub struct MemProxyEngine { - modules: Vec>>, - modules_data: Vec, - open_mem_align_ops: VecDeque, - address_map: Vec, - address_map_closed: bool, - last_addr: u32, - last_addr_value: u64, - current_module_id: usize, - current_module: String, - module_end_addr: u32, -} - impl MemProxy { pub fn new(wcm: Arc>, std: Arc>) -> Arc { let mem_align_rom_sm = MemAlignRomSM::new(wcm.clone()); @@ -119,372 +60,8 @@ impl MemProxy { ) -> Result<(), Box> { let mut engine = MemProxyEngine::::new(); engine.add_module("mem", self.mem_sm.clone()); - engine.close_address_map(); engine.prove(&self.mem_align_sm, mem_operations) } } -impl MemProxyEngine { - pub fn new() -> Self { - let mut modules: Vec>> = Vec::new(); - let mut modules_data: Vec = Vec::new(); - - Self { - modules, - modules_data, - last_addr: 0, - last_addr_value: 0, - current_module_id: 0, - current_module: String::new(), - module_end_addr: 0, - open_mem_align_ops: VecDeque::new(), - address_map: Vec::new(), - address_map_closed: false, - } - } - - pub fn add_module(&mut self, name: &str, module: Arc>) { - if self.modules.is_empty() { - self.current_module = String::from(name); - } - let module_id = self.modules.len() as u8; - self.modules.push(module.clone()); - - let ranges = module.get_addr_ranges(); - let flush_input_size = module.get_flush_input_size(); - - for range in ranges.iter() { - println!("## PROXY adding range 0x{:X} 0x{:X} ##", range.0, range.1); - self.insert_address_range(range.0, range.1, module_id); - } - self.modules_data.push(MemModuleData { - name: String::from(name), - id: module_id, - ranges, - inputs: Vec::new(), - flush_input_size, - }); - } - /* insert in sort way the address map and verify that */ - fn insert_address_range(&mut self, from_address: u32, to_address: u32, module_id: u8) { - let region = AddressRegion { from_address, to_address, module_id }; - if let Some(index) = self.address_map.iter().position(|x| x.from_address >= from_address) { - self.address_map.insert(index, region); - } else { - self.address_map.push(region); - } - } - - pub fn prove( - &mut self, - mem_align_sm: &MemAlignSM, - mem_operations: &mut Vec, - ) -> Result<(), Box> { - self.init_prove(&mem_operations); - - // Step 1. Sort the aligned memory accesses - // original vector is sorted by step, sort_by_key is stable, no reordering of elements with - // the same key. - timer_start_debug!(MEM_SORT); - mem_operations.sort_by_key(|mem| (mem.address & 0xFFFF_FFF8)); - timer_stop_and_log_debug!(MEM_SORT); - - // Step2. Add a final mark mem_op to force flush of open_mem_align_ops, because always the - // last operation is mem_op. - mem_operations.push(Self::end_of_memory_mark()); - - // Step3. Process each memory operation ordered by address and step. When a non-aligned - // memory access there are two possible situations: - // - // 1) the operation applies only applies to one memory address (read or read+write). In - // this case mem_align helper return the aligned operation for this address, and loop - // continues. - // 2) the operation applies to two consecutive memory addresses, mem_align helper returns - // the aligned operation involved for the current address, and the second part of the - // operation is enqueued to open_mem_align_ops, it will processed when processing next - // address. - // - // Inside loop, first of all, we verify if exists "previous" open mem align operations that - // be processed before current mem_op, in this case process all "previous" and after process - // the current mem_op. - - for mem_op in mem_operations.iter_mut() { - self.log_mem_op(mem_op); - - let aligned_mem_addr = Self::to_aligned_addr(mem_op.address); - let mem_step = mem_op.step; - - // Check if there are open mem align operations to be processed in this moment, with - // address (or step) less than the aligned of current mem_op. - self.process_all_previous_open_mem_align_ops(aligned_mem_addr, mem_step); - - // check if we are at end of loop - if self.check_if_end_of_memory_mark(mem_op) { - break; - } - - // TODO: edge case special memory with free-input memory data as input data - let mem_value = self.get_mem_value(aligned_mem_addr, mem_op); - - // all open mem align operations are processed, check if new mem operation is aligned - if !Self::is_aligned(&mem_op) { - // In this point found non-aligned memory access, phase-0 - let mem_align_op = mem_align_sm.get_mem_op(mem_op, [mem_value, 0], 0); - - // if operation applies to two consecutive memory addresses, add the second part - // is enqueued to be processed in future when processing next address on phase-1 - if mem_align_op.more_address { - self.push_open_mem_align_op(aligned_mem_addr, mem_value, mem_op); - } - self.push_mem_align_response_ops( - aligned_mem_addr, - mem_value, - mem_op, - &mem_align_op, - ); - } else { - self.push_mem_op(mem_op); - } - } - self.finish_prove(); - Ok(()) - } - - fn process_all_previous_open_mem_align_ops(&mut self, mem_addr: u32, mem_step: u64) { - // Two possible situations to process open mem align operations: - // - // 1) the address of open operation is less than the aligned address. - // 2) the address of open operation is equal to the aligned address, but the step of the - // open operation is less than the step of the current operation. - - while self.has_open_mem_align_lt(mem_addr, mem_step) { - let open_op = self.open_mem_align_ops.pop_front().unwrap(); - let mem_value = if open_op.addr == self.last_addr { self.last_addr_value } else { 0 }; - - // call to mem_align to get information of the aligned memory access needed - // to prove the unaligned open operation. - let mem_align_op = mem_align_call(&open_op.mem_op, [mem_value, 0], 1); - - // remove element from top of queue, because we are on last phase, phase 1. - self.open_mem_align_ops.pop_front(); - - // push the aligned memory operations for current address (read or read+write) and - // update last_address and last_value. - self.push_mem_align_response_ops( - open_op.addr, - mem_value, - &open_op.mem_op, - &mem_align_op, - ); - } - } - - pub fn main_step_to_mem_step(step: u64, step_offset: u8) -> u64 { - 1 + MAX_MEM_OPS_PER_MAIN_STEP * step + 2 * step_offset as u64 - } - - /// Static method to decide it the memory operation needs to be processed by - /// memAlign, because it isn't a 8-byte and 8-byte aligned memory access. - fn is_aligned(mem_op: &ZiskRequiredMemory) -> bool { - let aligned_mem_address = (mem_op.address as u64 & MEM_ADDR_MASK) as u32; - aligned_mem_address == mem_op.address && mem_op.width == MEM_BYTES as u8 - } - fn push_mem_op(&mut self, mem_op: &ZiskRequiredMemory) { - self.push_aligned_op(mem_op.is_write, mem_op.address, mem_op.value, mem_op.step); - } - - fn push_aligned_op(&mut self, is_write: bool, addr: u32, value: u64, step: u64) { - self.update_last_addr(addr, value); - let mem_op = ZiskRequiredMemory { - step, - is_write, - address: addr as u32, - width: MEM_BYTES as u8, - value, - }; - println!("## PROXY SEND {0} ## {1:?}", self.current_module, mem_op); - self.modules_data[self.current_module_id].inputs.push(mem_op); - self.last_addr_value = value; - self.check_flush_inputs(); - } - // method to add aligned read operation - #[inline(always)] - fn push_aligned_read(&mut self, addr: u32, value: u64, step: u64) { - self.push_aligned_op(false, addr, value, step); - } - // method to add aligned write operation - #[inline(always)] - fn push_aligned_write(&mut self, addr: u32, value: u64, step: u64) { - self.push_aligned_op(true, addr, value, step); - } - /// Process information of mem_op and mem_align_op to push mem_op operation. Only two possible - /// situations: - /// 1) read, only on single mem_op is pushed - /// 2) read+write, two mem_op are pushed, one read and one write. - /// - /// This process is used for each aligned memory address, means that the "second part" of non - /// aligned memory operation is processed on addr + MEM_BYTES. - fn push_mem_align_response_ops( - &mut self, - mem_addr: u32, - mem_value: u64, - mem_op: &ZiskRequiredMemory, - mem_align_op: &MemAlignResponse, - ) { - self.push_aligned_read(mem_addr, mem_value, mem_align_op.step); - if mem_op.is_write { - let mem_value = mem_align_op.value.expect("value returned by mem_align"); - self.push_aligned_write(mem_addr, mem_value, mem_align_op.step + 1); - } - } - fn create_modules_inputs(&self) -> Vec> { - let mut mem_module_inputs: Vec> = Default::default(); - for module in self.modules.iter() { - mem_module_inputs.push(Vec::new()); - } - mem_module_inputs - } - fn set_active_region(&mut self, region_id: usize) { - self.current_module_id = self.address_map[region_id].module_id as usize; - self.current_module = self.modules_data[self.current_module_id].name.clone(); - self.module_end_addr = self.address_map[region_id].to_address; - } - fn update_mem_module_id(&mut self, addr: u32) { - println!( - "## \x1B[31mGET MODULE ID\x1B[0m ## 0x{0:X} module_end_addr:0x{1:X} 0x{2:X}", - addr, self.module_end_addr, MAX_MEM_ADDR as u32 - ); - // println!("{:?}", self.address_map); - if let Some(index) = - self.address_map.iter().position(|x| x.from_address <= addr && x.to_address >= addr) - { - self.set_active_region(index); - } else { - assert!(false, "out-of-memory 0x{:X}", addr); - } - } - fn update_last_addr(&mut self, addr: u32, value: u64) { - self.last_addr = addr; - // check if need to reevaluate the module id - if addr > self.module_end_addr { - self.update_mem_module_id(addr); - } - } - fn check_flush_inputs(&mut self) { - // check if need to flush the inputs of the module - let mid = self.current_module_id; - println!( - "## PROXY FLUSH ## {0} {1} {2}", - mid, - self.modules_data[mid].inputs.len(), - self.modules_data[mid].flush_input_size - ); - if (self.modules_data[mid].inputs.len() as u64) >= self.modules_data[mid].flush_input_size { - // TODO: optimize passing ownership of inputs to module, and creating a new input - // object - self.modules[mid].send_inputs(&self.modules_data[mid].inputs); - self.modules_data[mid].inputs.clear(); - } - } - - fn has_open_mem_align_lt(&self, addr: u32, step: u64) -> bool { - self.open_mem_align_ops.len() > 0 && - (self.open_mem_align_ops[0].addr < addr || - (self.open_mem_align_ops[0].addr == addr && - self.open_mem_align_ops[0].mem_op.step < step)) - } - // method to process open mem align operations, second part of non aligned memory operations - // applies to two consecutive memory addresses. - - fn end_of_memory_mark() -> ZiskRequiredMemory { - ZiskRequiredMemory { - step: MAX_MEM_STEP, - is_write: false, - address: MAX_MEM_ADDR as u32, - width: MEM_BYTES as u8, - value: 0, - } - } - #[inline(always)] - fn check_if_end_of_memory_mark(&self, mem_op: &ZiskRequiredMemory) -> bool { - if mem_op.step == MAX_MEM_STEP && mem_op.address == MAX_MEM_ADDR as u32 { - assert!( - self.open_mem_align_ops.len() == 0, - "open_mem_align_ops not empty, has {} elements", - self.open_mem_align_ops.len() - ); - true - } else { - false - } - } - fn init_prove(&mut self, mem_operations: &Vec) { - if !self.address_map_closed { - self.close_address_map(); - } - println!( - "## PROXY INIT ## {:?} {} {}", - self.address_map[0], self.current_module_id, self.current_module - ); - self.current_module_id = self.address_map[0].module_id as usize; - println!("## PROXY INIT2 ## {} {}", self.current_module_id, self.modules_data.len()); - self.current_module = self.modules_data[self.current_module_id].name.clone(); - self.module_end_addr = self.address_map[0].to_address; - } - fn finish_prove(&self) {} - fn get_mem_value(&self, addr: u32, mem_op: &ZiskRequiredMemory) -> u64 { - if addr == self.last_addr { - self.last_addr_value - } else { - 0 - } - } - fn close_address_map(&mut self) { - let mut next_address = 0; - let mut unmapped_regions: Vec<(u32, u32)> = Vec::new(); - for address_region in self.address_map.iter() { - if next_address < address_region.from_address { - unmapped_regions.push((next_address, address_region.from_address - 1)); - } - next_address = address_region.to_address + 1; - } - if !unmapped_regions.is_empty() { - let mut unmapped_module = MemUnmapped::::new(); - for unmapped_region in unmapped_regions.iter() { - println!( - "\x1B[36m## PROXY UNMAPPED ## unmapped_region: 0x{0:X} 0x{1:X}\x1B[0m", - unmapped_region.0, unmapped_region.1 - ); - unmapped_module.add_range(unmapped_region.0, unmapped_region.1); - } - self.add_module("unmapped", Arc::new(unmapped_module)); - } - self.address_map_closed = true; - } - - #[inline(always)] - fn push_open_mem_align_op( - &mut self, - aligned_mem_addr: u32, - mem_value: u64, - mem_op: &ZiskRequiredMemory, - ) { - self.open_mem_align_ops.push_back(MemAlignOperation { - addr: aligned_mem_addr + MEM_BYTES as u32, - mem_op: mem_op.clone(), - mem_value: [mem_value, 0], - }); - } - fn log_mem_op(&self, mem_op: &ZiskRequiredMemory) { - println!( - "## PROXY LOOP ## mem_op: {0:?} 0x{1:#08X}({1}) 0x{2:#016X}({2})", - mem_op, self.last_addr, self.last_addr_value - ); - } - #[inline(always)] - fn to_aligned_addr(addr: u32) -> u32 { - (addr as u64 & MEM_ADDR_MASK) as u32 - } -} - impl WitnessComponent for MemProxy {} diff --git a/state-machines/mem/src/mem_proxy_engine.rs b/state-machines/mem/src/mem_proxy_engine.rs new file mode 100644 index 00000000..32638db5 --- /dev/null +++ b/state-machines/mem/src/mem_proxy_engine.rs @@ -0,0 +1,422 @@ +use std::{collections::VecDeque, sync::Arc}; + +use crate::{ + mem_align_call, MemAlignResponse, MemAlignSM, MemUnmapped, MAX_MEM_ADDR, + MAX_MEM_OPS_PER_MAIN_STEP, MAX_MEM_STEP, MEM_ADDR_MASK, MEM_BYTES, +}; +use log::info; +use p3_field::PrimeField; +use proofman_util::{timer_start_debug, timer_stop_and_log_debug}; +use zisk_core::ZiskRequiredMemory; + +macro_rules! debug_info { + ($prefix:expr, $($arg:tt)*) => { + #[cfg(feature = "debug_mem_proxy_engine")] + { + info!(concat!("MemPE : ",$prefix), $($arg)*); + } + }; +} + +pub trait MemModule: Send + Sync { + fn send_inputs(&self, mem_op: &[ZiskRequiredMemory]); + fn get_addr_ranges(&self) -> Vec<(u32, u32)>; + fn get_flush_input_size(&self) -> u32; +} + +trait MemAlignSm { + fn get_mem_op( + &self, + mem_op: &ZiskRequiredMemory, + mem_values: [u64; 2], + phase: u8, + ) -> MemAlignResponse; +} + +struct MemModuleData { + pub name: String, + pub inputs: Vec, + pub flush_input_size: u32, +} + +struct MemAlignOperation { + addr: u32, + mem_op: ZiskRequiredMemory, + mem_value: [u64; 2], +} + +#[derive(Debug)] +pub struct AddressRegion { + from_address: u32, + to_address: u32, + module_id: u8, +} +pub struct MemProxyEngine { + modules: Vec>>, + modules_data: Vec, + open_mem_align_ops: VecDeque, + address_map: Vec, + address_map_closed: bool, + last_addr: u32, + last_addr_value: u64, + current_module_id: usize, + current_module: String, + module_end_addr: u32, +} + +impl MemProxyEngine { + pub fn new() -> Self { + Self { + modules: Vec::new(), + modules_data: Vec::new(), + last_addr: 0, + last_addr_value: 0, + current_module_id: 0, + current_module: String::new(), + module_end_addr: 0, + open_mem_align_ops: VecDeque::new(), + address_map: Vec::new(), + address_map_closed: false, + } + } + + pub fn add_module(&mut self, name: &str, module: Arc>) { + if self.modules.is_empty() { + self.current_module = String::from(name); + } + let module_id = self.modules.len() as u8; + self.modules.push(module.clone()); + + let ranges = module.get_addr_ranges(); + let flush_input_size = module.get_flush_input_size(); + + for range in ranges.iter() { + debug_info!("adding range 0x{:X} 0x{:X}", range.0, range.1); + self.insert_address_range(range.0, range.1, module_id); + } + self.modules_data.push(MemModuleData { + name: String::from(name), + inputs: Vec::new(), + flush_input_size, + }); + } + /* insert in sort way the address map and verify that */ + fn insert_address_range(&mut self, from_address: u32, to_address: u32, module_id: u8) { + let region = AddressRegion { from_address, to_address, module_id }; + if let Some(index) = self.address_map.iter().position(|x| x.from_address >= from_address) { + self.address_map.insert(index, region); + } else { + self.address_map.push(region); + } + } + + pub fn prove( + &mut self, + mem_align_sm: &MemAlignSM, + mem_operations: &mut Vec, + ) -> Result<(), Box> { + self.init_prove(); + + // Step 1. Sort the aligned memory accesses + // original vector is sorted by step, sort_by_key is stable, no reordering of elements with + // the same key. + timer_start_debug!(MEM_SORT); + mem_operations.sort_by_key(|mem| (mem.address & 0xFFFF_FFF8)); + timer_stop_and_log_debug!(MEM_SORT); + + // Step2. Add a final mark mem_op to force flush of open_mem_align_ops, because always the + // last operation is mem_op. + mem_operations.push(Self::end_of_memory_mark()); + + // Step3. Process each memory operation ordered by address and step. When a non-aligned + // memory access there are two possible situations: + // + // 1) the operation applies only applies to one memory address (read or read+write). In + // this case mem_align helper return the aligned operation for this address, and loop + // continues. + // 2) the operation applies to two consecutive memory addresses, mem_align helper returns + // the aligned operation involved for the current address, and the second part of the + // operation is enqueued to open_mem_align_ops, it will processed when processing next + // address. + // + // Inside loop, first of all, we verify if exists "previous" open mem align operations that + // be processed before current mem_op, in this case process all "previous" and after process + // the current mem_op. + + for mem_op in mem_operations.iter_mut() { + // self.log_mem_op(mem_op); + + let aligned_mem_addr = Self::to_aligned_addr(mem_op.address); + let mem_step = mem_op.step; + + // Check if there are open mem align operations to be processed in this moment, with + // address (or step) less than the aligned of current mem_op. + self.process_all_previous_open_mem_align_ops(aligned_mem_addr, mem_step); + + // check if we are at end of loop + if self.check_if_end_of_memory_mark(mem_op) { + break; + } + + // TODO: edge case special memory with free-input memory data as input data + let mem_value = self.get_mem_value(aligned_mem_addr); + + // all open mem align operations are processed, check if new mem operation is aligned + if !Self::is_aligned(&mem_op) { + // In this point found non-aligned memory access, phase-0 + let mem_align_op = mem_align_sm.get_mem_op(mem_op, [mem_value, 0], 0); + + // if operation applies to two consecutive memory addresses, add the second part + // is enqueued to be processed in future when processing next address on phase-1 + if mem_align_op.more_address { + self.push_open_mem_align_op(aligned_mem_addr, mem_value, mem_op); + } + self.push_mem_align_response_ops( + aligned_mem_addr, + mem_value, + mem_op, + &mem_align_op, + ); + } else { + self.push_mem_op(mem_op); + } + } + self.finish_prove(); + Ok(()) + } + + fn process_all_previous_open_mem_align_ops(&mut self, mem_addr: u32, mem_step: u64) { + // Two possible situations to process open mem align operations: + // + // 1) the address of open operation is less than the aligned address. + // 2) the address of open operation is equal to the aligned address, but the step of the + // open operation is less than the step of the current operation. + + while self.has_open_mem_align_lt(mem_addr, mem_step) { + let open_op = self.open_mem_align_ops.pop_front().unwrap(); + let mem_value = if open_op.addr == self.last_addr { self.last_addr_value } else { 0 }; + + // call to mem_align to get information of the aligned memory access needed + // to prove the unaligned open operation. + let mem_align_op = mem_align_call(&open_op.mem_op, [mem_value, 0], 1); + + // remove element from top of queue, because we are on last phase, phase 1. + self.open_mem_align_ops.pop_front(); + + // push the aligned memory operations for current address (read or read+write) and + // update last_address and last_value. + self.push_mem_align_response_ops( + open_op.addr, + mem_value, + &open_op.mem_op, + &mem_align_op, + ); + } + } + + pub fn main_step_to_mem_step(step: u64, step_offset: u8) -> u64 { + 1 + MAX_MEM_OPS_PER_MAIN_STEP * step + 2 * step_offset as u64 + } + + /// Static method to decide it the memory operation needs to be processed by + /// memAlign, because it isn't a 8-byte and 8-byte aligned memory access. + fn is_aligned(mem_op: &ZiskRequiredMemory) -> bool { + let aligned_mem_address = (mem_op.address as u64 & MEM_ADDR_MASK) as u32; + aligned_mem_address == mem_op.address && mem_op.width == MEM_BYTES as u8 + } + fn push_mem_op(&mut self, mem_op: &ZiskRequiredMemory) { + self.push_aligned_op(mem_op.is_write, mem_op.address, mem_op.value, mem_op.step); + } + + fn push_aligned_op(&mut self, is_write: bool, addr: u32, value: u64, step: u64) { + self.update_last_addr(addr, value); + let mem_op = ZiskRequiredMemory { + step, + is_write, + address: addr as u32, + width: MEM_BYTES as u8, + value, + }; + debug_info!( + "route ==> {}[{:X}] {} {} #{}", + self.current_module, + mem_op.address, + if is_write { "W" } else { "R" }, + value, + step, + ); + self.modules_data[self.current_module_id].inputs.push(mem_op); + self.last_addr_value = value; + self.check_flush_inputs(); + } + // method to add aligned read operation + #[inline(always)] + fn push_aligned_read(&mut self, addr: u32, value: u64, step: u64) { + self.push_aligned_op(false, addr, value, step); + } + // method to add aligned write operation + #[inline(always)] + fn push_aligned_write(&mut self, addr: u32, value: u64, step: u64) { + self.push_aligned_op(true, addr, value, step); + } + /// Process information of mem_op and mem_align_op to push mem_op operation. Only two possible + /// situations: + /// 1) read, only on single mem_op is pushed + /// 2) read+write, two mem_op are pushed, one read and one write. + /// + /// This process is used for each aligned memory address, means that the "second part" of non + /// aligned memory operation is processed on addr + MEM_BYTES. + fn push_mem_align_response_ops( + &mut self, + mem_addr: u32, + mem_value: u64, + mem_op: &ZiskRequiredMemory, + mem_align_op: &MemAlignResponse, + ) { + self.push_aligned_read(mem_addr, mem_value, mem_align_op.step); + if mem_op.is_write { + let mem_value = mem_align_op.value.expect("value returned by mem_align"); + self.push_aligned_write(mem_addr, mem_value, mem_align_op.step + 1); + } + } + fn create_modules_inputs(&self) -> Vec> { + let mut mem_module_inputs: Vec> = Default::default(); + for _module in self.modules.iter() { + mem_module_inputs.push(Vec::new()); + } + mem_module_inputs + } + fn set_active_region(&mut self, region_id: usize) { + self.current_module_id = self.address_map[region_id].module_id as usize; + self.current_module = self.modules_data[self.current_module_id].name.clone(); + self.module_end_addr = self.address_map[region_id].to_address; + } + fn update_mem_module_id(&mut self, addr: u32) { + debug_info!("search module for address 0x{:X}", addr); + if let Some(index) = + self.address_map.iter().position(|x| x.from_address <= addr && x.to_address >= addr) + { + self.set_active_region(index); + } else { + assert!(false, "out-of-memory 0x{:X}", addr); + } + } + fn update_last_addr(&mut self, addr: u32, value: u64) { + self.last_addr = addr; + self.last_addr_value = value; + // check if need to reevaluate the module id + if addr > self.module_end_addr { + self.update_mem_module_id(addr); + } + } + fn check_flush_inputs(&mut self) { + // check if need to flush the inputs of the module + let mid = self.current_module_id; + let inputs = self.modules_data[mid].inputs.len() as u32; + if inputs >= self.modules_data[mid].flush_input_size { + // TODO: optimize passing ownership of inputs to module, and creating a new input + // object + debug_info!("flush {} inputs => {}", inputs, self.current_module); + self.modules[mid].send_inputs(&self.modules_data[mid].inputs); + self.modules_data[mid].inputs.clear(); + } + } + + fn has_open_mem_align_lt(&self, addr: u32, step: u64) -> bool { + self.open_mem_align_ops.len() > 0 && + (self.open_mem_align_ops[0].addr < addr || + (self.open_mem_align_ops[0].addr == addr && + self.open_mem_align_ops[0].mem_op.step < step)) + } + // method to process open mem align operations, second part of non aligned memory operations + // applies to two consecutive memory addresses. + + fn end_of_memory_mark() -> ZiskRequiredMemory { + ZiskRequiredMemory { + step: MAX_MEM_STEP, + is_write: false, + address: MAX_MEM_ADDR as u32, + width: MEM_BYTES as u8, + value: 0, + } + } + #[inline(always)] + fn check_if_end_of_memory_mark(&self, mem_op: &ZiskRequiredMemory) -> bool { + if mem_op.step == MAX_MEM_STEP && mem_op.address == MAX_MEM_ADDR as u32 { + assert!( + self.open_mem_align_ops.len() == 0, + "open_mem_align_ops not empty, has {} elements", + self.open_mem_align_ops.len() + ); + true + } else { + false + } + } + fn init_prove(&mut self) { + if !self.address_map_closed { + self.close_address_map(); + } + self.current_module_id = self.address_map[0].module_id as usize; + self.current_module = self.modules_data[self.current_module_id].name.clone(); + self.module_end_addr = self.address_map[0].to_address; + } + fn finish_prove(&self) { + for (module_id, module) in self.modules.iter().enumerate() { + module.send_inputs(&self.modules_data[module_id].inputs); + } + } + fn get_mem_value(&self, addr: u32) -> u64 { + if addr == self.last_addr { + self.last_addr_value + } else { + 0 + } + } + fn close_address_map(&mut self) { + let mut next_address = 0; + let mut unmapped_regions: Vec<(u32, u32)> = Vec::new(); + for address_region in self.address_map.iter() { + if next_address < address_region.from_address { + unmapped_regions.push((next_address, address_region.from_address - 1)); + } + next_address = address_region.to_address + 1; + } + if !unmapped_regions.is_empty() { + let mut unmapped_module = MemUnmapped::::new(); + for unmapped_region in unmapped_regions.iter() { + unmapped_module.add_range(unmapped_region.0, unmapped_region.1); + } + self.add_module("unmapped", Arc::new(unmapped_module)); + } + self.address_map_closed = true; + } + + #[inline(always)] + fn push_open_mem_align_op( + &mut self, + aligned_mem_addr: u32, + mem_value: u64, + mem_op: &ZiskRequiredMemory, + ) { + self.open_mem_align_ops.push_back(MemAlignOperation { + addr: aligned_mem_addr + MEM_BYTES as u32, + mem_op: mem_op.clone(), + mem_value: [mem_value, 0], + }); + } + fn log_mem_op(&self, mem_op: &ZiskRequiredMemory) { + debug_info!( + "next input [0x{:x}] {} {} {}b #{}", + mem_op.address, + if mem_op.is_write { "W" } else { "R" }, + mem_op.value, + mem_op.width, + mem_op.step, + ); + } + #[inline(always)] + fn to_aligned_addr(addr: u32) -> u32 { + (addr as u64 & MEM_ADDR_MASK) as u32 + } +} diff --git a/state-machines/mem/src/mem_sm.rs b/state-machines/mem/src/mem_sm.rs index 6cec707e..ce7b9f14 100644 --- a/state-machines/mem/src/mem_sm.rs +++ b/state-machines/mem/src/mem_sm.rs @@ -273,12 +273,9 @@ impl MemModule for MemSM { fn get_addr_ranges(&self) -> Vec<(u32, u32)> { vec![(MEM_INITIAL_ADDRESS, MEM_FINAL_ADDRESS)] } - fn get_flush_input_size(&self) -> u64 { - // self.num_rows as u64 - 1024 + fn get_flush_input_size(&self) -> u32 { + self.num_rows as u32 } - fn unregister_predecessor(&self) {} - fn register_predecessor(&self) {} } impl WitnessComponent for MemSM {} diff --git a/state-machines/mem/src/mem_unmapped.rs b/state-machines/mem/src/mem_unmapped.rs index 76659ef8..988971d6 100644 --- a/state-machines/mem/src/mem_unmapped.rs +++ b/state-machines/mem/src/mem_unmapped.rs @@ -20,14 +20,12 @@ impl MemUnmapped { } impl MemModule for MemUnmapped { fn send_inputs(&self, _mem_op: &[ZiskRequiredMemory]) { - println!("## MemUnmapped ## access {:?}", _mem_op); + // panic!("[MemUnmapped] invalid access to addr {:x}", _mem_op[0].addr); } fn get_addr_ranges(&self) -> Vec<(u32, u32)> { self.ranges.to_vec() } - fn get_flush_input_size(&self) -> u64 { - 1024 + fn get_flush_input_size(&self) -> u32 { + 1 } - fn unregister_predecessor(&self) {} - fn register_predecessor(&self) {} } From 632d7cdbe6e67495422e8ad8cadf432d49e374a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9ctor=20Masip?= Date: Fri, 22 Nov 2024 07:50:16 +0000 Subject: [PATCH 36/44] Mem rom fully working --- state-machines/mem/pil/mem_align_rom.pil | 335 +++++++++++---------- state-machines/mem/src/mem_align_rom_sm.rs | 223 +++++++++----- 2 files changed, 318 insertions(+), 240 deletions(-) diff --git a/state-machines/mem/pil/mem_align_rom.pil b/state-machines/mem/pil/mem_align_rom.pil index 016dd956..d1e60c7c 100644 --- a/state-machines/mem/pil/mem_align_rom.pil +++ b/state-machines/mem/pil/mem_align_rom.pil @@ -4,15 +4,6 @@ require "constants.pil" const int MEM_ALIGN_ROM_ID = 133; const int MEM_ALIGN_ROM_SIZE = P2_8; -// PROGRAM SIZE -// RV 0 2 -// RWV 1 3 -// RVR 2 3 -// RWVWR 3 5 -// -// Note1: The offset and width are sufficient to group programs with the same number of operations. -// Note2: The first instruction is always a read. - airtemplate MemAlignRom(const int N = MEM_ALIGN_ROM_SIZE, const int CHUNK_NUM = 8, const int DEFAULT_OFFSET = 0, const int DEFAULT_WIDTH = 8, const int disable_fixed = 0) { if (N < MEM_ALIGN_ROM_SIZE) { error(`N must be at least ${MEM_ALIGN_ROM_SIZE}, but N=${N} was provided`); @@ -28,77 +19,85 @@ airtemplate MemAlignRom(const int N = MEM_ALIGN_ROM_SIZE, const int CHUNK_NUM = return; } - // Define the size of each program: RV, RWV, RVR, RWVWR - const int psize[4] = [2, 3, 3, 5]; + // Define the size of each sub-program: RV, RWV, RVR, RWVWR + const int spsize[4] = [2, 3, 3, 5]; // Not all combinations of offset and width are valid for each program: const int one_word_combinations = 20; // (0..4,[1,2,4]), (5,6,[1,2]), (7,[1]) -> 5*3 + 2*2 + 1*1 = 20 const int two_word_combinations = 11; // (1..4,[8]), (5,6,[4,8]), (7,[2,4,8]) -> 4*1 + 2*2 + 1*3 = 11 // table_size = combinations * program_size - const int tsize[4] = [one_word_combinations*psize[0], one_word_combinations*psize[1], two_word_combinations*psize[2], two_word_combinations*psize[3]]; - // size - // RV 6+6*4+4+4+2 = 40 | 40 - // RWV 9+9*4+6+6+3 = 60 | 100 - // RVR 3*4+6+6+9 = 33 | 133 - // RWVWR 5*4+10+10+15 = 55 | 188 => N = 2^8 - - // Moreover, offset is set to DEFAULT_OFFSET and width to DEFAULT_WIDTH in aligned memory accesses. - // offset == width == 0 is set at the very first row for padding - // size - col fixed OFFSET = [0, [[0,0]:3, [0,1]:3, [0,2]:3, [0,3]:3, [0,4]:3, [0,5]:2, [0,6]:2, [0,7]], // RV 6+6*4+4+4+2 = 40 | 40 - [[0,0,0]:3, [0,0,1]:3, [0,0,2]:3, [0,0,3]:3, [0,0,4]:3, [0,0,5]:2, [0,0,6]:2, [0,0,7]], // RWV 9+9*4+6+6+3 = 60 | 100 - [[0,1,0], [0,2,0], [0,3,0], [0,4,0], [0,5,0]:2, [0,6,0]:2, [0,7,0]:3], // RVR 3*4+6+6+9 = 33 | 133 - [[0,0,1,0,0], [0,0,2,0,0], [0,0,3,0,0], [0,0,4,0,0], [0,0,5,0,0]:2, [0,0,6,0,0]:2, [0,0,7,0,0]:3]]...; // RWVWR 5*4+10+10+15 = 55 | 188 => N = 2^8 - - col fixed WIDTH = [0, [[8,1,8,2,8,4]:5, [8,1,8,2]:2, [8,1]], // RV - [[8,8,1,8,8,2,8,8,4]:5, [8,8,1,8,8,2]:2, [8,8,1]], // RWV - [[8,8,8]:4, [8,4,8,8,8,8]:2, [8,2,8,8,4,8,8,8,8]], // RVR - [[8,8,8,8,8]:4, [8,8,4,8,8,8,8,8,8,8]:2, [8,8,2,8,8,8,8,4,8,8,8,8,8,8,8]]]...; // RWVWR + const int tsize[4] = [one_word_combinations*spsize[0], one_word_combinations*spsize[1], two_word_combinations*spsize[2], two_word_combinations*spsize[3]]; + const int psize = tsize[0] + tsize[1] + tsize[2] + tsize[3]; - // TODO: Do a less-hardcoded version of the OFFSET and WIDTH computation - // col fixed OFFSET; - // col fixed WIDTH; - // for (int i = 0; i < N; i++) { - // int offset = 0; - // int width = 0; + // Offset is set to DEFAULT_OFFSET and width to DEFAULT_WIDTH in aligned memory accesses. + // Offset and width are set to 0 in padding lines. + // size + col fixed OFFSET = [0, // Padding 1 = 1 | 1 + [[0,0]:3, [0,1]:3, [0,2]:3, [0,3]:3, [0,4]:3, [0,5]:2, [0,6]:2, [0,7]], // RV 6+6*4+4+4+2 = 40 | 41 + [[0,0,0]:3, [0,0,1]:3, [0,0,2]:3, [0,0,3]:3, [0,0,4]:3, [0,0,5]:2, [0,0,6]:2, [0,0,7]], // RWV 9+9*4+6+6+3 = 60 | 101 + [[0,1,0], [0,2,0], [0,3,0], [0,4,0], [0,5,0]:2, [0,6,0]:2, [0,7,0]:3], // RVR 3*4+6+6+9 = 33 | 134 + [[0,0,1,0,0], [0,0,2,0,0], [0,0,3,0,0], [0,0,4,0,0], [0,0,5,0,0]:2, [0,0,6,0,0]:2, [0,0,7,0,0]:3], // RWVWR 5*4+10+10+15 = 55 | 189 => N = 2^8 + 0...]; // Padding - // OFFSET[i] = offset; - // WIDTH[i] = width; - // } + col fixed WIDTH = [0, // Padding + [[8,1,8,2,8,4]:5, [8,1,8,2]:2, [8,1]], // RV + [[8,8,1,8,8,2,8,8,4]:5, [8,8,1,8,8,2]:2, [8,8,1]], // RWV + [[8,8,8]:4, [8,4,8,8,8,8]:2, [8,2,8,8,4,8,8,8,8]], // RVR + [[8,8,8,8,8]:4, [8,8,4,8,8,8,8,8,8,8]:2, [8,8,2,8,8,8,8,4,8,8,8,8,8,8,8]], // RWVWR + 0...]; // Padding - // line | pc | pc'-pc | reset | addr | (addr-'addr)*(1-reset) | - // 0 | 0 | 0 | 0 | 0 | 0 | // for padding + // line | pc | pc'-pc | reset | addr | (addr-'addr)*(1-reset) | + // 0 | 0 | 0 | 1 | 0 | 0 | // for padding // 1 | 0 | 1 | 1 | X1 | 0 | // (RV) // 2 | 1 | -1 | 0 | X1 | 0 | // 3 | 0 | 3 | 1 | X2 | 0 | // (RV) // 4 | 3 | -3 | 0 | X2 | 0 | // 5 | 0 | 5 | 1 | X3 | 0 | // (RV) // 6 | 5 | -5 | 0 | X3 | 0 | - // 7 | 0 | 7 | 1 | ⋮ | ⋮ | // (RV) + // 7 | 0 | 7 | 1 | ⋮ | ⋮ | // (RV) // ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | // 41 | 0 | 41 | 1 | X4 | 0 | // (RWV) // 42 | 41 | 1 | 0 | X4 | 0 | // 43 | 42 | -42 | 0 | X4 | 0 | // 44 | 0 | 44 | 1 | X5 | 0 | // (RWV) + // 45 | 44 | 1 | 0 | X5 | 0 | + // 46 | 45 | -45 | 0 | X5 | 0 | + // 47 | 0 | 47 | 1 | X6 | 0 | // (RWV) // ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | - // 101 | 0 | 101 | 1 | X6 | 0 | // (RVR) - // 102 |101 | 1 | 0 | X6 | 0 | - // 103 |102 | -102 | 0 | X6+1 | 1 | + // 101 | 0 | 101 | 1 | X7 | 0 | // (RVR) + // 102 |101 | 1 | 0 | X7 | 0 | + // 103 |102 | -102 | 0 | X7+1 | 1 | + // 104 | 0 | 104 | 1 | X8 | 0 | // (RVR) + // 105 |104 | 1 | 0 | X8 | 0 | + // 106 |105 | -105 | 0 | X8+1 | 1 | + // 107 | 0 | 107 | 1 | X9 | 0 | // (RVR) // ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | - // 134 | 0 | 134 | 1 | X7 | 0 | // (RWVWR) - // 135 |134 | 1 | 0 | X7 | 0 | - // 136 |135 | 1 | 0 | X7 | 0 | - // 137 |136 | 1 | 0 | X7+1 | 1 | - // 138 |137 | -137 | 0 | X7+1 | 1 | + // 134 | 0 | 134 | 1 | X10 | 0 | // (RWVWR) + // 135 |134 | 1 | 0 | X10 | 0 | + // 136 |135 | 1 | 0 | X10 | 0 | + // 137 |136 | 1 | 0 | X10+1 | 1 | + // 138 |137 | -137 | 0 | X10+1 | 0 | + // 139 | 0 | 139 | 1 | X11 | 0 | // (RWVWR) + // 140 |139 | 1 | 0 | X11 | 0 | + // 141 |140 | 1 | 0 | X11 | 0 | + // 142 |141 | 1 | 0 | X11+1 | 1 | + // 143 |142 | -142 | 0 | X11+1 | 0 | + // 144 | 0 | 144 | 1 | X12 | 0 | // (RWVWR) // ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | + // 188 |187 | -187 | 0 | X13+1 | 0 | + // 189 | 0 | 0 | 1 | 0 | 0 | // for padding + // ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | + + // Note: The overall program contains "holes", meaning that pc can vary + // from program to program by any constant, as long as it is unique for each program. + // For example, the first program has pc=0,1, while the second has pc=0,3. col fixed PC; col fixed DELTA_PC; col fixed DELTA_ADDR; col fixed FLAGS; for (int i = 0; i < N; i++) { - const int [offset, width] = [OFFSET[i], WIDTH[i]]; int pc = 0; int delta_pc = 0; int delta_addr = 0; @@ -111,213 +110,215 @@ airtemplate MemAlignRom(const int N = MEM_ALIGN_ROM_SIZE, const int CHUNK_NUM = int sel_up_to_down = 0; int sel_down_to_up = 0; + const int prev_line = i == 0 ? 0 : i-1; const int line = i; - const int next = i+1; - if (line == 0) { // padding - // Do nothing + if (line == 0 || line > psize) + { + // pc = 0; + // delta_pc = 0; + // delta_addr = 0; + // is_write = 0; + reset = 1; + // sel = [0:CHUNK_NUM] + // sel_up_to_down = 0; + // sel_down_to_up = 0; } - else if (line < tsize[0]) // RV + else if (line < 1+tsize[0]) // RV { - if (line % 2 == 0) { + if (line % 2 == 1) { // pc = 0; - delta_pc = next; + delta_pc = line; // delta_addr = 0; // is_write = 0; reset = 1; - sel = get_selectors(offset, width, program: 0); + for (int j = 0; j < CHUNK_NUM; j++) { + if (j >= OFFSET[i+1] && j < OFFSET[i+1] + WIDTH[i+1]) { + sel[j] = 1; + } + } sel_up_to_down = 1; // sel_down_to_up = 0; } else { - pc = line; + pc = prev_line; delta_pc = -pc; - delta_addr = 1; + // delta_addr = 0; // is_write = 0; // reset = 0; - // sel = [0:CHUNK_NUM] + for (int j = 0; j < CHUNK_NUM; j++) { + if (j == OFFSET[i]) { + sel[j] = 1; + } + } // sel_up_to_down = 0; // sel_down_to_up = 0; } } - else if (line < tsize[0]+tsize[1]) // RWV + else if (line < 1+tsize[0]+tsize[1]) // RWV { - if (line % 3 == 0) { // R + if (line % 3 == 2) { // pc = 0; - delta_pc = next; + delta_pc = line; // delta_addr = 0; // is_write = 0; reset = 1; - sel = get_selectors(offset, width, program: 1); + for (int j = 0; j < CHUNK_NUM; j++) { + if (j < OFFSET[i+2] || j >= OFFSET[i+2] + WIDTH[i+2]) { + sel[j] = 1; + } + } sel_up_to_down = 1; // sel_down_to_up = 0; - } else if (line % 3 == 1) { // W - pc = line; + } else if (line % 3 == 0) { + pc = prev_line; delta_pc = 1; - delta_addr = 1; - is_write = 0; + // delta_addr = 0; + is_write = 1; // reset = 0; - sel = get_selectors(offset, width, program: 1, is_write: 1); + for (int j = 0; j < CHUNK_NUM; j++) { + if (j >= OFFSET[i+1] && j < OFFSET[i+1] + WIDTH[i+1]) { + sel[j] = 1; + } + } sel_up_to_down = 1; // sel_down_to_up = 0; - } else { // V - pc = line; + } else { + pc = prev_line; delta_pc = -pc; // delta_addr = 0; // is_write = 0; // reset = 0; - // sel = [0:CHUNK_NUM] + for (int j = 0; j < CHUNK_NUM; j++) { + if (j == OFFSET[i]) { + sel[j] = 1; + } + } // sel_up_to_down = 0; // sel_down_to_up = 0; } } - else if (line < tsize[0]+tsize[1]+tsize[2]) + else if (line < 1+tsize[0]+tsize[1]+tsize[2]) // RVR { - if (line % 3 == 0) { // R + if (line % 3 == 2) { // pc = 0; - delta_pc = next; + delta_pc = line; // delta_addr = 0; // is_write = 0; reset = 1; - sel = get_selectors(offset, width, program: 2); + for (int j = 0; j < CHUNK_NUM; j++) { + if (j >= OFFSET[i+1]) { + sel[j] = 1; + } + } sel_up_to_down = 1; // sel_down_to_up = 0; - } else if (line % 3 == 1) { // V - pc = line; + } else if (line % 3 == 0) { + pc = prev_line; delta_pc = 1; - delta_addr = 1; + // delta_addr = 0; // is_write = 0; // reset = 0; - // sel = [0:CHUNK_NUM] - // sel_up_to_down = 1; + for (int j = 0; j < CHUNK_NUM; j++) { + if (j == OFFSET[i]) { + sel[j] = 1; + } + } + // sel_up_to_down = 0; // sel_down_to_up = 0; - } else { // R - pc = line; + } else { + pc = prev_line; delta_pc = -pc; - // delta_addr = 0; + delta_addr = 1; // is_write = 0; // reset = 0; - sel = get_selectors(offset, width, program: 2); + for (int j = 0; j < CHUNK_NUM; j++) { + if (j < (OFFSET[i-1] + WIDTH[i-1]) % CHUNK_NUM) { + sel[j] = 1; + } + } // sel_up_to_down = 0; sel_down_to_up = 1; } } - else if (line < tsize[0]+tsize[1]+tsize[2]+tsize[3]) + else if (line < 1+tsize[0]+tsize[1]+tsize[2]+tsize[3]) // RWVWR { - if (next % 5 == 0) { // R + if (line % 5 == 4) { // pc = 0; - delta_pc = next; + delta_pc = line; // delta_addr = 0; // is_write = 0; reset = 1; - sel = get_selectors(offset, width, program: 3, is_write: 0, is_first: 1); + for (int j = 0; j < CHUNK_NUM; j++) { + if (j < OFFSET[i+2]) { + sel[j] = 1; + } + } sel_up_to_down = 1; // sel_down_to_up = 0; - } else if (next % 5 == 1) { // W - pc = line; + } else if (line % 5 == 0) { + pc = prev_line; delta_pc = 1; - delta_addr = 1; + // delta_addr = 0; is_write = 1; // reset = 0; - sel = get_selectors(offset, width, program: 3, is_write: 1, is_first: 1); + for (int j = 0; j < CHUNK_NUM; j++) { + if (j >= OFFSET[i+1]) { + sel[j] = 1; + } + } sel_up_to_down = 1; // sel_down_to_up = 0; - } else if (next % 5 == 2) { // V - pc = line; + } else if (line % 5 == 1) { + pc = prev_line; delta_pc = 1; - delta_addr = 1; + // delta_addr = 0; // is_write = 0; // reset = 0; - // sel = [0:CHUNK_NUM] + for (int j = 0; j < CHUNK_NUM; j++) { + if (j == OFFSET[i]) { + sel[j] = 1; + } + } // sel_up_to_down = 0; // sel_down_to_up = 0; - } else if (next % 5 == 3) { // W - pc = line; + } else if (line % 5 == 2) { + pc = prev_line; delta_pc = 1; delta_addr = 1; is_write = 1; // reset = 0; - sel = get_selectors(offset, width, program: 3, is_write: 1, is_first: 0); + for (int j = 0; j < CHUNK_NUM; j++) { + if (j < (OFFSET[i-1] + WIDTH[i-1]) % CHUNK_NUM) { + sel[j] = 1; + } + } // sel_up_to_down = 0; sel_down_to_up = 1; - } else { // R - pc = line; + } else { + pc = prev_line; delta_pc = -pc; // delta_addr = 0; // is_write = 0; // reset = 0; - sel = get_selectors(offset, width, program: 3, is_write: 0, is_first: 0); + for (int j = 0; j < CHUNK_NUM; j++) { + if (j >= (OFFSET[i-2] + WIDTH[i-2]) % CHUNK_NUM) { + sel[j] = 1; + } + } // sel_up_to_down = 0; sel_down_to_up = 1; } } - PC[i] = pc; DELTA_PC[i] = delta_pc; DELTA_ADDR[i] = delta_addr; - FLAGS[i] = 0; + int flags = 0; for (int j = 0; j < CHUNK_NUM; j++) { - FLAGS[i] += sel[j] * 2**j; + flags += sel[j] * 2**j; } - FLAGS[i] += is_write * 2**CHUNK_NUM + reset * 2**(CHUNK_NUM + 1) + sel_up_to_down * 2**(CHUNK_NUM + 2) + sel_down_to_up * 2**(CHUNK_NUM + 3); + flags += is_write * 2**CHUNK_NUM + reset * 2**(CHUNK_NUM + 1) + sel_up_to_down * 2**(CHUNK_NUM + 2) + sel_down_to_up * 2**(CHUNK_NUM + 3); + FLAGS[i] = flags; } + // Ensure the program is being followed by the MemAlign lookup_proves(MEM_ALIGN_ROM_ID, [PC, DELTA_PC, DELTA_ADDR, OFFSET, WIDTH, FLAGS], multiplicity); -} - -private function get_selectors(const int offset, const int width, const int program, const int is_write = 0, const int is_first = 0, const int bytes = 8): int[] { - int _sel[bytes]; - for (int j = 0; j < bytes; j++) { - _sel[j] = 0; - } - - switch (program) { - case 0: // RV - for (int j = 0; j < offset; j++) { - _sel[j] = 1; - } - - case 1: // RWV - if (!is_write) { - for (int j = 0; j < offset; j++) { - _sel[j] = 1; - } - } else { - for (int j = offset; j < offset + width; j++) { - _sel[j] = 1; - } - } - - case 2: // RVR - for (int j = 0; j < offset; j++) { - _sel[j] = 1; - } - - case 3: // RWVWR - if (is_first) { - if (!is_write) { - for (int j = 0; j < offset; j++) { - _sel[j] = 1; - } - } else { - for (int j = offset; j < bytes; j++) { - _sel[j] = 1; - } - } - } else { - const int rem = (offset + width) % bytes; - if (is_write) { - for (int j = 0; j < rem; j++) { - _sel[j] = 1; - } - } else { - for (int j = rem; j < bytes; j++) { - _sel[j] = 1; - } - } - } - - default: - error(`Invalid program ${program}`); - } - - return _sel; } \ No newline at end of file diff --git a/state-machines/mem/src/mem_align_rom_sm.rs b/state-machines/mem/src/mem_align_rom_sm.rs index 59b46bcd..b5e16545 100644 --- a/state-machines/mem/src/mem_align_rom_sm.rs +++ b/state-machines/mem/src/mem_align_rom_sm.rs @@ -23,7 +23,9 @@ pub enum MemOp { } const CHUNK_NUM: usize = 8; -const OP_SIZES: [usize; 4] = [2, 3, 3, 5]; +const OP_SIZES: [u64; 4] = [2, 3, 3, 5]; +const ONE_WORD_COMBINATIONS: u64 = 20; // (0..4,[1,2,4]), (5,6,[1,2]), (7,[1]) -> 5*3 + 2*2 + 1*1 = 20 +const TWO_WORD_COMBINATIONS: u64 = 11; // (1..4,[8]), (5,6,[4,8]), (7,[2,4,8]) -> 4*1 + 2*2 + 1*3 = 11 pub struct MemAlignRomSM { // Witness computation manager @@ -76,109 +78,178 @@ impl MemAlignRomSM { } } - pub fn get_mem_align_op_size(op: MemOp) -> usize { - OP_SIZES[op as usize] + pub fn calculate_next_pc(&self, opcode: MemOp, offset: usize, width: usize) -> u64 { + let row_idxs = Self::get_row_idxs(&self, opcode, offset, width); + + // Update the multiplicity + let ones: Vec = vec![1; row_idxs.len()]; + self.update_multiplicity_by_row_idx(&row_idxs, &ones); + + row_idxs[0] } - fn calculate_rom_rows(opcode: MemOp, offset: usize, width: usize) -> Vec { - // Calculate the ROM rows based on the requested opcode, offset, and width + fn get_row_idxs(&self, opcode: MemOp, offset: usize, width: usize) -> Vec { + let opcode_idx = opcode as usize; + let op_size = OP_SIZES[opcode_idx]; match opcode { MemOp::OneRead | MemOp::OneWrite => { // Sanity check assert!(offset + width <= CHUNK_NUM); - let possible_widths = match offset { - x if x <= 4 => vec![1, 2, 4], - x if x <= 6 => vec![1, 2], - x if x == 7 => vec![1], - _ => panic!("Invalid offset={}", offset), + + // Go to the actual operation + let mut value_row = match opcode { + MemOp::OneRead => 1, + MemOp::OneWrite => 1 + ONE_WORD_COMBINATIONS * OP_SIZES[0], + _ => unreachable!(), }; - Self::get_row_idxs(opcode, possible_widths, offset, width) + + match opcode { + MemOp::OneRead => { + if offset == 7 && width == 1 + { + println!("OneRead value_row: {}", value_row); + } + }, + MemOp::OneWrite => { + if offset == 3 && width == 1 + { + println!("OneWrite value_row: {}", value_row); + } + }, + _ => {} + } + + // Go to the actual offset + for i in 0..offset { + let possible_widths = Self::calculate_possible_widths(true, i); + value_row += op_size * possible_widths.len() as u64; + } + + match opcode { + MemOp::OneRead => { + if offset == 7 && width == 1 + { + println!("OneRead value_row: {}", value_row); + } + }, + MemOp::OneWrite => { + if offset == 3 && width == 1 + { + println!("OneWrite value_row: {}", value_row); + } + }, + _ => {} + } + + // Go to the right width + let width_idx = Self::calculate_possible_widths(true, offset) + .iter() + .position(|&w| w == width) + .expect("Invalid width"); + value_row += op_size * width_idx as u64; + + match opcode { + MemOp::OneRead => { + if offset == 7 && width == 1 + { + println!("OneRead value_row: {}", value_row); + } + }, + MemOp::OneWrite => { + if offset == 3 && width == 1 + { + println!("opsizes: {:?}", op_size); + println!("width_idx: {:?}", width_idx); + println!("OneWrite value_row: {}", value_row); + } + }, + _ => {} + } + + assert!(value_row < self.num_rows as u64); + + match opcode { + MemOp::OneRead => vec![value_row, value_row + 1], + MemOp::OneWrite => vec![value_row, value_row + 1, value_row + 2], + _ => unreachable!(), + } } MemOp::TwoReads | MemOp::TwoWrites => { // Sanity check assert!(offset + width > CHUNK_NUM); - let possible_widths = match offset { - x if x == 0 => panic!("Invalid offset={}", offset), - x if x <= 4 => vec![8], - x if x <= 6 => vec![4, 8], - x if x == 7 => vec![2, 4, 8], - _ => panic!("Invalid offset={}", offset), + + // Go to the actual operation + let mut value_row = match opcode { + MemOp::TwoReads => { + 1 + ONE_WORD_COMBINATIONS * OP_SIZES[0] + ONE_WORD_COMBINATIONS * OP_SIZES[1] + } + MemOp::TwoWrites => { + 1 + ONE_WORD_COMBINATIONS * OP_SIZES[0] + + ONE_WORD_COMBINATIONS * OP_SIZES[1] + + TWO_WORD_COMBINATIONS * OP_SIZES[2] + } + _ => unreachable!(), }; - Self::get_row_idxs(opcode, possible_widths, offset, width) - } - } - } - fn get_row_idxs( - opcode: MemOp, - possible_widths: Vec, - offset: usize, - width: usize, - ) -> Vec { - // Sanity check - assert!(possible_widths.contains(&width)); + // Go to the actual offset + for i in 1..offset { + let possible_widths = Self::calculate_possible_widths(false, i); + value_row += op_size * possible_widths.len() as u64; + } + + assert!(value_row < self.num_rows as u64); + + // Go to the right width + let width_idx = Self::calculate_possible_widths(false, offset) + .iter() + .position(|&w| w == width) + .expect("Invalid width"); + value_row += op_size * width_idx as u64; - let width_idx = possible_widths.iter().position(|&w| w == width).unwrap(); - let opcode_idx = opcode as usize; - match opcode { - MemOp::OneRead | MemOp::OneWrite => { - let value_row = (offset * possible_widths.len() * OP_SIZES[opcode_idx] - + (offset + width_idx + 1) * OP_SIZES[opcode_idx] - - 1) as u64; match opcode { - MemOp::OneRead => vec![value_row - 1, value_row], - MemOp::OneWrite => vec![value_row - 2, value_row - 1, value_row], + MemOp::TwoReads => vec![value_row, value_row + 1, value_row + 2], + MemOp::TwoWrites => { + vec![value_row, value_row + 1, value_row + 2, value_row + 3, value_row + 4] + } _ => unreachable!(), } } - MemOp::TwoReads => { - let value_row = (offset * possible_widths.len() * OP_SIZES[opcode_idx] - + (offset + width_idx + 1) * OP_SIZES[opcode_idx] - - 2) as u64; - return vec![value_row - 1, value_row, value_row + 1]; - } - MemOp::TwoWrites => { - let value_row = (offset * possible_widths.len() * OP_SIZES[opcode_idx] - + (offset + width_idx + 1) * OP_SIZES[opcode_idx] - - 3) as u64; - return vec![value_row - 2, value_row - 1, value_row, value_row + 1, value_row + 2]; - } } } - pub fn calculate_next_pc(&self, op: MemOp, offset: usize, width: usize) -> u64 { - let row_idxs = Self::calculate_rom_rows(op, offset, width); - - // Update the multiplicity - self.update_multiplicity_by_idx(&row_idxs); - - // The "next" pc is always found on the second row of the program being executed - row_idxs[1] + fn calculate_possible_widths(one_word: bool, offset: usize) -> Vec { + // Calculate the ROM rows based on the requested opcode, offset, and width + match one_word { + true => match offset { + x if x <= 4 => vec![1, 2, 4], + x if x <= 6 => vec![1, 2], + x if x == 7 => vec![1], + _ => panic!("Invalid offset={}", offset), + }, + false => match offset { + x if x == 0 => panic!("Invalid offset={}", offset), + x if x <= 4 => vec![8], + x if x <= 6 => vec![4, 8], + x if x == 7 => vec![2, 4, 8], + _ => panic!("Invalid offset={}", offset), + }, + } } pub fn update_padding_row(&self, padding_len: u64) { // Update entry at the padding row (pos = 0) with the given padding length - self.update_multiplicity(&[padding_len]); - } - - pub fn update_multiplicity_by_input(&self, opcode: MemOp, offset: usize, width: usize) { - let row_idxs = Self::calculate_rom_rows(opcode, offset, width); - self.update_multiplicity_by_idx(&row_idxs); + self.update_multiplicity_by_row_idx(&[0], &[padding_len]); } - pub fn update_multiplicity_by_idx(&self, idxs: &[u64]) { - let mut multiplicity = self.multiplicity.lock().unwrap(); - - for &i in idxs { - *multiplicity.entry(i).or_insert(0) += 1; + pub fn update_multiplicity_by_row_idx(&self, row_idxs: &[u64], muls: &[u64]) { + if row_idxs.len() != muls.len() { + panic!("The number of indices and multiplicities must be the same"); } - } - pub fn update_multiplicity(&self, inputs: &[u64]) { let mut multiplicity = self.multiplicity.lock().unwrap(); - for (idx, mul) in inputs.iter().enumerate() { - *multiplicity.entry(idx as u64).or_insert(0) += *mul; + for (i, &idx) in row_idxs.iter().enumerate() { + *multiplicity.entry(idx).or_insert(0) += muls[i]; } } @@ -205,6 +276,12 @@ impl MemAlignRomSM { ) .unwrap(); + // Initialize the trace buffer to zero + for i in 0..air_mem_align_rom_rows { + trace_buffer[i] = MemAlignRomRow { multiplicity: F::zero() }; + } + + // Fill the trace buffer with the multiplicity values if let Ok(multiplicity) = self.multiplicity.lock() { for (row_idx, multiplicity) in multiplicity.iter() { trace_buffer[*row_idx as usize] = From e92abdd54c709d9298476b7a145408dad255a4a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9ctor=20Masip?= Date: Fri, 22 Nov 2024 07:53:17 +0000 Subject: [PATCH 37/44] Mem align fully working --- core/src/zisk_required_operation.rs | 2 +- state-machines/mem/src/mem_align_rom_sm.rs | 53 +---- state-machines/mem/src/mem_align_sm.rs | 253 +++++++++++++++++---- 3 files changed, 207 insertions(+), 101 deletions(-) diff --git a/core/src/zisk_required_operation.rs b/core/src/zisk_required_operation.rs index 41056a6a..5ee2ec5f 100644 --- a/core/src/zisk_required_operation.rs +++ b/core/src/zisk_required_operation.rs @@ -22,7 +22,7 @@ impl fmt::Debug for ZiskRequiredMemory { let label = if self.is_write { "WR" } else { "RD" }; write!( f, - "{0} addr:{1:#08X}({1}) with:{2} value:{3:#016X}({3}) step:{4} offset:{5}", + "{0} addr:{1:#08X}({1}) offset:{5} with:{2} value:{3:#016X}({3}) step:{4}", label, self.address, self.width, diff --git a/state-machines/mem/src/mem_align_rom_sm.rs b/state-machines/mem/src/mem_align_rom_sm.rs index b5e16545..1953c016 100644 --- a/state-machines/mem/src/mem_align_rom_sm.rs +++ b/state-machines/mem/src/mem_align_rom_sm.rs @@ -103,44 +103,12 @@ impl MemAlignRomSM { _ => unreachable!(), }; - match opcode { - MemOp::OneRead => { - if offset == 7 && width == 1 - { - println!("OneRead value_row: {}", value_row); - } - }, - MemOp::OneWrite => { - if offset == 3 && width == 1 - { - println!("OneWrite value_row: {}", value_row); - } - }, - _ => {} - } - // Go to the actual offset for i in 0..offset { let possible_widths = Self::calculate_possible_widths(true, i); value_row += op_size * possible_widths.len() as u64; } - match opcode { - MemOp::OneRead => { - if offset == 7 && width == 1 - { - println!("OneRead value_row: {}", value_row); - } - }, - MemOp::OneWrite => { - if offset == 3 && width == 1 - { - println!("OneWrite value_row: {}", value_row); - } - }, - _ => {} - } - // Go to the right width let width_idx = Self::calculate_possible_widths(true, offset) .iter() @@ -148,24 +116,6 @@ impl MemAlignRomSM { .expect("Invalid width"); value_row += op_size * width_idx as u64; - match opcode { - MemOp::OneRead => { - if offset == 7 && width == 1 - { - println!("OneRead value_row: {}", value_row); - } - }, - MemOp::OneWrite => { - if offset == 3 && width == 1 - { - println!("opsizes: {:?}", op_size); - println!("width_idx: {:?}", width_idx); - println!("OneWrite value_row: {}", value_row); - } - }, - _ => {} - } - assert!(value_row < self.num_rows as u64); match opcode { @@ -181,7 +131,8 @@ impl MemAlignRomSM { // Go to the actual operation let mut value_row = match opcode { MemOp::TwoReads => { - 1 + ONE_WORD_COMBINATIONS * OP_SIZES[0] + ONE_WORD_COMBINATIONS * OP_SIZES[1] + 1 + ONE_WORD_COMBINATIONS * OP_SIZES[0] + + ONE_WORD_COMBINATIONS * OP_SIZES[1] } MemOp::TwoWrites => { 1 + ONE_WORD_COMBINATIONS * OP_SIZES[0] + diff --git a/state-machines/mem/src/mem_align_sm.rs b/state-machines/mem/src/mem_align_sm.rs index fec8a776..6de796a8 100644 --- a/state-machines/mem/src/mem_align_sm.rs +++ b/state-machines/mem/src/mem_align_sm.rs @@ -23,11 +23,12 @@ use crate::{MemAlignRomSM, MemOp}; const CHUNK_NUM: usize = 8; const CHUNK_NUM_U64: u64 = CHUNK_NUM as u64; const CHUNK_BITS: usize = 8; -const CHUNK_BITS_U64: u64 = CHUNK_BITS as u64; const OFFSET_MASK: u64 = CHUNK_NUM_U64 - 1; const CHUNK_BITS_MASK: u64 = (1 << CHUNK_BITS) - 1; const ALLOWED_WIDTHS: [u64; 4] = [1, 2, 4, 8]; +const DEFAULT_OFFSET: u64 = 0; +const DEFAULT_WIDTH: u64 = 8; pub struct MemAlignResponse { pub more_address: bool, @@ -144,7 +145,7 @@ impl MemAlignSM { drop(num_rows); println!("INPUT: {:?}", input); println!("MEM_VALUES: {:?}", mem_values); - println!("PHASE: {:?}\n", phase); + println!("PHASE: {:?}", phase); /* RV with offset=2, width=4 +----+----+====+====+====+====+----+----+ @@ -165,7 +166,7 @@ impl MemAlignSM { let addr_read = addr >> CHUNK_BITS; // Get the aligned value - let read_value = mem_values[phase]; + let value_read = mem_values[phase]; // Get the next pc let next_pc = @@ -174,7 +175,8 @@ impl MemAlignSM { let mut read_row = MemAlignRow:: { step: F::from_canonical_u64(step), addr: F::from_canonical_u64(addr_read), - // offset: F::from_canonical_u64(0), + offset: F::from_canonical_u64(DEFAULT_OFFSET), + width: F::from_canonical_u64(DEFAULT_WIDTH), // wr: F::from_bool(false), // pc: F::from_canonical_u64(0), reset: F::from_bool(true), @@ -184,7 +186,7 @@ impl MemAlignSM { let mut value_row = MemAlignRow:: { step: F::from_canonical_u64(step), - addr: F::from_canonical_u64(addr), + addr: F::from_canonical_u64(addr_read), offset: F::from_canonical_usize(offset), width: F::from_canonical_usize(width), // wr: F::from_bool(false), @@ -194,21 +196,57 @@ impl MemAlignSM { ..Default::default() }; + println!("VALUE_READ: {:?}", value_read.to_le_bytes()); + println!("VALUE: {:?}", value.to_le_bytes()); + for i in 0..CHUNK_NUM { - read_row.reg[i] = F::from_canonical_u64(Self::get_byte(read_value, i, 0)); - println!("READ_ROW[{}]: {:?}", i, read_row.reg[i]); + read_row.reg[i] = F::from_canonical_u64(Self::get_byte(value_read, i, 0)); if i >= offset && i < offset + width { read_row.sel[i] = F::from_bool(true); } value_row.reg[i] = F::from_canonical_u64(Self::get_byte(value, i, CHUNK_NUM - offset)); - println!("VALUE_ROW[{}]: {:?}", i, value_row.reg[i]); if i == offset { value_row.sel[i] = F::from_bool(true); } } + println!( + "FLAGS READ: {:?}", + [ + read_row.sel[0], + read_row.sel[1], + read_row.sel[2], + read_row.sel[3], + read_row.sel[4], + read_row.sel[5], + read_row.sel[6], + read_row.sel[7], + read_row.wr, + read_row.reset, + read_row.sel_up_to_down, + read_row.sel_down_to_up + ] + ); + println!( + "FLAGS VALUE: {:?}\n", + [ + value_row.sel[0], + value_row.sel[1], + value_row.sel[2], + value_row.sel[3], + value_row.sel[4], + value_row.sel[5], + value_row.sel[6], + value_row.sel[7], + value_row.wr, + value_row.reset, + value_row.sel_up_to_down, + value_row.sel_down_to_up + ] + ); + // Prove the generated rows self.prove(&[read_row, value_row]); @@ -220,7 +258,7 @@ impl MemAlignSM { drop(num_rows); println!("INPUT: {:?}", input); println!("MEM_VALUES: {:?}", mem_values); - println!("PHASE: {:?}\n", phase); + println!("PHASE: {:?}", phase); /* RWV with offset=3, width=4 +----+----+----+====+====+====+====+----+ @@ -245,7 +283,7 @@ impl MemAlignSM { let addr_read = addr >> CHUNK_BITS; // Get the aligned value - let read_value = mem_values[phase]; + let value_read = mem_values[phase]; // Get the next pc let next_pc = @@ -254,29 +292,23 @@ impl MemAlignSM { // Compute the write value let value_write = { // with:1 offset:4 - let width_bytes: u64 = (1 << (width * CHUNK_BITS)) - 1; // 0xFF - println!("WIDTH_BYTES: {:#X}", width_bytes); + let width_bytes: u64 = (1 << (width * CHUNK_BITS)) - 1; - let mask: u64 = width_bytes << (offset * CHUNK_BITS); // 0x00_00_00_FF_00_00_00_00 - println!("MASK: {:#X}", mask); + let mask: u64 = width_bytes << (offset * CHUNK_BITS); // Get the first width bytes of the unaligned value let value_to_write = (value & width_bytes) << (offset * CHUNK_BITS); - println!("VALUE_TO_WRITE: {:#X}", value_to_write); - // Write zeroes to read_value from offset to offset + width + // Write zeroes to value_read from offset to offset + width // and add the value to write to the value read - - let result = (read_value & !mask) | value_to_write; - println!("RESULT: {:#X}", result); - result + (value_read & !mask) | value_to_write }; let mut read_row = MemAlignRow:: { step: F::from_canonical_u64(step), addr: F::from_canonical_u64(addr_read), - // offset: F::from_canonical_u64(0), - width: F::from_canonical_u64(CHUNK_NUM_U64), + offset: F::from_canonical_u64(DEFAULT_OFFSET), + width: F::from_canonical_u64(DEFAULT_WIDTH), // wr: F::from_bool(false), // pc: F::from_canonical_u64(0), reset: F::from_bool(true), @@ -287,8 +319,8 @@ impl MemAlignSM { let mut write_row = MemAlignRow:: { step: F::from_canonical_u64(step + 1), addr: F::from_canonical_u64(addr_read), - // offset: F::from_canonical_u64(0), - width: F::from_canonical_u64(CHUNK_NUM_U64), + offset: F::from_canonical_u64(DEFAULT_OFFSET), + width: F::from_canonical_u64(DEFAULT_WIDTH), wr: F::from_bool(true), pc: F::from_canonical_u64(next_pc), // reset: F::from_bool(false), @@ -298,7 +330,7 @@ impl MemAlignSM { let mut value_row = MemAlignRow:: { step: F::from_canonical_u64(step), - addr: F::from_canonical_u64(addr), + addr: F::from_canonical_u64(addr_read), offset: F::from_canonical_usize(offset), width: F::from_canonical_usize(width), // wr: F::from_bool(false), @@ -308,15 +340,17 @@ impl MemAlignSM { ..Default::default() }; + println!("VALUE_READ: {:?}", value_read.to_le_bytes()); + println!("VALUE_WRITE: {:?}", value_write.to_le_bytes()); + println!("VALUE: {:?}", value.to_le_bytes()); + for i in 0..CHUNK_NUM { - read_row.reg[i] = F::from_canonical_u64(Self::get_byte(read_value, i, 0)); - println!("READ_ROW[{}]: {:?}", i, read_row.reg[i]); + read_row.reg[i] = F::from_canonical_u64(Self::get_byte(value_read, i, 0)); if i < offset || i >= offset + width { read_row.sel[i] = F::from_bool(true); } write_row.reg[i] = F::from_canonical_u64(Self::get_byte(value_write, i, 0)); - println!("WRITE_ROW[{}]: {:?}", i, write_row.reg[i]); if i >= offset && i < offset + width { write_row.sel[i] = F::from_bool(true); } @@ -328,12 +362,29 @@ impl MemAlignSM { F::from_canonical_u64(Self::get_byte(value, i, CHUNK_NUM - offset)) } }; - println!("VALUE_ROW[{}]: {:?}", i, value_row.reg[i]); if i == offset { value_row.sel[i] = F::from_bool(true); } } + println!( + "FLAGS: {:?}\n", + [ + value_row.sel[0], + value_row.sel[1], + value_row.sel[2], + value_row.sel[3], + value_row.sel[4], + value_row.sel[5], + value_row.sel[6], + value_row.sel[7], + F::from_bool(true), + value_row.reset, + value_row.sel_up_to_down, + value_row.sel_down_to_up + ] + ); + // Prove the generated rows self.prove(&[read_row, write_row, value_row]); @@ -366,7 +417,7 @@ impl MemAlignSM { drop(num_rows); println!("INPUT: {:?}", input); println!("MEM_VALUES: {:?}", mem_values); - println!("PHASE: {:?}\n", phase); + println!("PHASE: {:?}", phase); assert!(mem_values.len() == 2); // TODO: Debug mode @@ -392,8 +443,8 @@ impl MemAlignSM { let mut first_read_row = MemAlignRow:: { step: F::from_canonical_u64(step), addr: F::from_canonical_u64(addr_first_read), - // offset: F::from_canonical_u64(0), - width: F::from_canonical_u64(CHUNK_NUM_U64), + offset: F::from_canonical_u64(DEFAULT_OFFSET), + width: F::from_canonical_u64(DEFAULT_WIDTH), // wr: F::from_bool(false), // pc: F::from_canonical_u64(0), reset: F::from_bool(true), @@ -403,7 +454,7 @@ impl MemAlignSM { let mut value_row = MemAlignRow:: { step: F::from_canonical_u64(step), - addr: F::from_canonical_u64(addr), + addr: F::from_canonical_u64(addr_first_read), offset: F::from_canonical_usize(offset), width: F::from_canonical_usize(width), // wr: F::from_bool(false), @@ -416,8 +467,8 @@ impl MemAlignSM { let mut second_read_row = MemAlignRow:: { step: F::from_canonical_u64(step), addr: F::from_canonical_u64(addr_second_read), - // offset: F::from_canonical_u64(0), - width: F::from_canonical_u64(CHUNK_NUM_U64), + offset: F::from_canonical_u64(DEFAULT_OFFSET), + width: F::from_canonical_u64(DEFAULT_WIDTH), // wr: F::from_bool(false), pc: F::from_canonical_u64(next_pc + 1), // reset: F::from_bool(false), @@ -449,6 +500,24 @@ impl MemAlignSM { } } + println!( + "FLAGS: {:?}\n", + [ + value_row.sel[0], + value_row.sel[1], + value_row.sel[2], + value_row.sel[3], + value_row.sel[4], + value_row.sel[5], + value_row.sel[6], + value_row.sel[7], + F::from_bool(false), + value_row.reset, + value_row.sel_up_to_down, + value_row.sel_down_to_up + ] + ); + // Prove the generated rows self.prove(&[first_read_row, value_row, second_read_row]); @@ -505,7 +574,7 @@ impl MemAlignSM { // Get the first width bytes of the unaligned value let value_to_write = (value & width_bytes) << (offset * CHUNK_BITS); - // Write zeroes to read_value from offset to offset + width + // Write zeroes to value_read from offset to offset + width // and add the value to write to the value read (value_first_read & !mask) | value_to_write }; @@ -544,7 +613,7 @@ impl MemAlignSM { drop(num_rows); println!("INPUT: {:?}", input); println!("MEM_VALUES: {:?}", mem_values); - println!("PHASE: {:?}\n", phase); + println!("PHASE: {:?}", phase); assert!(mem_values.len() == 2); // TODO: Debug mode @@ -574,7 +643,7 @@ impl MemAlignSM { // Get the first width bytes of the unaligned value let value_to_write = (value & width_bytes) << (offset * CHUNK_BITS); - // Write zeroes to read_value from offset to offset + width + // Write zeroes to value_read from offset to offset + width // and add the value to write to the value read (value_first_read & !mask) | value_to_write }; @@ -592,7 +661,7 @@ impl MemAlignSM { // Get the first width bytes of the unaligned value let value_to_write = (value >> width_norm * CHUNK_BITS) & mask; - // Write zeroes to read_value from 0 to offset + width + // Write zeroes to value_read from 0 to offset + width // and add the value to write to the value read (value_second_read & !mask) | value_to_write }; @@ -608,8 +677,8 @@ impl MemAlignSM { let mut first_read_row = MemAlignRow:: { step: F::from_canonical_u64(step), addr: F::from_canonical_u64(addr_first_read_write), - // offset: F::from_canonical_u64(0), - width: F::from_canonical_u64(CHUNK_NUM_U64), + offset: F::from_canonical_u64(DEFAULT_OFFSET), + width: F::from_canonical_u64(DEFAULT_WIDTH), // wr: F::from_bool(false), // pc: F::from_canonical_u64(0), reset: F::from_bool(true), @@ -620,8 +689,8 @@ impl MemAlignSM { let mut first_write_row = MemAlignRow:: { step: F::from_canonical_u64(step + 1), addr: F::from_canonical_u64(addr_first_read_write), - // offset: F::from_canonical_u64(0), - width: F::from_canonical_u64(CHUNK_NUM_U64), + offset: F::from_canonical_u64(DEFAULT_OFFSET), + width: F::from_canonical_u64(DEFAULT_WIDTH), wr: F::from_bool(true), pc: F::from_canonical_u64(next_pc), // reset: F::from_bool(false), @@ -631,7 +700,7 @@ impl MemAlignSM { let mut value_row = MemAlignRow:: { step: F::from_canonical_u64(step), - addr: F::from_canonical_u64(addr), + addr: F::from_canonical_u64(addr_first_read_write), offset: F::from_canonical_usize(offset), width: F::from_canonical_usize(width), // wr: F::from_bool(false), @@ -644,8 +713,8 @@ impl MemAlignSM { let mut second_write_row = MemAlignRow:: { step: F::from_canonical_u64(step), addr: F::from_canonical_u64(addr_second_read_write), - // offset: F::from_canonical_u64(0), - width: F::from_canonical_u64(CHUNK_NUM_U64), + offset: F::from_canonical_u64(DEFAULT_OFFSET), + width: F::from_canonical_u64(DEFAULT_WIDTH), wr: F::from_bool(true), pc: F::from_canonical_u64(next_pc + 2), // reset: F::from_bool(false), @@ -656,8 +725,8 @@ impl MemAlignSM { let mut second_read_row = MemAlignRow:: { step: F::from_canonical_u64(step + 1), addr: F::from_canonical_u64(addr_second_read_write), - // offset: F::from_canonical_u64(0), - width: F::from_canonical_u64(CHUNK_NUM_U64), + offset: F::from_canonical_u64(DEFAULT_OFFSET), + width: F::from_canonical_u64(DEFAULT_WIDTH), // wr: F::from_bool(false), pc: F::from_canonical_u64(next_pc + 3), reset: F::from_bool(false), @@ -713,6 +782,92 @@ impl MemAlignSM { } } + println!( + "FLAGS FIRST READ: {:?}", + [ + first_read_row.sel[0], + first_read_row.sel[1], + first_read_row.sel[2], + first_read_row.sel[3], + first_read_row.sel[4], + first_read_row.sel[5], + first_read_row.sel[6], + first_read_row.sel[7], + F::from_bool(true), + first_read_row.reset, + first_read_row.sel_up_to_down, + first_read_row.sel_down_to_up + ] + ); + println!( + "FLAGS FIRST WRITE: {:?}", + [ + first_write_row.sel[0], + first_write_row.sel[1], + first_write_row.sel[2], + first_write_row.sel[3], + first_write_row.sel[4], + first_write_row.sel[5], + first_write_row.sel[6], + first_write_row.sel[7], + F::from_bool(false), + first_write_row.reset, + first_write_row.sel_up_to_down, + first_write_row.sel_down_to_up + ] + ); + println!( + "FLAGS VALUE: {:?}", + [ + value_row.sel[0], + value_row.sel[1], + value_row.sel[2], + value_row.sel[3], + value_row.sel[4], + value_row.sel[5], + value_row.sel[6], + value_row.sel[7], + F::from_bool(false), + value_row.reset, + value_row.sel_up_to_down, + value_row.sel_down_to_up + ] + ); + println!( + "FLAGS SECOND WRITE: {:?}", + [ + second_write_row.sel[0], + second_write_row.sel[1], + second_write_row.sel[2], + second_write_row.sel[3], + second_write_row.sel[4], + second_write_row.sel[5], + second_write_row.sel[6], + second_write_row.sel[7], + F::from_bool(false), + second_write_row.reset, + second_write_row.sel_up_to_down, + second_write_row.sel_down_to_up + ] + ); + println!( + "FLAGS SECOND READ: {:?}\n", + [ + second_read_row.sel[0], + second_read_row.sel[1], + second_read_row.sel[2], + second_read_row.sel[3], + second_read_row.sel[4], + second_read_row.sel[5], + second_read_row.sel[6], + second_read_row.sel[7], + F::from_bool(false), + second_read_row.reset, + second_read_row.sel_up_to_down, + second_read_row.sel_down_to_up + ] + ); + // Prove the generated rows self.prove(&[ first_read_row, @@ -798,7 +953,7 @@ impl MemAlignSM { } // Pad the remaining rows with trivially satisfying rows - let padding_row = MemAlignRow::::default(); + let padding_row = MemAlignRow:: { reset: F::from_bool(true), ..Default::default() }; let padding_size = air_mem_align_rows - rows_len; // Store the padding rows From 4593a0dbe10dcbdd921b3b732670bb38e4d02e0e Mon Sep 17 00:00:00 2001 From: zkronos73 Date: Fri, 22 Nov 2024 08:28:22 +0000 Subject: [PATCH 38/44] WIP mem proxy --- state-machines/mem/src/mem_proxy_engine.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/state-machines/mem/src/mem_proxy_engine.rs b/state-machines/mem/src/mem_proxy_engine.rs index 32638db5..bfa8f8f0 100644 --- a/state-machines/mem/src/mem_proxy_engine.rs +++ b/state-machines/mem/src/mem_proxy_engine.rs @@ -304,6 +304,9 @@ impl MemProxyEngine { fn update_last_addr(&mut self, addr: u32, value: u64) { self.last_addr = addr; self.last_addr_value = value; + self.update_mem_module(addr); + } + fn update_mem_module(&mut self, addr: u32) { // check if need to reevaluate the module id if addr > self.module_end_addr { self.update_mem_module_id(addr); From abcb73c81bf546f0b3a881fe2b57a612596e0119 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9ctor=20Masip?= Date: Fri, 22 Nov 2024 10:50:59 +0000 Subject: [PATCH 39/44] Cleaning the mem align --- state-machines/mem/src/mem_align_sm.rs | 438 ++++++++++++------------- 1 file changed, 210 insertions(+), 228 deletions(-) diff --git a/state-machines/mem/src/mem_align_sm.rs b/state-machines/mem/src/mem_align_sm.rs index 6de796a8..87f5b9c4 100644 --- a/state-machines/mem/src/mem_align_sm.rs +++ b/state-machines/mem/src/mem_align_sm.rs @@ -26,6 +26,17 @@ const CHUNK_BITS: usize = 8; const OFFSET_MASK: u64 = CHUNK_NUM_U64 - 1; const CHUNK_BITS_MASK: u64 = (1 << CHUNK_BITS) - 1; +const fn generate_allowed_offsets() -> [u64; CHUNK_NUM] { + let mut offsets = [0; CHUNK_NUM]; + let mut i = 0; + while i < CHUNK_NUM { + offsets[i] = i as u64; + i += 1; + } + offsets +} + +const ALLOWED_OFFSETS: [u64; CHUNK_NUM] = generate_allowed_offsets(); const ALLOWED_WIDTHS: [u64; 4] = [1, 2, 4, 8]; const DEFAULT_OFFSET: u64 = 0; const DEFAULT_WIDTH: u64 = 8; @@ -47,12 +58,22 @@ pub struct MemAlignSM { // Computed row information rows: Mutex>>, - num_computed_rows: Mutex, // TODO: DEBUG!!! + #[cfg(feature = "debug_mem_align")] + num_computed_rows: Mutex, // Secondary State machines mem_align_rom_sm: Arc>, } +macro_rules! debug_info { + ($prefix:expr, $($arg:tt)*) => { + #[cfg(feature = "debug_mem_align")] + { + info!(concat!("MemAlign: ",$prefix), $($arg)*); + } + }; +} + impl MemAlignSM { const MY_NAME: &'static str = "MemAlign"; @@ -66,6 +87,7 @@ impl MemAlignSM { std: std.clone(), registered_predecessors: AtomicU32::new(0), rows: Mutex::new(Vec::new()), + #[cfg(feature = "debug_mem_align")] num_computed_rows: Mutex::new(0), mem_align_rom_sm, }; @@ -98,7 +120,7 @@ impl MemAlignSM { let air_mem_align = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_ALIGN_AIR_IDS[0]); let rows_len = rows.len(); - assert!(rows_len <= air_mem_align.num_rows()); + debug_assert!(rows_len <= air_mem_align.num_rows()); let drained_rows = rows.drain(..rows_len).collect::>(); @@ -117,36 +139,39 @@ impl MemAlignSM { mem_values: [u64; 2], phase: usize, ) -> MemAlignResponse { - // Sanity check - // assert!(mem_values.len() == phase + 1); // TODO: Debug mode + debug_assert!( + mem_values.len() == phase + 1, + "The number of mem_values {} is not equal to phase + 1 {}", + mem_values.len(), + phase + 1 + ); let addr = input.address; let width = input.width; - let width = if ALLOWED_WIDTHS.contains(&width) { - width as usize - } else { - panic!("Width={} is not allowed. Allowed widths are {:?}", width, ALLOWED_WIDTHS); - }; + + // Compute the width + debug_assert!( + ALLOWED_WIDTHS.contains(&width), + "Width={} is not allowed. Allowed widths are {:?}", + width, + ALLOWED_WIDTHS + ); + let width = width as usize; // Compute the offset let offset = addr & OFFSET_MASK; - let offset = if offset <= usize::MAX as u64 { - offset as usize - } else { - panic!("Offset={} is too large", offset); - }; - - let num_rows = self.num_computed_rows.lock().unwrap(); // TODO: DEBUG!!! + debug_assert!( + ALLOWED_OFFSETS.contains(&offset), + "Offset={} is not allowed. Allowed offsets are {:?}", + offset, + ALLOWED_OFFSETS + ); + let offset = offset as usize; + #[cfg(feature = "debug_mem_align")] + let num_rows = self.num_computed_rows.lock().unwrap(); match (input.is_write, offset + width > CHUNK_NUM) { (false, false) => { - println!("ONE READ"); - println!("NUM_ROWS: [{},{}]", num_rows, *num_rows + 1); - drop(num_rows); - println!("INPUT: {:?}", input); - println!("MEM_VALUES: {:?}", mem_values); - println!("PHASE: {:?}", phase); - /* RV with offset=2, width=4 +----+----+====+====+====+====+----+----+ | R0 | R1 | R2 | R3 | R4 | R5 | R6 | R7 | @@ -156,7 +181,7 @@ impl MemAlignSM { | V6 | V7 | V0 | V1 | V2 | V3 | V4 | V5 | +----+----+====+====+====+====+----+----+ */ - assert!(phase == 0); // TODO: Debug mode + debug_assert!(phase == 0); // Unaligned memory op information thrown into the bus let step = input.step; @@ -196,9 +221,6 @@ impl MemAlignSM { ..Default::default() }; - println!("VALUE_READ: {:?}", value_read.to_le_bytes()); - println!("VALUE: {:?}", value.to_le_bytes()); - for i in 0..CHUNK_NUM { read_row.reg[i] = F::from_canonical_u64(Self::get_byte(value_read, i, 0)); if i >= offset && i < offset + width { @@ -212,54 +234,44 @@ impl MemAlignSM { } } - println!( - "FLAGS READ: {:?}", + #[rustfmt::skip] + debug_info!( + "\nOne Word Read\n\ + Num Rows: {:?}\n\ + Input: {:?}\n\ + Mem Values: {:?}\n\ + Phase: {:?}\n\ + Value Read: {:?}\n\ + Value: {:?}\n\ + Flags Read: {:?}\n\ + Flags Value: {:?}", + [*num_rows, *num_rows + 1], + input, + mem_values, + phase, + value_read.to_le_bytes(), + value.to_le_bytes(), [ - read_row.sel[0], - read_row.sel[1], - read_row.sel[2], - read_row.sel[3], - read_row.sel[4], - read_row.sel[5], - read_row.sel[6], - read_row.sel[7], - read_row.wr, - read_row.reset, - read_row.sel_up_to_down, - read_row.sel_down_to_up - ] - ); - println!( - "FLAGS VALUE: {:?}\n", + read_row.sel[0], read_row.sel[1], read_row.sel[2], read_row.sel[3], + read_row.sel[4], read_row.sel[5], read_row.sel[6], read_row.sel[7], + read_row.wr, read_row.reset, read_row.sel_up_to_down, read_row.sel_down_to_up + ], [ - value_row.sel[0], - value_row.sel[1], - value_row.sel[2], - value_row.sel[3], - value_row.sel[4], - value_row.sel[5], - value_row.sel[6], - value_row.sel[7], - value_row.wr, - value_row.reset, - value_row.sel_up_to_down, - value_row.sel_down_to_up + value_row.sel[0], value_row.sel[1], value_row.sel[2], value_row.sel[3], + value_row.sel[4], value_row.sel[5], value_row.sel[6], value_row.sel[7], + value_row.wr, value_row.reset, value_row.sel_up_to_down, value_row.sel_down_to_up ] ); + #[cfg(feature = "debug_mem_align")] + drop(num_rows); + // Prove the generated rows self.prove(&[read_row, value_row]); MemAlignResponse { more_address: false, step, value: None } } (true, false) => { - println!("ONE WRITE"); - println!("NUM_ROWS: [{},{}]", num_rows, *num_rows + 2); - drop(num_rows); - println!("INPUT: {:?}", input); - println!("MEM_VALUES: {:?}", mem_values); - println!("PHASE: {:?}", phase); - /* RWV with offset=3, width=4 +----+----+----+====+====+====+====+----+ | R0 | R1 | R2 | R3 | R4 | R5 | R6 | R7 | @@ -273,7 +285,7 @@ impl MemAlignSM { | V5 | V6 | V7 | V0 | V1 | V2 | V3 | V4 | +----+----+----+====+====+====+====+----+ */ - assert!(phase == 0); // TODO: Debug mode + debug_assert!(phase == 0); // Unaligned memory op information thrown into the bus let step = input.step; @@ -340,10 +352,6 @@ impl MemAlignSM { ..Default::default() }; - println!("VALUE_READ: {:?}", value_read.to_le_bytes()); - println!("VALUE_WRITE: {:?}", value_write.to_le_bytes()); - println!("VALUE: {:?}", value.to_le_bytes()); - for i in 0..CHUNK_NUM { read_row.reg[i] = F::from_canonical_u64(Self::get_byte(value_read, i, 0)); if i < offset || i >= offset + width { @@ -367,24 +375,46 @@ impl MemAlignSM { } } - println!( - "FLAGS: {:?}\n", + #[rustfmt::skip] + debug_info!( + "\nOne Word Write\n\ + Num Rows: {:?}\n\ + Input: {:?}\n\ + Mem Values: {:?}\n\ + Phase: {:?}\n\ + Value Read: {:?}\n\ + Value Write: {:?}\n\ + Value: {:?}\n\ + Flags Read: {:?}\n\ + Flags Write: {:?}\n\ + Flags Value: {:?}", + [*num_rows, *num_rows + 2], + input, + mem_values, + phase, + value_read.to_le_bytes(), + value_write.to_le_bytes(), + value.to_le_bytes(), + [ + read_row.sel[0], read_row.sel[1], read_row.sel[2], read_row.sel[3], + read_row.sel[4], read_row.sel[5], read_row.sel[6], read_row.sel[7], + read_row.wr, read_row.reset, read_row.sel_up_to_down, read_row.sel_down_to_up + ], [ - value_row.sel[0], - value_row.sel[1], - value_row.sel[2], - value_row.sel[3], - value_row.sel[4], - value_row.sel[5], - value_row.sel[6], - value_row.sel[7], - F::from_bool(true), - value_row.reset, - value_row.sel_up_to_down, - value_row.sel_down_to_up + write_row.sel[0], write_row.sel[1], write_row.sel[2], write_row.sel[3], + write_row.sel[4], write_row.sel[5], write_row.sel[6], write_row.sel[7], + write_row.wr, write_row.reset, write_row.sel_up_to_down, write_row.sel_down_to_up + ], + [ + value_row.sel[0], value_row.sel[1], value_row.sel[2], value_row.sel[3], + value_row.sel[4], value_row.sel[5], value_row.sel[6], value_row.sel[7], + value_row.wr, value_row.reset, value_row.sel_up_to_down, value_row.sel_down_to_up ] ); + #[cfg(feature = "debug_mem_align")] + drop(num_rows); + // Prove the generated rows self.prove(&[read_row, write_row, value_row]); @@ -404,7 +434,7 @@ impl MemAlignSM { | R0 | R1 | R2 | R3 | R4 | R5 | R6 | R7 | +====+====+====+====+====+----+----+----+ */ - assert!(phase == 0 || phase == 1); // TODO: Debug mode + debug_assert!(phase == 0 || phase == 1); match phase { // If phase == 0, do nothing, just ask for more @@ -412,14 +442,7 @@ impl MemAlignSM { // Otherwise, do the RVR 1 => { - println!("TWO READS"); - println!("NUM_ROWS: [{},{}]", num_rows, *num_rows + 2); - drop(num_rows); - println!("INPUT: {:?}", input); - println!("MEM_VALUES: {:?}", mem_values); - println!("PHASE: {:?}", phase); - - assert!(mem_values.len() == 2); // TODO: Debug mode + debug_assert!(mem_values.len() == 2); // Unaligned memory op information thrown into the bus let step = input.step; @@ -476,10 +499,6 @@ impl MemAlignSM { ..Default::default() }; - println!("VALUE_FIRST_READ: {:?}", value_first_read.to_le_bytes()); - println!("VALUE: {:?}", value.to_le_bytes()); - println!("VALUE_SECOND_READ: {:?}", value_second_read.to_le_bytes()); - for i in 0..CHUNK_NUM { first_read_row.reg[i] = F::from_canonical_u64(Self::get_byte(value_first_read, i, 0)); @@ -500,24 +519,46 @@ impl MemAlignSM { } } - println!( - "FLAGS: {:?}\n", + #[rustfmt::skip] + debug_info!( + "\nTwo Words Read\n\ + Num Rows: {:?}\n\ + Input: {:?}\n\ + Mem Values: {:?}\n\ + Phase: {:?}\n\ + Value First Read: {:?}\n\ + Value: {:?}\n\ + Value Second Read: {:?}\n\ + Flags First Read: {:?}\n\ + Flags Value: {:?}\n\ + Flags Second Read: {:?}", + [*num_rows, *num_rows + 2], + input, + mem_values, + phase, + value_first_read.to_le_bytes(), + value.to_le_bytes(), + value_second_read.to_le_bytes(), [ - value_row.sel[0], - value_row.sel[1], - value_row.sel[2], - value_row.sel[3], - value_row.sel[4], - value_row.sel[5], - value_row.sel[6], - value_row.sel[7], - F::from_bool(false), - value_row.reset, - value_row.sel_up_to_down, - value_row.sel_down_to_up + first_read_row.sel[0], first_read_row.sel[1], first_read_row.sel[2], first_read_row.sel[3], + first_read_row.sel[4], first_read_row.sel[5], first_read_row.sel[6], first_read_row.sel[7], + first_read_row.wr, first_read_row.reset, first_read_row.sel_up_to_down, first_read_row.sel_down_to_up + ], + [ + value_row.sel[0], value_row.sel[1], value_row.sel[2], value_row.sel[3], + value_row.sel[4], value_row.sel[5], value_row.sel[6], value_row.sel[7], + value_row.wr, value_row.reset, value_row.sel_up_to_down, value_row.sel_down_to_up + ], + [ + second_read_row.sel[0], second_read_row.sel[1], second_read_row.sel[2], second_read_row.sel[3], + second_read_row.sel[4], second_read_row.sel[5], second_read_row.sel[6], second_read_row.sel[7], + second_read_row.wr, second_read_row.reset, second_read_row.sel_up_to_down, second_read_row.sel_down_to_up ] ); + #[cfg(feature = "debug_mem_align")] + drop(num_rows); + // Prove the generated rows self.prove(&[first_read_row, value_row, second_read_row]); @@ -548,12 +589,12 @@ impl MemAlignSM { | R0 | R1 | R2 | R3 | R4 | R5 | R6 | R7 | +====+====+----+----+----+----+----+----+ */ - assert!(phase == 0 || phase == 1); // TODO: Debug mode + debug_assert!(phase == 0 || phase == 1); match phase { // If phase == 0, compute the resulting write value and ask for more 0 => { - // assert!(mem_values.len() == 1); // TODO: Debug mode + debug_assert!(mem_values.len() == 1); // Unaligned memory op information thrown into the bus let value = input.value; @@ -587,35 +628,7 @@ impl MemAlignSM { } // Otherwise, do the RWVRW 1 => { - /* RWVWR with offset=6, width=4 - +----+----+----+----+----+----+====+====+ - | R0 | R1 | R2 | R3 | R4 | R5 | R6 | R7 | - +----+----+----+----+----+----+====+====+ - ⇓ - +----+----+----+----+----+----+====+====+ - | W0 | W1 | W2 | W3 | W4 | W5 | W6 | W7 | - +----+----+----+----+----+----+====+====+ - ⇓ - +====+====+----+----+----+----+====+====+ - | V2 | V3 | V4 | V5 | V6 | V7 | V0 | V1 | - +====+====+----+----+----+----+====+====+ - ⇓ - +====+====+----+----+----+----+----+----+ - | W0 | W1 | W2 | W3 | W4 | W5 | W6 | W7 | - +====+====+----+----+----+----+----+----+ - ⇓ - +====+====+----+----+----+----+----+----+ - | R0 | R1 | R2 | R3 | R4 | R5 | R6 | R7 | - +====+====+----+----+----+----+----+----+ - */ - println!("TWO WRITES"); - println!("NUM_ROWS: [{},{}]", num_rows, *num_rows + 4); - drop(num_rows); - println!("INPUT: {:?}", input); - println!("MEM_VALUES: {:?}", mem_values); - println!("PHASE: {:?}", phase); - - assert!(mem_values.len() == 2); // TODO: Debug mode + debug_assert!(mem_values.len() == 2); // Unaligned memory op information thrown into the bus let step = input.step; @@ -734,11 +747,6 @@ impl MemAlignSM { ..Default::default() }; - println!("VALUE_FIRST_READ: {:?}", value_first_read.to_le_bytes()); - println!("VALUE_FIRST_WRITE: {:?}", value_first_write.to_le_bytes()); - println!("VALUE: {:?}", value.to_le_bytes()); - println!("VALUE_SECOND_WRITE: {:?}", value_second_write.to_le_bytes()); - println!("VALUE_SECOND_READ: {:?}", value_second_read.to_le_bytes()); for i in 0..CHUNK_NUM { first_read_row.reg[i] = F::from_canonical_u64(Self::get_byte(value_first_read, i, 0)); @@ -782,92 +790,62 @@ impl MemAlignSM { } } - println!( - "FLAGS FIRST READ: {:?}", + #[rustfmt::skip] + debug_info!( + "\nTwo Words Write\n\ + Num Rows: {:?}\n\ + Input: {:?}\n\ + Mem Values: {:?}\n\ + Phase: {:?}\n\ + Value First Read: {:?}\n\ + Value First Write: {:?}\n\ + Value: {:?}\n\ + Value Second Read: {:?}\n\ + Value Second Write: {:?}\n\ + Flags First Read: {:?}\n\ + Flags First Write: {:?}\n\ + Flags Value: {:?}\n\ + Flags Second Write: {:?}\n\ + Flags Second Read: {:?}", + [*num_rows, *num_rows + 4], + input, + mem_values, + phase, + value_first_read.to_le_bytes(), + value_first_write.to_le_bytes(), + value.to_le_bytes(), + value_second_write.to_le_bytes(), + value_second_read.to_le_bytes(), [ - first_read_row.sel[0], - first_read_row.sel[1], - first_read_row.sel[2], - first_read_row.sel[3], - first_read_row.sel[4], - first_read_row.sel[5], - first_read_row.sel[6], - first_read_row.sel[7], - F::from_bool(true), - first_read_row.reset, - first_read_row.sel_up_to_down, - first_read_row.sel_down_to_up - ] - ); - println!( - "FLAGS FIRST WRITE: {:?}", + first_read_row.sel[0], first_read_row.sel[1], first_read_row.sel[2], first_read_row.sel[3], + first_read_row.sel[4], first_read_row.sel[5], first_read_row.sel[6], first_read_row.sel[7], + first_read_row.wr, first_read_row.reset, first_read_row.sel_up_to_down, first_read_row.sel_down_to_up + ], [ - first_write_row.sel[0], - first_write_row.sel[1], - first_write_row.sel[2], - first_write_row.sel[3], - first_write_row.sel[4], - first_write_row.sel[5], - first_write_row.sel[6], - first_write_row.sel[7], - F::from_bool(false), - first_write_row.reset, - first_write_row.sel_up_to_down, - first_write_row.sel_down_to_up - ] - ); - println!( - "FLAGS VALUE: {:?}", + first_write_row.sel[0], first_write_row.sel[1], first_write_row.sel[2], first_write_row.sel[3], + first_write_row.sel[4], first_write_row.sel[5], first_write_row.sel[6], first_write_row.sel[7], + first_write_row.wr, first_write_row.reset, first_write_row.sel_up_to_down, first_write_row.sel_down_to_up + ], [ - value_row.sel[0], - value_row.sel[1], - value_row.sel[2], - value_row.sel[3], - value_row.sel[4], - value_row.sel[5], - value_row.sel[6], - value_row.sel[7], - F::from_bool(false), - value_row.reset, - value_row.sel_up_to_down, - value_row.sel_down_to_up - ] - ); - println!( - "FLAGS SECOND WRITE: {:?}", + value_row.sel[0], value_row.sel[1], value_row.sel[2], value_row.sel[3], + value_row.sel[4], value_row.sel[5], value_row.sel[6], value_row.sel[7], + value_row.wr, value_row.reset, value_row.sel_up_to_down, value_row.sel_down_to_up + ], [ - second_write_row.sel[0], - second_write_row.sel[1], - second_write_row.sel[2], - second_write_row.sel[3], - second_write_row.sel[4], - second_write_row.sel[5], - second_write_row.sel[6], - second_write_row.sel[7], - F::from_bool(false), - second_write_row.reset, - second_write_row.sel_up_to_down, - second_write_row.sel_down_to_up - ] - ); - println!( - "FLAGS SECOND READ: {:?}\n", + second_write_row.sel[0], second_write_row.sel[1], second_write_row.sel[2], second_write_row.sel[3], + second_write_row.sel[4], second_write_row.sel[5], second_write_row.sel[6], second_write_row.sel[7], + second_write_row.wr, second_write_row.reset, second_write_row.sel_up_to_down, second_write_row.sel_down_to_up + ], [ - second_read_row.sel[0], - second_read_row.sel[1], - second_read_row.sel[2], - second_read_row.sel[3], - second_read_row.sel[4], - second_read_row.sel[5], - second_read_row.sel[6], - second_read_row.sel[7], - F::from_bool(false), - second_read_row.reset, - second_read_row.sel_up_to_down, - second_read_row.sel_down_to_up + second_read_row.sel[0], second_read_row.sel[1], second_read_row.sel[2], second_read_row.sel[3], + second_read_row.sel[4], second_read_row.sel[5], second_read_row.sel[6], second_read_row.sel[7], + second_read_row.wr, second_read_row.reset, second_read_row.sel_up_to_down, second_read_row.sel_down_to_up ] ); + #[cfg(feature = "debug_mem_align")] + drop(num_rows); + // Prove the generated rows self.prove(&[ first_read_row, @@ -898,8 +876,12 @@ impl MemAlignSM { if let Ok(mut rows) = self.rows.lock() { rows.extend_from_slice(computed_rows); - let mut num_rows = self.num_computed_rows.lock().unwrap(); // TODO: DEBUG!!! - *num_rows += computed_rows.len(); + #[cfg(feature = "debug_mem_align")] + { + let mut num_rows = self.num_computed_rows.lock().unwrap(); + *num_rows += computed_rows.len(); + drop(num_rows); + } let pctx = self.wcm.get_pctx(); let air_mem_align = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_ALIGN_AIR_IDS[0]); @@ -924,7 +906,7 @@ impl MemAlignSM { let rows_len = rows.len(); // You cannot feed to the AIR more rows than it has - assert!(rows_len <= air_mem_align_rows); + debug_assert!(rows_len <= air_mem_align_rows); // Get the execution and setup context let ectx = wcm.get_ectx(); @@ -939,7 +921,7 @@ impl MemAlignSM { MemAlignTrace::::map_buffer(&mut prover_buffer, air_mem_align_rows, offset as usize) .unwrap(); - let mut reg_range_check: HashMap = HashMap::new(); + let mut reg_range_check: HashMap = HashMap::new(); // TODO: HashMap to Vec of size 256 // Add the input rows to the trace for (i, &row) in rows.iter().enumerate() { From cec3d41e717d1b5405a180f454494fdf56a18914 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9ctor=20Masip?= Date: Fri, 22 Nov 2024 12:13:00 +0000 Subject: [PATCH 40/44] Removing hashmaps and cleaning up stuff a little bit --- state-machines/mem/src/mem_align_rom_sm.rs | 145 ++++++++------------- state-machines/mem/src/mem_align_sm.rs | 30 +++-- 2 files changed, 71 insertions(+), 104 deletions(-) diff --git a/state-machines/mem/src/mem_align_rom_sm.rs b/state-machines/mem/src/mem_align_rom_sm.rs index 1953c016..61170135 100644 --- a/state-machines/mem/src/mem_align_rom_sm.rs +++ b/state-machines/mem/src/mem_align_rom_sm.rs @@ -22,7 +22,6 @@ pub enum MemOp { TwoWrites, } -const CHUNK_NUM: usize = 8; const OP_SIZES: [u64; 4] = [2, 3, 3, 5]; const ONE_WORD_COMBINATIONS: u64 = 20; // (0..4,[1,2,4]), (5,6,[1,2]), (7,[1]) -> 5*3 + 2*2 + 1*1 = 20 const TWO_WORD_COMBINATIONS: u64 = 11; // (1..4,[8]), (5,6,[4,8]), (7,[2,4,8]) -> 4*1 + 2*2 + 1*3 = 11 @@ -79,93 +78,64 @@ impl MemAlignRomSM { } pub fn calculate_next_pc(&self, opcode: MemOp, offset: usize, width: usize) -> u64 { - let row_idxs = Self::get_row_idxs(&self, opcode, offset, width); + // Get the table offset + let (table_offset, one_word) = match opcode { + MemOp::OneRead => { + (1, true) + } + + MemOp::OneWrite => { + (1 + ONE_WORD_COMBINATIONS * OP_SIZES[0], true) + } - // Update the multiplicity - let ones: Vec = vec![1; row_idxs.len()]; - self.update_multiplicity_by_row_idx(&row_idxs, &ones); + MemOp::TwoReads => { + (1 + ONE_WORD_COMBINATIONS * OP_SIZES[0] + ONE_WORD_COMBINATIONS * OP_SIZES[1], false) + } + + MemOp::TwoWrites => { + (1 + ONE_WORD_COMBINATIONS * OP_SIZES[0] + ONE_WORD_COMBINATIONS * OP_SIZES[1] + TWO_WORD_COMBINATIONS * OP_SIZES[2], false) + } + }; + + // Get the first row index + let first_row_idx = Self::get_first_row_idx(opcode, offset, width, table_offset, one_word); + + // Based on the program size, return the row indices + let opcode_idx = opcode as usize; + let op_size = OP_SIZES[opcode_idx]; + for i in 0..op_size { + let row_idx = first_row_idx + i; + // Check whether the row index is within the bounds + debug_assert!(row_idx < self.num_rows as u64); + // Update the multiplicity + self.update_multiplicity_by_row_idx(row_idx, 1); + } - row_idxs[0] + first_row_idx } - fn get_row_idxs(&self, opcode: MemOp, offset: usize, width: usize) -> Vec { + fn get_first_row_idx(opcode: MemOp, offset: usize, width: usize, table_offset: u64, one_word: bool) -> u64 { let opcode_idx = opcode as usize; let op_size = OP_SIZES[opcode_idx]; - match opcode { - MemOp::OneRead | MemOp::OneWrite => { - // Sanity check - assert!(offset + width <= CHUNK_NUM); - - // Go to the actual operation - let mut value_row = match opcode { - MemOp::OneRead => 1, - MemOp::OneWrite => 1 + ONE_WORD_COMBINATIONS * OP_SIZES[0], - _ => unreachable!(), - }; - - // Go to the actual offset - for i in 0..offset { - let possible_widths = Self::calculate_possible_widths(true, i); - value_row += op_size * possible_widths.len() as u64; - } - - // Go to the right width - let width_idx = Self::calculate_possible_widths(true, offset) - .iter() - .position(|&w| w == width) - .expect("Invalid width"); - value_row += op_size * width_idx as u64; - - assert!(value_row < self.num_rows as u64); - - match opcode { - MemOp::OneRead => vec![value_row, value_row + 1], - MemOp::OneWrite => vec![value_row, value_row + 1, value_row + 2], - _ => unreachable!(), - } - } - MemOp::TwoReads | MemOp::TwoWrites => { - // Sanity check - assert!(offset + width > CHUNK_NUM); - - // Go to the actual operation - let mut value_row = match opcode { - MemOp::TwoReads => { - 1 + ONE_WORD_COMBINATIONS * OP_SIZES[0] + - ONE_WORD_COMBINATIONS * OP_SIZES[1] - } - MemOp::TwoWrites => { - 1 + ONE_WORD_COMBINATIONS * OP_SIZES[0] + - ONE_WORD_COMBINATIONS * OP_SIZES[1] + - TWO_WORD_COMBINATIONS * OP_SIZES[2] - } - _ => unreachable!(), - }; - - // Go to the actual offset - for i in 1..offset { - let possible_widths = Self::calculate_possible_widths(false, i); - value_row += op_size * possible_widths.len() as u64; - } - - assert!(value_row < self.num_rows as u64); - - // Go to the right width - let width_idx = Self::calculate_possible_widths(false, offset) - .iter() - .position(|&w| w == width) - .expect("Invalid width"); - value_row += op_size * width_idx as u64; - - match opcode { - MemOp::TwoReads => vec![value_row, value_row + 1, value_row + 2], - MemOp::TwoWrites => { - vec![value_row, value_row + 1, value_row + 2, value_row + 3, value_row + 4] - } - _ => unreachable!(), - } - } + + // Go to the actual operation + let mut first_row_idx = table_offset; + + // Go to the actual offset + let first_valid_offset = if one_word { 0 } else { 1 }; + for i in first_valid_offset..offset { + let possible_widths = Self::calculate_possible_widths(one_word, i); + first_row_idx += op_size * possible_widths.len() as u64; } + + // Go to the right width + let width_idx = Self::calculate_possible_widths(one_word, offset) + .iter() + .position(|&w| w == width) + .expect("Invalid width"); + first_row_idx += op_size * width_idx as u64; + + first_row_idx } fn calculate_possible_widths(one_word: bool, offset: usize) -> Vec { @@ -189,19 +159,12 @@ impl MemAlignRomSM { pub fn update_padding_row(&self, padding_len: u64) { // Update entry at the padding row (pos = 0) with the given padding length - self.update_multiplicity_by_row_idx(&[0], &[padding_len]); + self.update_multiplicity_by_row_idx(0, padding_len); } - pub fn update_multiplicity_by_row_idx(&self, row_idxs: &[u64], muls: &[u64]) { - if row_idxs.len() != muls.len() { - panic!("The number of indices and multiplicities must be the same"); - } - + pub fn update_multiplicity_by_row_idx(&self, row_idx: u64, mul: u64) { let mut multiplicity = self.multiplicity.lock().unwrap(); - - for (i, &idx) in row_idxs.iter().enumerate() { - *multiplicity.entry(idx).or_insert(0) += muls[i]; - } + *multiplicity.entry(row_idx).or_insert(0) += mul; } pub fn create_air_instance(&self) { diff --git a/state-machines/mem/src/mem_align_sm.rs b/state-machines/mem/src/mem_align_sm.rs index 87f5b9c4..b0f826a1 100644 --- a/state-machines/mem/src/mem_align_sm.rs +++ b/state-machines/mem/src/mem_align_sm.rs @@ -1,14 +1,12 @@ use core::panic; -use std::{ - collections::HashMap, - sync::{ - atomic::{AtomicU32, Ordering}, - Arc, Mutex, - }, +use std::sync::{ + atomic::{AtomicU32, Ordering}, + Arc, Mutex, }; use log::info; use num_bigint::BigInt; +use num_traits::ToPrimitive; use p3_field::PrimeField; use pil_std_lib::Std; use proofman::{WitnessComponent, WitnessManager}; @@ -921,7 +919,7 @@ impl MemAlignSM { MemAlignTrace::::map_buffer(&mut prover_buffer, air_mem_align_rows, offset as usize) .unwrap(); - let mut reg_range_check: HashMap = HashMap::new(); // TODO: HashMap to Vec of size 256 + let mut reg_range_check: Vec = vec![0; 1 << CHUNK_BITS]; // Add the input rows to the trace for (i, &row) in rows.iter().enumerate() { @@ -930,7 +928,9 @@ impl MemAlignSM { // Store the value of all reg columns so that they can be range checked for j in 0..CHUNK_NUM { - *reg_range_check.entry(row.reg[j]).or_insert(0) += 1; + let element = + row.reg[j].as_canonical_biguint().to_usize().expect("Cannot convert to usize"); + reg_range_check[element] += 1; } } @@ -943,16 +943,20 @@ impl MemAlignSM { trace_buffer[i] = padding_row; } - // Store the value of all reg columns so that they can be range checked - for j in 0..CHUNK_NUM { - *reg_range_check.entry(padding_row.reg[j]).or_insert(0) += padding_size as u64; + // Store the value of all padding reg columns so that they can be range checked + for _ in 0..CHUNK_NUM { + reg_range_check[0] += padding_size as u64; } // Perform the range checks let std = self.std.clone(); let range_id = std.get_range(BigInt::from(0), BigInt::from(CHUNK_BITS_MASK), None); - for (&value, &multiplicity) in reg_range_check.iter() { - std.range_check(value, F::from_canonical_u64(multiplicity), range_id); + for (value, &multiplicity) in reg_range_check.iter().enumerate() { + std.range_check( + F::from_canonical_usize(value), + F::from_canonical_u64(multiplicity), + range_id, + ); } // Compute the padding multiplicity From 5b37bc298b88641c99ec6d2940df2df5d28e0259 Mon Sep 17 00:00:00 2001 From: zkronos73 Date: Fri, 22 Nov 2024 19:15:56 +0000 Subject: [PATCH 41/44] WIP mem - mem_align integration --- Cargo.lock | 9 +- core/src/zisk_required_operation.rs | 1 + emulator/src/emu.rs | 5 + pil/src/pil_helpers/traces.rs | 2 + state-machines/mem/src/mem_align_sm.rs | 66 ++++++------ state-machines/mem/src/mem_helpers.rs | 77 +++++++++++++- state-machines/mem/src/mem_proxy_engine.rs | 112 +++++++++++---------- state-machines/mem/src/mem_sm.rs | 33 ++---- state-machines/mem/src/mem_unmapped.rs | 6 +- 9 files changed, 183 insertions(+), 128 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6fb276c5..99c8b119 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1593,9 +1593,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.91" +version = "1.0.92" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "307e3004becf10f5a6e0d59d20f3cd28231b0e0827a96cd3e0ce6d14bc1e4bb3" +checksum = "37d3544b3f2748c54e147655edb5025752e2303145b5aefb3c3ea2c78b973bb0" dependencies = [ "unicode-ident", ] @@ -1627,6 +1627,7 @@ dependencies = [ "env_logger", "log", "p3-field", + "p3-goldilocks", "pilout", "proofman-macros", "proofman-starks-lib-c", @@ -2891,9 +2892,9 @@ dependencies = [ [[package]] name = "webpki-roots" -version = "0.26.6" +version = "0.26.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "841c67bff177718f1d4dfefde8d8f0e78f9b6589319ba88312f567fc5841a958" +checksum = "5d642ff16b7e79272ae451b7322067cdc17cadf68c23264be9d94a32319efe7e" dependencies = [ "rustls-pki-types", ] diff --git a/core/src/zisk_required_operation.rs b/core/src/zisk_required_operation.rs index d0d38abf..04410a0f 100644 --- a/core/src/zisk_required_operation.rs +++ b/core/src/zisk_required_operation.rs @@ -13,6 +13,7 @@ pub struct ZiskRequiredMemory { pub address: u32, pub is_write: bool, pub width: u8, + pub step_offset: u8, pub step: u64, pub value: u64, } diff --git a/emulator/src/emu.rs b/emulator/src/emu.rs index c7134b1e..0752d793 100644 --- a/emulator/src/emu.rs +++ b/emulator/src/emu.rs @@ -110,6 +110,7 @@ impl<'a> Emu<'a> { let required_memory = ZiskRequiredMemory { step: self.ctx.inst_ctx.step, + step_offset: 0, is_write: false, address: addr as u32, width: 8, @@ -184,6 +185,7 @@ impl<'a> Emu<'a> { let required_memory = ZiskRequiredMemory { step: self.ctx.inst_ctx.step, + step_offset: 1, is_write: false, address: addr as u32, width: 8, @@ -203,6 +205,7 @@ impl<'a> Emu<'a> { self.ctx.inst_ctx.b = self.ctx.inst_ctx.mem.read(addr, instruction.ind_width); let required_memory = ZiskRequiredMemory { step: self.ctx.inst_ctx.step, + step_offset: 1, is_write: false, address: addr as u32, width: instruction.ind_width as u8, @@ -283,6 +286,7 @@ impl<'a> Emu<'a> { let required_memory = ZiskRequiredMemory { step: self.ctx.inst_ctx.step, + step_offset: 2, is_write: true, address: addr as u32, width: 8, @@ -305,6 +309,7 @@ impl<'a> Emu<'a> { let required_memory = ZiskRequiredMemory { step: self.ctx.inst_ctx.step, + step_offset: 2, is_write: true, address: addr as u32, width: instruction.ind_width as u8, diff --git a/pil/src/pil_helpers/traces.rs b/pil/src/pil_helpers/traces.rs index 509b00b8..79bef379 100644 --- a/pil/src/pil_helpers/traces.rs +++ b/pil/src/pil_helpers/traces.rs @@ -58,3 +58,5 @@ trace!(SpecifiedRangesRow, SpecifiedRangesTrace { trace!(U8AirRow, U8AirTrace { mul: F, }); + + diff --git a/state-machines/mem/src/mem_align_sm.rs b/state-machines/mem/src/mem_align_sm.rs index 1d125b3e..d2ea9f18 100644 --- a/state-machines/mem/src/mem_align_sm.rs +++ b/state-machines/mem/src/mem_align_sm.rs @@ -15,18 +15,17 @@ use proofman::{WitnessComponent, WitnessManager}; use proofman_common::AirInstance; use sm_common::create_prover_buffer; -use zisk_core::ZiskRequiredMemory; use zisk_pil::{MemAlignRow, MemAlignTrace, MEM_ALIGN_AIR_IDS, ZISK_AIRGROUP_ID}; -use crate::{MemAlignRomSM, MemOp}; +use crate::{MemAlignInput, MemAlignRomSM, MemOp}; const CHUNK_NUM: usize = 8; -const CHUNK_NUM_U64: u64 = CHUNK_NUM as u64; const CHUNK_BITS: usize = 8; -const OFFSET_MASK: u64 = CHUNK_NUM_U64 - 1; +const OFFSET_MASK: u64 = 0x07; +const OFFSET_BITS: u64 = 3; const CHUNK_BITS_MASK: u64 = (1 << CHUNK_BITS) - 1; -const ALLOWED_WIDTHS: [u64; 4] = [1, 2, 4, 8]; +const ALLOWED_WIDTHS: [u8; 4] = [1, 2, 4, 8]; const DEFAULT_OFFSET: u64 = 0; const DEFAULT_WIDTH: u64 = 8; @@ -111,25 +110,19 @@ impl MemAlignSM { } #[inline(always)] - pub fn get_mem_op( - &self, - input: &ZiskRequiredMemory, - mem_values: [u64; 2], - phase: usize, - ) -> MemAlignResponse { + pub fn get_mem_op(&self, input: &MemAlignInput, phase: usize) -> MemAlignResponse { // Sanity check // assert!(mem_values.len() == phase + 1); // TODO: Debug mode - let addr = input.address; - let width = input.width; - let width = if ALLOWED_WIDTHS.contains(&width) { - width as usize + let addr = input.address as u64; + let width = if ALLOWED_WIDTHS.contains(&input.width) { + input.width as usize } else { - panic!("Width={} is not allowed. Allowed widths are {:?}", width, ALLOWED_WIDTHS); + panic!("Width={} is not allowed. Allowed widths are {:?}", input.width, ALLOWED_WIDTHS); }; // Compute the offset - let offset = addr as u64 & OFFSET_MASK; + let offset = addr & OFFSET_MASK; let offset = if offset <= usize::MAX as u64 { offset as usize } else { @@ -144,7 +137,7 @@ impl MemAlignSM { println!("NUM_ROWS: [{},{}]", num_rows, *num_rows + 1); drop(num_rows); println!("INPUT: {:?}", input); - println!("MEM_VALUES: {:?}", mem_values); + println!("MEM_VALUES: {:?}", input.mem_values); println!("PHASE: {:?}", phase); /* RV with offset=2, width=4 @@ -163,10 +156,10 @@ impl MemAlignSM { let value = input.value; // Get the aligned address - let addr_read = addr >> CHUNK_BITS; + let addr_read = addr >> OFFSET_BITS; // Get the aligned value - let value_read = mem_values[phase]; + let value_read = input.mem_values[phase]; // Get the next pc let next_pc = @@ -249,6 +242,8 @@ impl MemAlignSM { // Prove the generated rows self.prove(&[read_row, value_row]); + println!("MEM_ALIGN_PRE_ROW(R): {:?}", read_row); + println!("MEM_ALIGN_PRE_ROW(V): {:?}", value_row); MemAlignResponse { more_address: false, step, value: None } } @@ -257,7 +252,6 @@ impl MemAlignSM { println!("NUM_ROWS: [{},{}]", num_rows, *num_rows + 2); drop(num_rows); println!("INPUT: {:?}", input); - println!("MEM_VALUES: {:?}", mem_values); println!("PHASE: {:?}", phase); /* RWV with offset=3, width=4 @@ -280,10 +274,10 @@ impl MemAlignSM { let value = input.value; // Get the aligned address - let addr_read = addr >> CHUNK_BITS; + let addr_read = addr >> OFFSET_BITS; // Get the aligned value - let value_read = mem_values[phase]; + let value_read = input.mem_values[phase]; // Get the next pc let next_pc = @@ -416,10 +410,9 @@ impl MemAlignSM { println!("NUM_ROWS: [{},{}]", num_rows, *num_rows + 2); drop(num_rows); println!("INPUT: {:?}", input); - println!("MEM_VALUES: {:?}", mem_values); println!("PHASE: {:?}", phase); - assert!(mem_values.len() == 2); // TODO: Debug mode + assert!(input.mem_values.len() == 2); // TODO: Debug mode // Unaligned memory op information thrown into the bus let step = input.step; @@ -429,12 +422,12 @@ impl MemAlignSM { let rem_bytes = (offset + width) % CHUNK_NUM; // Get the aligned address - let addr_first_read = addr >> CHUNK_BITS; - let addr_second_read = addr >> CHUNK_BITS + CHUNK_BITS; + let addr_first_read = addr >> OFFSET_BITS; + let addr_second_read = addr_first_read + 1; // Get the aligned value - let value_first_read = mem_values[0]; - let value_second_read = mem_values[1]; + let value_first_read = input.mem_values[0]; + let value_second_read = input.mem_values[1]; // Get the next pc let next_pc = @@ -560,7 +553,7 @@ impl MemAlignSM { let step = input.step; // Get the aligned value - let value_first_read = mem_values[0]; + let value_first_read = input.mem_values[0]; // Compute the write value let value_first_write = { @@ -612,10 +605,9 @@ impl MemAlignSM { println!("NUM_ROWS: [{},{}]", num_rows, *num_rows + 4); drop(num_rows); println!("INPUT: {:?}", input); - println!("MEM_VALUES: {:?}", mem_values); println!("PHASE: {:?}", phase); - assert!(mem_values.len() == 2); // TODO: Debug mode + assert!(input.mem_values.len() == 2); // TODO: Debug mode // Unaligned memory op information thrown into the bus let step = input.step; @@ -625,11 +617,11 @@ impl MemAlignSM { let rem_bytes = (offset + width) % CHUNK_NUM; // Get the aligned address - let addr_first_read_write = addr >> CHUNK_BITS; - let addr_second_read_write = addr >> CHUNK_BITS + CHUNK_BITS; + let addr_first_read_write = (addr >> OFFSET_BITS) as u64; + let addr_second_read_write = addr_first_read_write + 1; // Get the first aligned value - let value_first_read = mem_values[0]; + let value_first_read = input.mem_values[0]; // Recompute the first write value let value_first_write = { @@ -649,7 +641,7 @@ impl MemAlignSM { }; // Get the second aligned value - let value_second_read = mem_values[1]; + let value_second_read = input.mem_values[1]; // Compute the second write value let value_second_write = { @@ -945,7 +937,7 @@ impl MemAlignSM { for (i, &row) in rows.iter().enumerate() { // Store the entire row trace_buffer[i] = row; - + println!("MEM_ALIGN_ROW: {:?}", row); // Store the value of all reg columns so that they can be range checked for j in 0..CHUNK_NUM { *reg_range_check.entry(row.reg[j]).or_insert(0) += 1; diff --git a/state-machines/mem/src/mem_helpers.rs b/state-machines/mem/src/mem_helpers.rs index ac4ca198..3f4db4d9 100644 --- a/state-machines/mem/src/mem_helpers.rs +++ b/state-machines/mem/src/mem_helpers.rs @@ -1,4 +1,4 @@ -use crate::MemAlignResponse; +use crate::{MemAlignResponse, MEM_BYTES}; use std::fmt; use zisk_core::ZiskRequiredMemory; @@ -12,6 +12,73 @@ fn format_u64_hex(value: u64) -> String { .join("_") } +const MAX_MEM_STEP_OFFSET: u64 = 2; +const MAX_MEM_OPS_PER_MAIN_STEP: u64 = (MAX_MEM_STEP_OFFSET + 1) * 2; + +#[derive(Debug, Clone)] +pub struct MemAlignInput { + pub address: u32, + pub is_write: bool, + pub width: u8, + pub step: u64, + pub value: u64, + pub mem_values: [u64; 2], +} + +#[derive(Debug, Clone)] +pub struct MemInput { + pub address: u32, + pub is_write: bool, + pub step: u64, + pub value: u64, +} + +impl MemInput { + pub fn new(address: u32, is_write: bool, step: u64, value: u64) -> Self { + MemInput { address, is_write, step, value } + } + pub fn from(mem_op: &ZiskRequiredMemory) -> Self { + // debug_assert_eq!(mem_op.width, MEM_BYTES as u8); + MemInput { + address: mem_op.address, + is_write: mem_op.is_write, + step: MemHelpers::main_step_to_address_step(mem_op.step, mem_op.step_offset), + value: mem_op.value, + } + } +} + +impl MemAlignInput { + pub fn new( + address: u32, + is_write: bool, + width: u8, + step: u64, + value: u64, + mem_values: [u64; 2], + ) -> Self { + MemAlignInput { address, is_write, width, step, value, mem_values } + } + pub fn from(mem_op: &MemInput, width: u8, mem_values: &[u64; 2]) -> Self { + MemAlignInput { + address: mem_op.address, + is_write: mem_op.is_write, + step: mem_op.step, + width, + value: mem_op.value, + mem_values: [mem_values[0], mem_values[1]], + } + } +} + +pub struct MemHelpers {} + +impl MemHelpers { + pub fn main_step_to_address_step(step: u64, step_offset: u8) -> u64 { + 1 + MAX_MEM_OPS_PER_MAIN_STEP * step + 2 * step_offset as u64 + } +} + impl fmt::Debug for MemAlignResponse { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( @@ -41,8 +108,8 @@ pub fn mem_align_call( more_address: double_address, step: mem_op.step + 1, value: Some( - (mem_value & (0xFFFF_FFFF_FFFF_FFFFu64 ^ (mask << offset))) - | ((mem_op.value & mask) << offset), + (mem_value & (0xFFFF_FFFF_FFFF_FFFFu64 ^ (mask << offset))) | + ((mem_op.value & mask) << offset), ), } } else { @@ -50,8 +117,8 @@ pub fn mem_align_call( more_address: false, step: mem_op.step + 1, value: Some( - (mem_value & (0xFFFF_FFFF_FFFF_FFFFu64 << (offset + width as u32 - 64))) - | ((mem_op.value & mask) >> (128 - (offset + width as u32))), + (mem_value & (0xFFFF_FFFF_FFFF_FFFFu64 << (offset + width as u32 - 64))) | + ((mem_op.value & mask) >> (128 - (offset + width as u32))), ), } } diff --git a/state-machines/mem/src/mem_proxy_engine.rs b/state-machines/mem/src/mem_proxy_engine.rs index bfa8f8f0..50315cdd 100644 --- a/state-machines/mem/src/mem_proxy_engine.rs +++ b/state-machines/mem/src/mem_proxy_engine.rs @@ -1,8 +1,8 @@ use std::{collections::VecDeque, sync::Arc}; use crate::{ - mem_align_call, MemAlignResponse, MemAlignSM, MemUnmapped, MAX_MEM_ADDR, - MAX_MEM_OPS_PER_MAIN_STEP, MAX_MEM_STEP, MEM_ADDR_MASK, MEM_BYTES, + MemAlignInput, MemAlignResponse, MemAlignSM, MemHelpers, MemInput, MemUnmapped, MAX_MEM_ADDR, + MAX_MEM_OPS_PER_MAIN_STEP, MEM_ADDR_MASK, MEM_BYTES, }; use log::info; use p3_field::PrimeField; @@ -19,30 +19,24 @@ macro_rules! debug_info { } pub trait MemModule: Send + Sync { - fn send_inputs(&self, mem_op: &[ZiskRequiredMemory]); + fn send_inputs(&self, mem_op: &[MemInput]); fn get_addr_ranges(&self) -> Vec<(u32, u32)>; fn get_flush_input_size(&self) -> u32; } trait MemAlignSm { - fn get_mem_op( - &self, - mem_op: &ZiskRequiredMemory, - mem_values: [u64; 2], - phase: u8, - ) -> MemAlignResponse; + fn get_mem_op(&self, mem_op: &MemInput, phase: u8) -> MemAlignResponse; } struct MemModuleData { pub name: String, - pub inputs: Vec, + pub inputs: Vec, pub flush_input_size: u32, } struct MemAlignOperation { addr: u32, - mem_op: ZiskRequiredMemory, - mem_value: [u64; 2], + input: MemAlignInput, } #[derive(Debug)] @@ -143,18 +137,18 @@ impl MemProxyEngine { // be processed before current mem_op, in this case process all "previous" and after process // the current mem_op. - for mem_op in mem_operations.iter_mut() { + for mem_extern_op in mem_operations.iter_mut() { // self.log_mem_op(mem_op); - + let mem_op = MemInput::from(mem_extern_op); let aligned_mem_addr = Self::to_aligned_addr(mem_op.address); let mem_step = mem_op.step; // Check if there are open mem align operations to be processed in this moment, with // address (or step) less than the aligned of current mem_op. - self.process_all_previous_open_mem_align_ops(aligned_mem_addr, mem_step); + self.process_all_previous_open_mem_align_ops(aligned_mem_addr, mem_step, mem_align_sm); // check if we are at end of loop - if self.check_if_end_of_memory_mark(mem_op) { + if self.check_if_end_of_memory_mark(&mem_op) { break; } @@ -162,30 +156,37 @@ impl MemProxyEngine { let mem_value = self.get_mem_value(aligned_mem_addr); // all open mem align operations are processed, check if new mem operation is aligned - if !Self::is_aligned(&mem_op) { + if !Self::is_aligned(&mem_extern_op) { // In this point found non-aligned memory access, phase-0 - let mem_align_op = mem_align_sm.get_mem_op(mem_op, [mem_value, 0], 0); + let mem_align_input = + MemAlignInput::from(&mem_op, mem_extern_op.width, &[mem_value, 0]); + let mem_align_response = mem_align_sm.get_mem_op(&mem_align_input, 0); // if operation applies to two consecutive memory addresses, add the second part // is enqueued to be processed in future when processing next address on phase-1 - if mem_align_op.more_address { - self.push_open_mem_align_op(aligned_mem_addr, mem_value, mem_op); + if mem_align_response.more_address { + self.push_open_mem_align_op(aligned_mem_addr, &mem_align_input); } self.push_mem_align_response_ops( aligned_mem_addr, mem_value, - mem_op, - &mem_align_op, + &mem_align_input, + &mem_align_response, ); } else { - self.push_mem_op(mem_op); + self.push_mem_op(&mem_op); } } self.finish_prove(); Ok(()) } - fn process_all_previous_open_mem_align_ops(&mut self, mem_addr: u32, mem_step: u64) { + fn process_all_previous_open_mem_align_ops( + &mut self, + mem_addr: u32, + mem_step: u64, + mem_align_sm: &MemAlignSM, + ) { // Two possible situations to process open mem align operations: // // 1) the address of open operation is less than the aligned address. @@ -193,12 +194,13 @@ impl MemProxyEngine { // open operation is less than the step of the current operation. while self.has_open_mem_align_lt(mem_addr, mem_step) { - let open_op = self.open_mem_align_ops.pop_front().unwrap(); + let mut open_op = self.open_mem_align_ops.pop_front().unwrap(); let mem_value = if open_op.addr == self.last_addr { self.last_addr_value } else { 0 }; // call to mem_align to get information of the aligned memory access needed // to prove the unaligned open operation. - let mem_align_op = mem_align_call(&open_op.mem_op, [mem_value, 0], 1); + open_op.input.mem_values[1] = mem_value; + let mem_align_op = mem_align_sm.get_mem_op(&open_op.input, 1); // remove element from top of queue, because we are on last phase, phase 1. self.open_mem_align_ops.pop_front(); @@ -208,7 +210,7 @@ impl MemProxyEngine { self.push_mem_align_response_ops( open_op.addr, mem_value, - &open_op.mem_op, + &open_op.input, &mem_align_op, ); } @@ -224,19 +226,13 @@ impl MemProxyEngine { let aligned_mem_address = (mem_op.address as u64 & MEM_ADDR_MASK) as u32; aligned_mem_address == mem_op.address && mem_op.width == MEM_BYTES as u8 } - fn push_mem_op(&mut self, mem_op: &ZiskRequiredMemory) { + fn push_mem_op(&mut self, mem_op: &MemInput) { self.push_aligned_op(mem_op.is_write, mem_op.address, mem_op.value, mem_op.step); } fn push_aligned_op(&mut self, is_write: bool, addr: u32, value: u64, step: u64) { self.update_last_addr(addr, value); - let mem_op = ZiskRequiredMemory { - step, - is_write, - address: addr as u32, - width: MEM_BYTES as u8, - value, - }; + let mem_op = MemInput { step, is_write, address: addr as u32, value }; debug_info!( "route ==> {}[{:X}] {} {} #{}", self.current_module, @@ -270,17 +266,21 @@ impl MemProxyEngine { &mut self, mem_addr: u32, mem_value: u64, - mem_op: &ZiskRequiredMemory, - mem_align_op: &MemAlignResponse, + mem_align_input: &MemAlignInput, + mem_align_resp: &MemAlignResponse, ) { - self.push_aligned_read(mem_addr, mem_value, mem_align_op.step); - if mem_op.is_write { - let mem_value = mem_align_op.value.expect("value returned by mem_align"); - self.push_aligned_write(mem_addr, mem_value, mem_align_op.step + 1); + self.push_aligned_read(mem_addr, mem_value, mem_align_resp.step); + if mem_align_input.is_write { + // let mem_value = mem_align_resp.value.expect("value returned by mem_align"); + self.push_aligned_write( + mem_addr, + mem_align_resp.value.unwrap(), + mem_align_resp.step + 1, + ); } } - fn create_modules_inputs(&self) -> Vec> { - let mut mem_module_inputs: Vec> = Default::default(); + fn create_modules_inputs(&self) -> Vec> { + let mut mem_module_inputs: Vec> = Default::default(); for _module in self.modules.iter() { mem_module_inputs.push(Vec::new()); } @@ -329,14 +329,15 @@ impl MemProxyEngine { self.open_mem_align_ops.len() > 0 && (self.open_mem_align_ops[0].addr < addr || (self.open_mem_align_ops[0].addr == addr && - self.open_mem_align_ops[0].mem_op.step < step)) + self.open_mem_align_ops[0].input.step < step)) } // method to process open mem align operations, second part of non aligned memory operations // applies to two consecutive memory addresses. fn end_of_memory_mark() -> ZiskRequiredMemory { ZiskRequiredMemory { - step: MAX_MEM_STEP, + step: 0, + step_offset: 0, is_write: false, address: MAX_MEM_ADDR as u32, width: MEM_BYTES as u8, @@ -344,8 +345,9 @@ impl MemProxyEngine { } } #[inline(always)] - fn check_if_end_of_memory_mark(&self, mem_op: &ZiskRequiredMemory) -> bool { - if mem_op.step == MAX_MEM_STEP && mem_op.address == MAX_MEM_ADDR as u32 { + fn check_if_end_of_memory_mark(&self, mem_op: &MemInput) -> bool { + // TODO: 0xFFFF_FFFF not valid address + if mem_op.address == MAX_MEM_ADDR as u32 { assert!( self.open_mem_align_ops.len() == 0, "open_mem_align_ops not empty, has {} elements", @@ -366,6 +368,11 @@ impl MemProxyEngine { } fn finish_prove(&self) { for (module_id, module) in self.modules.iter().enumerate() { + debug_info!( + "{}: flush all({}) inputs", + self.modules_data[module_id].name, + self.modules_data[module_id].inputs.len() + ); module.send_inputs(&self.modules_data[module_id].inputs); } } @@ -396,16 +403,11 @@ impl MemProxyEngine { } #[inline(always)] - fn push_open_mem_align_op( - &mut self, - aligned_mem_addr: u32, - mem_value: u64, - mem_op: &ZiskRequiredMemory, - ) { + fn push_open_mem_align_op(&mut self, aligned_mem_addr: u32, input: &MemAlignInput) { + info!("aligned_mem_addr:{:x}", aligned_mem_addr); self.open_mem_align_ops.push_back(MemAlignOperation { addr: aligned_mem_addr + MEM_BYTES as u32, - mem_op: mem_op.clone(), - mem_value: [mem_value, 0], + input: input.clone(), }); } fn log_mem_op(&self, mem_op: &ZiskRequiredMemory) { diff --git a/state-machines/mem/src/mem_sm.rs b/state-machines/mem/src/mem_sm.rs index ce7b9f14..31e45377 100644 --- a/state-machines/mem/src/mem_sm.rs +++ b/state-machines/mem/src/mem_sm.rs @@ -5,14 +5,13 @@ use std::sync::{ const MEM_INITIAL_ADDRESS: u32 = 0xA0000000; const MEM_FINAL_ADDRESS: u32 = MEM_INITIAL_ADDRESS + 128 * 1024 * 1024; -use crate::MemModule; +use crate::{MemInput, MemModule}; use p3_field::PrimeField; use proofman::{WitnessComponent, WitnessManager}; use proofman_common::AirInstance; use rayon::prelude::*; use sm_common::create_prover_buffer; -use zisk_core::ZiskRequiredMemory; use zisk_pil::{MemTrace, MEM_AIR_IDS, ZISK_AIRGROUP_ID}; pub struct MemSM { @@ -49,7 +48,7 @@ impl MemSM { if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 {} } - pub fn prove(&self, mem_accesses: &[ZiskRequiredMemory]) { + pub fn prove(&self, mem_accesses: &[MemInput]) { // Sort the (full) aligned memory accesses let pctx = self.wcm.get_pctx(); @@ -108,11 +107,11 @@ impl MemSM { /// /// # Parameters /// - /// - `mem_inputs`: A slice of all `ZiskRequiredMemory` inputs + /// - `mem_inputs`: A slice of all `MemoryInput` inputs pub fn prove_instance( &self, - mem_ops: &[ZiskRequiredMemory], - mem_first_row: ZiskRequiredMemory, + mem_ops: &[MemInput], + mem_first_row: MemInput, segment_id: usize, is_last_segment: bool, mut prover_buffer: Vec, @@ -165,13 +164,7 @@ impl MemSM { trace[0].sel = F::zero(); trace[0].wr = F::zero(); - let value = match mem_first_row.width { - 1 => mem_first_row.value as u8 as u64, - 2 => mem_first_row.value as u16 as u64, - 4 => mem_first_row.value as u32 as u64, - 8 => mem_first_row.value, - _ => panic!("Invalid width"), - }; + let value = mem_first_row.value; let (low_val, high_val) = self.get_u32_values(value); trace[0].value = [F::from_canonical_u32(low_val), F::from_canonical_u32(high_val)]; trace[0].addr_changes = F::zero(); @@ -186,19 +179,13 @@ impl MemSM { // trace[i].mem_segment = segment_id_field; // trace[i].mem_last_segment = is_last_segment_field; - trace[i].addr = F::from_canonical_u32(mem_op.address); // n-byte address, real address = addr * MEM_BYTES + let mem_addr = mem_op.address >> 3; + trace[i].addr = F::from_canonical_u32(mem_addr); // n-byte address, real address = addr * MEM_BYTES trace[i].step = F::from_canonical_u64(mem_op.step); trace[i].sel = F::one(); trace[i].wr = F::from_bool(mem_op.is_write); - let value = match mem_op.width { - 1 => mem_op.value as u8 as u64, - 2 => mem_op.value as u16 as u64, - 4 => mem_op.value as u32 as u64, - 8 => mem_op.value, - _ => panic!("Invalid width"), - }; - let (low_val, high_val) = self.get_u32_values(value); + let (low_val, high_val) = self.get_u32_values(mem_op.value); trace[i].value = [F::from_canonical_u32(low_val), F::from_canonical_u32(high_val)]; let addr_changes = trace[i - 1].addr != trace[i].addr; @@ -267,7 +254,7 @@ impl MemSM { } impl MemModule for MemSM { - fn send_inputs(&self, mem_op: &[ZiskRequiredMemory]) { + fn send_inputs(&self, mem_op: &[MemInput]) { self.prove(&mem_op); } fn get_addr_ranges(&self) -> Vec<(u32, u32)> { diff --git a/state-machines/mem/src/mem_unmapped.rs b/state-machines/mem/src/mem_unmapped.rs index 988971d6..f647750a 100644 --- a/state-machines/mem/src/mem_unmapped.rs +++ b/state-machines/mem/src/mem_unmapped.rs @@ -1,10 +1,8 @@ use std::marker::PhantomData; -use crate::MemModule; +use crate::{MemInput, MemModule}; use p3_field::PrimeField; -use zisk_core::ZiskRequiredMemory; - pub struct MemUnmapped { ranges: Vec<(u32, u32)>, __data: PhantomData, @@ -19,7 +17,7 @@ impl MemUnmapped { } } impl MemModule for MemUnmapped { - fn send_inputs(&self, _mem_op: &[ZiskRequiredMemory]) { + fn send_inputs(&self, _mem_op: &[MemInput]) { // panic!("[MemUnmapped] invalid access to addr {:x}", _mem_op[0].addr); } fn get_addr_ranges(&self) -> Vec<(u32, u32)> { From e0ebd21839d558a826b6306054150f4736252606 Mon Sep 17 00:00:00 2001 From: zkronos73 Date: Sun, 24 Nov 2024 22:30:01 +0000 Subject: [PATCH 42/44] fix bugs with memory --- emulator/src/emu.rs | 1 + pil/src/pil_helpers/traces.rs | 2 +- state-machines/main/pil/main.pil | 4 +- state-machines/mem/pil/mem_align.pil | 37 ++++++------ state-machines/mem/src/mem_align_sm.rs | 43 ++++++++------ state-machines/mem/src/mem_proxy_engine.rs | 65 ++++++++++++++++++---- 6 files changed, 104 insertions(+), 48 deletions(-) diff --git a/emulator/src/emu.rs b/emulator/src/emu.rs index 0752d793..9c82aaff 100644 --- a/emulator/src/emu.rs +++ b/emulator/src/emu.rs @@ -116,6 +116,7 @@ impl<'a> Emu<'a> { width: 8, value: self.ctx.inst_ctx.a, }; + emu_mem.push(required_memory); } SRC_IMM => { diff --git a/pil/src/pil_helpers/traces.rs b/pil/src/pil_helpers/traces.rs index 79bef379..40dfc7fd 100644 --- a/pil/src/pil_helpers/traces.rs +++ b/pil/src/pil_helpers/traces.rs @@ -16,7 +16,7 @@ trace!(MemRow, MemTrace { }); trace!(MemAlignRow, MemAlignTrace { - addr: F, offset: F, width: F, wr: F, pc: F, reset: F, sel_up_to_down: F, sel_down_to_up: F, reg: [F; 8], sel: [F; 8], sel_prove: F, step: F, + addr: F, offset: F, width: F, wr: F, pc: F, reset: F, sel_up_to_down: F, sel_down_to_up: F, reg: [F; 8], sel: [F; 8], step: F, sel_prove: F, }); trace!(MemAlignRomRow, MemAlignRomTrace { diff --git a/state-machines/main/pil/main.pil b/state-machines/main/pil/main.pil index 095a6ac6..7ebcfe3a 100644 --- a/state-machines/main/pil/main.pil +++ b/state-machines/main/pil/main.pil @@ -142,7 +142,7 @@ airtemplate Main(int N = 2**21, int RC = 2, int stack_enabled = 0, const int ope mem_load(sel: sel_mem_b, step: STEP, step_offset: 1, - bytes: ind_width, + bytes: b_src_ind * (ind_width - 8) + 8, addr: addr1, value: b); @@ -155,7 +155,7 @@ airtemplate Main(int N = 2**21, int RC = 2, int stack_enabled = 0, const int ope mem_store(sel: store_mem + store_ind, step: STEP, step_offset: 2, - bytes: ind_width, + bytes: store_ind * (ind_width - 8) + 8, addr: addr2, value: store_value); diff --git a/state-machines/mem/pil/mem_align.pil b/state-machines/mem/pil/mem_align.pil index f5405f0d..0c3cce00 100644 --- a/state-machines/mem/pil/mem_align.pil +++ b/state-machines/mem/pil/mem_align.pil @@ -35,7 +35,7 @@ require "std_range_check.pil" [V] In the third clock cycle, we restore the demanded value from w 3] Read operation that spans two memory words w1 = [w1_0, w1_1] and w2 = [w2_0, w2_1]: - w1_0 w1_1 w2_0 w2_1 + w1_0 w1_1 w2_0 w2_1 +---+---+---+---+ +---+===+===+===+ +===+===+===+===+ +===+---+---+---+ | 0 | 1 | 2 | 3 | | 4 | 5 | 6 | 7 | | 0 | 1 | 2 | 3 | | 4 | 5 | 6 | 7 | +---+---+---+---+ +---+===+===+===+ +===+===+===+===+ +===+---+---+---+ @@ -46,7 +46,7 @@ require "std_range_check.pil" [R] In the third clock cycle, we perform an aligned read to w2 4] Write operation that spans two memory words w1 = [w1_0, w1_1] and w2 = [w2_0, w2_1]: - w1_0 w1_1 w2_0 w2_1 + w1_0 w1_1 w2_0 w2_1 +---+===+===+===+ +===+===+===+===+ +===+---+---+---+ +---+---+---+---+ | 0 | 1 | 2 | 3 | | 4 | 5 | 6 | 7 | | 0 | 1 | 2 | 3 | | 4 | 5 | 6 | 7 | +---+===+===+===+ +===+===+===+===+ +===+---+---+---+ +---+---+---+---+ @@ -88,7 +88,7 @@ require "std_range_check.pil" */ airtemplate MemAlign(const int N = 2**10, const int RC = 2, const int CHUNK_NUM = 8, const int CHUNK_BITS = 8) { - const int CHUNK_NUM_HALF = CHUNK_NUM / 2; + const int CHUNKS_BY_RC = CHUNK_NUM / RC; col witness addr; // CHUNK_NUM-byte address, real address = addr * CHUNK_NUM col witness offset; // 0..7, position at which the operation starts @@ -100,6 +100,7 @@ airtemplate MemAlign(const int N = 2**10, const int RC = 2, const int CHUNK_NUM col witness sel_down_to_up; // 1 if the next value is the previous value (e.g. W -> R) col witness reg[CHUNK_NUM]; // Register values, 1 byte each col witness sel[CHUNK_NUM]; // Selectors, 1 if the value is used, 0 otherwise + col witness step; // Step of memory // 1] Ensure the MemAlign follows the program @@ -107,7 +108,7 @@ airtemplate MemAlign(const int N = 2**10, const int RC = 2, const int CHUNK_NUM // - reg' == reg in transitions R -> V, R -> W, W -> V, // - 'reg == reg in transitions V <- W, W <- R, // in any case, sel_up_to_down,sel_down_to_up are 0 in [V] steps. - for (int i = 0; i < CHUNK_NUM; i++) { + for (int i = 0; i < CHUNK_NUM; i++) { range_check(reg[i], 0, 2**CHUNK_BITS-1); (reg[i]' - reg[i]) * sel[i] * sel_up_to_down === 0; @@ -136,7 +137,7 @@ airtemplate MemAlign(const int N = 2**10, const int RC = 2, const int CHUNK_NUM // 2] Assume aligned memory accesses against the Memory component const expr sel_assume = sel_up_to_down + sel_down_to_up; - + // Offset should be 0 in aligned memory accesses, but this is ensured by the rom // Width should be 8 in aligned memory accesses, but this is ensured by the rom @@ -144,8 +145,8 @@ airtemplate MemAlign(const int N = 2**10, const int RC = 2, const int CHUNK_NUM expr assume_val[RC]; for (int i = 0; i < RC; i++) { assume_val[i] = 0; - for (int j = 0; j < CHUNK_NUM_HALF; j++) { - assume_val[i] += reg[j + i * CHUNK_NUM_HALF] * 2**j; + for (int j = 0; j < CHUNKS_BY_RC; j++) { + assume_val[i] += reg[j + i * CHUNKS_BY_RC] * 2**(j*8); } } @@ -156,18 +157,22 @@ airtemplate MemAlign(const int N = 2**10, const int RC = 2, const int CHUNK_NUM // On prove steps, we reconstruct the value in the correct manner chosen by the selectors expr prove_val[RC]; - for (int i = 0; i < RC; i++) { - prove_val[i] = 0; - for (int j = 0; j < CHUNK_NUM_HALF; j++) { - expr _prove_val = 0; - for (int k = j; k < j + CHUNK_NUM_HALF; k++) { - _prove_val += reg[(k + i * CHUNK_NUM_HALF) % CHUNK_NUM] * 2**(k-j); + for (int rc_index = 0; rc_index < RC; ++rc_index) { + prove_val[rc_index] = 0; + } + for (int _offset = 0; _offset < CHUNK_NUM; _offset++) { + for (int rc_index = 0; rc_index < RC; rc_index++) { + expr _tmp = 0; + int base = 1; + for (int ichunk = 0; ichunk < CHUNKS_BY_RC; ichunk++) { + _tmp += reg[(_offset + rc_index * CHUNKS_BY_RC + ichunk) % CHUNK_NUM] * base; + base = base * 256; } - prove_val[i] += sel[j + i * CHUNK_NUM_HALF] * _prove_val; + prove_val[rc_index] += sel[_offset] * _tmp; } } // We prove and assume with the same permutation check but with disjoint and different sign selectors - col witness step; - permutation(MEMORY_ID, cols: [wr * (MEMORY_STORE_OP - MEMORY_LOAD_OP) + MEMORY_LOAD_OP, addr * CHUNK_NUM + offset, step, width, ...prove_val], sel: sel_prove - sel_assume); + permutation_proves(MEMORY_ID, cols: [wr * (MEMORY_STORE_OP - MEMORY_LOAD_OP) + MEMORY_LOAD_OP, addr * CHUNK_NUM + offset, step, width, ...prove_val], sel: sel_prove); + permutation_assumes(MEMORY_ID, cols: [wr * (MEMORY_STORE_OP - MEMORY_LOAD_OP) + MEMORY_LOAD_OP, addr * CHUNK_NUM + offset, step, width, ...assume_val], sel: sel_assume); } \ No newline at end of file diff --git a/state-machines/mem/src/mem_align_sm.rs b/state-machines/mem/src/mem_align_sm.rs index fcd735c7..2124e4b7 100644 --- a/state-machines/mem/src/mem_align_sm.rs +++ b/state-machines/mem/src/mem_align_sm.rs @@ -342,7 +342,7 @@ impl MemAlignSM { addr: F::from_canonical_u32(addr_read), offset: F::from_canonical_usize(offset), width: F::from_canonical_usize(width), - // wr: F::from_bool(false), + wr: F::from_bool(true), pc: F::from_canonical_u64(next_pc + 1), // reset: F::from_bool(false), sel_prove: F::from_bool(true), @@ -503,8 +503,11 @@ impl MemAlignSM { first_read_row.sel[i] = F::from_bool(true); } + // value_row.reg[i] = + // F::from_canonical_u64(Self::get_byte(value, i, offset)); value_row.reg[i] = - F::from_canonical_u64(Self::get_byte(value, i, offset)); + F::from_canonical_u64(Self::get_byte(value, i, CHUNK_NUM - offset)); + if i == offset { value_row.sel[i] = F::from_bool(true); } @@ -713,7 +716,7 @@ impl MemAlignSM { addr: F::from_canonical_u32(addr_first_read_write), offset: F::from_canonical_usize(offset), width: F::from_canonical_usize(width), - // wr: F::from_bool(false), + wr: F::from_bool(true), pc: F::from_canonical_u64(next_pc + 1), // reset: F::from_bool(false), sel_prove: F::from_bool(true), @@ -721,7 +724,7 @@ impl MemAlignSM { }; let mut second_write_row = MemAlignRow:: { - step: F::from_canonical_u64(step), + step: F::from_canonical_u64(step + 1), addr: F::from_canonical_u32(addr_second_read_write), offset: F::from_canonical_u64(DEFAULT_OFFSET), width: F::from_canonical_u64(DEFAULT_WIDTH), @@ -733,7 +736,7 @@ impl MemAlignSM { }; let mut second_read_row = MemAlignRow:: { - step: F::from_canonical_u64(step + 1), + step: F::from_canonical_u64(step), addr: F::from_canonical_u32(addr_second_read_write), offset: F::from_canonical_u64(DEFAULT_OFFSET), width: F::from_canonical_u64(DEFAULT_WIDTH), @@ -757,6 +760,18 @@ impl MemAlignSM { first_write_row.sel[i] = F::from_bool(true); } + second_write_row.reg[i] = + F::from_canonical_u64(Self::get_byte(value_second_write, i, 0)); + if i < rem_bytes { + second_write_row.sel[i] = F::from_bool(true); + } + + second_read_row.reg[i] = + F::from_canonical_u64(Self::get_byte(value_second_read, i, 0)); + if i >= rem_bytes { + second_read_row.sel[i] = F::from_bool(true); + } + value_row.reg[i] = { if i < rem_bytes { second_write_row.reg[i] @@ -773,18 +788,6 @@ impl MemAlignSM { if i == offset { value_row.sel[i] = F::from_bool(true); } - - second_write_row.reg[i] = - F::from_canonical_u64(Self::get_byte(value_second_write, i, 0)); - if i < rem_bytes { - second_write_row.sel[i] = F::from_bool(true); - } - - second_read_row.reg[i] = - F::from_canonical_u64(Self::get_byte(value_second_read, i, 0)); - if i >= rem_bytes { - second_read_row.sel[i] = F::from_bool(true); - } } #[rustfmt::skip] @@ -871,8 +874,12 @@ impl MemAlignSM { pub fn prove(&self, computed_rows: &[MemAlignRow]) { if let Ok(mut rows) = self.rows.lock() { + let row_index = rows.len(); rows.extend_from_slice(computed_rows); - + for (index, row) in computed_rows.iter().enumerate() { + let _addr = row.addr.as_canonical_biguint().to_u64().unwrap(); + println!("MEM_ALIGN_ROW_0x{:X} => {:?} row:{}", _addr * 8, row, row_index + index); + } #[cfg(feature = "debug_mem_align")] { let mut num_rows = self.num_computed_rows.lock().unwrap(); diff --git a/state-machines/mem/src/mem_proxy_engine.rs b/state-machines/mem/src/mem_proxy_engine.rs index 50315cdd..2b167108 100644 --- a/state-machines/mem/src/mem_proxy_engine.rs +++ b/state-machines/mem/src/mem_proxy_engine.rs @@ -9,6 +9,9 @@ use p3_field::PrimeField; use proofman_util::{timer_start_debug, timer_stop_and_log_debug}; use zisk_core::ZiskRequiredMemory; +#[cfg(feature = "debug_mem_proxy_engine")] +const DEBUG_ADDR: u32 = 0xA0008F10; + macro_rules! debug_info { ($prefix:expr, $($arg:tt)*) => { #[cfg(feature = "debug_mem_proxy_engine")] @@ -138,7 +141,7 @@ impl MemProxyEngine { // the current mem_op. for mem_extern_op in mem_operations.iter_mut() { - // self.log_mem_op(mem_op); + self.log_mem_op(mem_extern_op); let mem_op = MemInput::from(mem_extern_op); let aligned_mem_addr = Self::to_aligned_addr(mem_op.address); let mem_step = mem_op.step; @@ -162,6 +165,21 @@ impl MemProxyEngine { MemAlignInput::from(&mem_op, mem_extern_op.width, &[mem_value, 0]); let mem_align_response = mem_align_sm.get_mem_op(&mem_align_input, 0); + #[cfg(feature = "debug_mem_proxy_engine")] + if mem_align_input.address >= DEBUG_ADDR - 8 && + mem_align_input.address <= DEBUG_ADDR + 8 + { + debug_info!( + "mem_align_input_{:X}: phase: 0 {:?}", + mem_align_input.address, + mem_align_input + ); + debug_info!( + "mem_align_response_{:X}: phase: 0 {:?}", + mem_align_input.address, + mem_align_response + ); + } // if operation applies to two consecutive memory addresses, add the second part // is enqueued to be processed in future when processing next address on phase-1 if mem_align_response.more_address { @@ -195,23 +213,33 @@ impl MemProxyEngine { while self.has_open_mem_align_lt(mem_addr, mem_step) { let mut open_op = self.open_mem_align_ops.pop_front().unwrap(); - let mem_value = if open_op.addr == self.last_addr { self.last_addr_value } else { 0 }; + let mem_value = self.get_mem_value(open_op.addr); // call to mem_align to get information of the aligned memory access needed // to prove the unaligned open operation. open_op.input.mem_values[1] = mem_value; - let mem_align_op = mem_align_sm.get_mem_op(&open_op.input, 1); - - // remove element from top of queue, because we are on last phase, phase 1. - self.open_mem_align_ops.pop_front(); - + let mem_align_resp = mem_align_sm.get_mem_op(&open_op.input, 1); + + #[cfg(feature = "debug_mem_proxy_engine")] + if open_op.input.address >= DEBUG_ADDR - 8 && open_op.input.address <= DEBUG_ADDR + 8 { + debug_info!( + "mem_align_input_{:X}: phase:1 {:?}", + open_op.input.address, + open_op.input + ); + debug_info!( + "mem_align_response_{:X}: phase:1 {:?}", + open_op.input.address, + mem_align_resp + ); + } // push the aligned memory operations for current address (read or read+write) and // update last_address and last_value. self.push_mem_align_response_ops( open_op.addr, mem_value, &open_op.input, - &mem_align_op, + &mem_align_resp, ); } } @@ -271,7 +299,21 @@ impl MemProxyEngine { ) { self.push_aligned_read(mem_addr, mem_value, mem_align_resp.step); if mem_align_input.is_write { - // let mem_value = mem_align_resp.value.expect("value returned by mem_align"); + #[cfg(feature = "debug_mem_proxy_engine")] + if mem_addr >= DEBUG_ADDR - 8 && mem_addr <= DEBUG_ADDR - 8 { + debug_info!( + "push_mem_align_response_ops_{:X}-A: value:{} {:?}", + mem_addr, + mem_align_resp.value.unwrap(), + mem_align_resp + ); + debug_info!( + "push_mem_align_response_ops_{:X}-B: mem_value:{} {:?}", + mem_addr, + mem_value, + mem_align_input + ); + } self.push_aligned_write( mem_addr, mem_align_resp.value.unwrap(), @@ -404,7 +446,6 @@ impl MemProxyEngine { #[inline(always)] fn push_open_mem_align_op(&mut self, aligned_mem_addr: u32, input: &MemAlignInput) { - info!("aligned_mem_addr:{:x}", aligned_mem_addr); self.open_mem_align_ops.push_back(MemAlignOperation { addr: aligned_mem_addr + MEM_BYTES as u32, input: input.clone(), @@ -412,12 +453,14 @@ impl MemProxyEngine { } fn log_mem_op(&self, mem_op: &ZiskRequiredMemory) { debug_info!( - "next input [0x{:x}] {} {} {}b #{}", + "next input [0x{:x}] {} {} {}b #{} [0x{:x},{}]", mem_op.address, if mem_op.is_write { "W" } else { "R" }, mem_op.value, mem_op.width, mem_op.step, + self.last_addr, + self.last_addr_value ); } #[inline(always)] From 84be6778c38ca983376b64c4ae05881fd9606d7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9ctor=20Masip?= Date: Mon, 25 Nov 2024 10:57:20 +0000 Subject: [PATCH 43/44] Fixing typos --- state-machines/mem/pil/mem_align_rom.pil | 4 +-- state-machines/mem/src/mem_align_sm.rs | 46 +++++++----------------- 2 files changed, 15 insertions(+), 35 deletions(-) diff --git a/state-machines/mem/pil/mem_align_rom.pil b/state-machines/mem/pil/mem_align_rom.pil index d1e60c7c..db0d4440 100644 --- a/state-machines/mem/pil/mem_align_rom.pil +++ b/state-machines/mem/pil/mem_align_rom.pil @@ -185,7 +185,7 @@ airtemplate MemAlignRom(const int N = MEM_ALIGN_ROM_SIZE, const int CHUNK_NUM = pc = prev_line; delta_pc = -pc; // delta_addr = 0; - // is_write = 0; + is_write = 1; // reset = 0; for (int j = 0; j < CHUNK_NUM; j++) { if (j == OFFSET[i]) { @@ -271,7 +271,7 @@ airtemplate MemAlignRom(const int N = MEM_ALIGN_ROM_SIZE, const int CHUNK_NUM = pc = prev_line; delta_pc = 1; // delta_addr = 0; - // is_write = 0; + is_write = 1; // reset = 0; for (int j = 0; j < CHUNK_NUM; j++) { if (j == OFFSET[i]) { diff --git a/state-machines/mem/src/mem_align_sm.rs b/state-machines/mem/src/mem_align_sm.rs index 2124e4b7..eafbd80d 100644 --- a/state-machines/mem/src/mem_align_sm.rs +++ b/state-machines/mem/src/mem_align_sm.rs @@ -131,9 +131,6 @@ impl MemAlignSM { #[inline(always)] pub fn get_mem_op(&self, input: &MemAlignInput, phase: usize) -> MemAlignResponse { - // Sanity check - // assert!(mem_values.len() == phase + 1); // TODO: Debug mode - debug_assert!( input.mem_values.len() == phase + 1, "The number of mem_values {} is not equal to phase + 1 {}", @@ -234,7 +231,6 @@ impl MemAlignSM { "\nOne Word Read\n\ Num Rows: {:?}\n\ Input: {:?}\n\ - Mem Values: {:?}\n\ Phase: {:?}\n\ Value Read: {:?}\n\ Value: {:?}\n\ @@ -242,7 +238,6 @@ impl MemAlignSM { Flags Value: {:?}", [*num_rows, *num_rows + 1], input, - mem_values, phase, value_read.to_le_bytes(), value.to_le_bytes(), @@ -263,8 +258,6 @@ impl MemAlignSM { // Prove the generated rows self.prove(&[read_row, value_row]); - println!("MEM_ALIGN_PRE_ROW(R): {:?}", read_row); - println!("MEM_ALIGN_PRE_ROW(V): {:?}", value_row); MemAlignResponse { more_address: false, step, value: None } } @@ -377,7 +370,6 @@ impl MemAlignSM { "\nOne Word Write\n\ Num Rows: {:?}\n\ Input: {:?}\n\ - Mem Values: {:?}\n\ Phase: {:?}\n\ Value Read: {:?}\n\ Value Write: {:?}\n\ @@ -387,7 +379,6 @@ impl MemAlignSM { Flags Value: {:?}", [*num_rows, *num_rows + 2], input, - mem_values, phase, value_read.to_le_bytes(), value_write.to_le_bytes(), @@ -503,8 +494,6 @@ impl MemAlignSM { first_read_row.sel[i] = F::from_bool(true); } - // value_row.reg[i] = - // F::from_canonical_u64(Self::get_byte(value, i, offset)); value_row.reg[i] = F::from_canonical_u64(Self::get_byte(value, i, CHUNK_NUM - offset)); @@ -524,7 +513,6 @@ impl MemAlignSM { "\nTwo Words Read\n\ Num Rows: {:?}\n\ Input: {:?}\n\ - Mem Values: {:?}\n\ Phase: {:?}\n\ Value First Read: {:?}\n\ Value: {:?}\n\ @@ -534,7 +522,6 @@ impl MemAlignSM { Flags Second Read: {:?}", [*num_rows, *num_rows + 2], input, - mem_values, phase, value_first_read.to_le_bytes(), value.to_le_bytes(), @@ -760,18 +747,6 @@ impl MemAlignSM { first_write_row.sel[i] = F::from_bool(true); } - second_write_row.reg[i] = - F::from_canonical_u64(Self::get_byte(value_second_write, i, 0)); - if i < rem_bytes { - second_write_row.sel[i] = F::from_bool(true); - } - - second_read_row.reg[i] = - F::from_canonical_u64(Self::get_byte(value_second_read, i, 0)); - if i >= rem_bytes { - second_read_row.sel[i] = F::from_bool(true); - } - value_row.reg[i] = { if i < rem_bytes { second_write_row.reg[i] @@ -788,6 +763,18 @@ impl MemAlignSM { if i == offset { value_row.sel[i] = F::from_bool(true); } + + second_write_row.reg[i] = + F::from_canonical_u64(Self::get_byte(value_second_write, i, 0)); + if i < rem_bytes { + second_write_row.sel[i] = F::from_bool(true); + } + + second_read_row.reg[i] = + F::from_canonical_u64(Self::get_byte(value_second_read, i, 0)); + if i >= rem_bytes { + second_read_row.sel[i] = F::from_bool(true); + } } #[rustfmt::skip] @@ -795,7 +782,6 @@ impl MemAlignSM { "\nTwo Words Write\n\ Num Rows: {:?}\n\ Input: {:?}\n\ - Mem Values: {:?}\n\ Phase: {:?}\n\ Value First Read: {:?}\n\ Value First Write: {:?}\n\ @@ -809,7 +795,6 @@ impl MemAlignSM { Flags Second Read: {:?}", [*num_rows, *num_rows + 4], input, - mem_values, phase, value_first_read.to_le_bytes(), value_first_write.to_le_bytes(), @@ -874,12 +859,8 @@ impl MemAlignSM { pub fn prove(&self, computed_rows: &[MemAlignRow]) { if let Ok(mut rows) = self.rows.lock() { - let row_index = rows.len(); rows.extend_from_slice(computed_rows); - for (index, row) in computed_rows.iter().enumerate() { - let _addr = row.addr.as_canonical_biguint().to_u64().unwrap(); - println!("MEM_ALIGN_ROW_0x{:X} => {:?} row:{}", _addr * 8, row, row_index + index); - } + #[cfg(feature = "debug_mem_align")] { let mut num_rows = self.num_computed_rows.lock().unwrap(); @@ -931,7 +912,6 @@ impl MemAlignSM { for (i, &row) in rows.iter().enumerate() { // Store the entire row trace_buffer[i] = row; - println!("MEM_ALIGN_ROW: {:?}", row); // Store the value of all reg columns so that they can be range checked for j in 0..CHUNK_NUM { let element = From 5ef0915da705bc78ced21906e8dae395a0425e69 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9ctor=20Masip=20Ardevol?= Date: Mon, 25 Nov 2024 21:00:12 +0100 Subject: [PATCH 44/44] Mem align and mem fixes (#175) --- pil/src/pil_helpers/traces.rs | 6 +- state-machines/mem/pil/mem.pil | 8 +- state-machines/mem/pil/mem_align.pil | 30 ++-- state-machines/mem/src/mem_align_rom_sm.rs | 46 ++--- state-machines/mem/src/mem_align_sm.rs | 87 +++++++-- state-machines/mem/src/mem_proxy.rs | 4 +- state-machines/mem/src/mem_sm.rs | 195 +++++++++++++-------- 7 files changed, 255 insertions(+), 121 deletions(-) diff --git a/pil/src/pil_helpers/traces.rs b/pil/src/pil_helpers/traces.rs index 40dfc7fd..e9631826 100644 --- a/pil/src/pil_helpers/traces.rs +++ b/pil/src/pil_helpers/traces.rs @@ -12,11 +12,11 @@ trace!(RomRow, RomTrace { }); trace!(MemRow, MemTrace { - addr: F, step: F, sel: F, wr: F, value: [F; 2], addr_changes: F, same_value: F, first_addr_access_is_read: F, + addr: F, step: F, sel: F, wr: F, value: [F; 2], addr_changes: F, increment: F, same_value: F, first_addr_access_is_read: F, }); trace!(MemAlignRow, MemAlignTrace { - addr: F, offset: F, width: F, wr: F, pc: F, reset: F, sel_up_to_down: F, sel_down_to_up: F, reg: [F; 8], sel: [F; 8], step: F, sel_prove: F, + addr: F, offset: F, width: F, wr: F, pc: F, reset: F, sel_up_to_down: F, sel_down_to_up: F, reg: [F; 8], sel: [F; 8], step: F, delta_addr: F, sel_prove: F, value: [F; 2], }); trace!(MemAlignRomRow, MemAlignRomTrace { @@ -58,5 +58,3 @@ trace!(SpecifiedRangesRow, SpecifiedRangesTrace { trace!(U8AirRow, U8AirTrace { mul: F, }); - - diff --git a/state-machines/mem/pil/mem.pil b/state-machines/mem/pil/mem.pil index d5084847..23740730 100644 --- a/state-machines/mem/pil/mem.pil +++ b/state-machines/mem/pil/mem.pil @@ -6,12 +6,12 @@ const int MEMORY_CONT_ID = 11; const int MEMORY_LOAD_OP = 1; const int MEMORY_STORE_OP = 2; -const int MEMORY_MAX_DIFF = 2**22; +const int MEMORY_MAX_DIFF = 2**24; const int MAX_MEM_STEP_OFFSET = 2; const int MAX_MEM_OPS_PER_MAIN_STEP = (MAX_MEM_STEP_OFFSET + 1) * 2; -airtemplate Mem(const int N = 2**21, const int RC = 2, const int id = MEMORY_ID, const int MAX_STEP = 2 ** 23, const int MEM_BYTES = 8) { +airtemplate Mem(const int N = 2**21, const int id = MEMORY_ID, const int RC = 2, const int MEM_BYTES = 8, const int INITIAL_ADDRESS = 0xA0000000) { col fixed SEGMENT_L1 = [1,0...]; const expr SEGMENT_LAST = SEGMENT_L1'; @@ -38,7 +38,9 @@ airtemplate Mem(const int N = 2**21, const int RC = 2, const int id = MEMORY_ID, addr_changes * (1 - addr_changes) === 0; // check increment of memory - range_check(sel: (1 - SEGMENT_L1), colu: addr_changes * (addr - 'addr - step + 'step) + step - 'step, min: 1, max: MEMORY_MAX_DIFF); + col witness increment; + increment === SEGMENT_L1 * (addr - 1 - INITIAL_ADDRESS) + (1 - SEGMENT_L1) * (addr_changes * (addr - 'addr - step + 'step) + step - 'step); + range_check(colu: increment, min: 1, max: MEMORY_MAX_DIFF); // PADDING: At end of memory fill with same addr, incrementing step, same value, sel = 0, rd = 1, wr = 0 // setting mem_last_segment = 1 diff --git a/state-machines/mem/pil/mem_align.pil b/state-machines/mem/pil/mem_align.pil index 0c3cce00..8a23ab2a 100644 --- a/state-machines/mem/pil/mem_align.pil +++ b/state-machines/mem/pil/mem_align.pil @@ -100,7 +100,7 @@ airtemplate MemAlign(const int N = 2**10, const int RC = 2, const int CHUNK_NUM col witness sel_down_to_up; // 1 if the next value is the previous value (e.g. W -> R) col witness reg[CHUNK_NUM]; // Register values, 1 byte each col witness sel[CHUNK_NUM]; // Selectors, 1 if the value is used, 0 otherwise - col witness step; // Step of memory + col witness step; // Memory step // 1] Ensure the MemAlign follows the program @@ -133,7 +133,12 @@ airtemplate MemAlign(const int N = 2**10, const int RC = 2, const int CHUNK_NUM } flags += wr * 2**CHUNK_NUM + reset * 2**(CHUNK_NUM + 1) + sel_up_to_down * 2**(CHUNK_NUM + 2) + sel_down_to_up * 2**(CHUNK_NUM + 3); - lookup_assumes(MEM_ALIGN_ROM_ID, [pc, pc'-pc, (addr-'addr)*(1-reset), offset, width, flags]); + // Perform the lookup against the program + expr delta_pc; + col witness delta_addr; // Auxiliary column + delta_pc = pc' - pc; + delta_addr === (addr - 'addr) * (1 - reset); + lookup_assumes(MEM_ALIGN_ROM_ID, [pc, delta_pc, delta_addr, offset, width, flags]); // 2] Assume aligned memory accesses against the Memory component const expr sel_assume = sel_up_to_down + sel_down_to_up; @@ -143,10 +148,12 @@ airtemplate MemAlign(const int N = 2**10, const int RC = 2, const int CHUNK_NUM // On assume steps, we reconstruct the value from the registers directly expr assume_val[RC]; - for (int i = 0; i < RC; i++) { - assume_val[i] = 0; - for (int j = 0; j < CHUNKS_BY_RC; j++) { - assume_val[i] += reg[j + i * CHUNKS_BY_RC] * 2**(j*8); + for (int rc_index = 0; rc_index < RC; rc_index++) { + assume_val[rc_index] = 0; + int base = 1; + for (int _offset = 0; _offset < CHUNKS_BY_RC; _offset++) { + assume_val[rc_index] += reg[_offset + rc_index * CHUNKS_BY_RC] * base; + base *= 256; } } @@ -157,7 +164,7 @@ airtemplate MemAlign(const int N = 2**10, const int RC = 2, const int CHUNK_NUM // On prove steps, we reconstruct the value in the correct manner chosen by the selectors expr prove_val[RC]; - for (int rc_index = 0; rc_index < RC; ++rc_index) { + for (int rc_index = 0; rc_index < RC; rc_index++) { prove_val[rc_index] = 0; } for (int _offset = 0; _offset < CHUNK_NUM; _offset++) { @@ -166,13 +173,16 @@ airtemplate MemAlign(const int N = 2**10, const int RC = 2, const int CHUNK_NUM int base = 1; for (int ichunk = 0; ichunk < CHUNKS_BY_RC; ichunk++) { _tmp += reg[(_offset + rc_index * CHUNKS_BY_RC + ichunk) % CHUNK_NUM] * base; - base = base * 256; + base *= 256; } prove_val[rc_index] += sel[_offset] * _tmp; } } // We prove and assume with the same permutation check but with disjoint and different sign selectors - permutation_proves(MEMORY_ID, cols: [wr * (MEMORY_STORE_OP - MEMORY_LOAD_OP) + MEMORY_LOAD_OP, addr * CHUNK_NUM + offset, step, width, ...prove_val], sel: sel_prove); - permutation_assumes(MEMORY_ID, cols: [wr * (MEMORY_STORE_OP - MEMORY_LOAD_OP) + MEMORY_LOAD_OP, addr * CHUNK_NUM + offset, step, width, ...assume_val], sel: sel_assume); + col witness value[RC]; // Auxiliary columns + for (int i = 0; i < RC; i++) { + value[i] === sel_prove * prove_val[i] + sel_assume * assume_val[i]; + } + permutation(MEMORY_ID, cols: [wr * (MEMORY_STORE_OP - MEMORY_LOAD_OP) + MEMORY_LOAD_OP, addr * CHUNK_NUM + offset, step, width, ...value], sel: sel_prove - sel_assume); } \ No newline at end of file diff --git a/state-machines/mem/src/mem_align_rom_sm.rs b/state-machines/mem/src/mem_align_rom_sm.rs index 61170135..df6081e9 100644 --- a/state-machines/mem/src/mem_align_rom_sm.rs +++ b/state-machines/mem/src/mem_align_rom_sm.rs @@ -79,22 +79,22 @@ impl MemAlignRomSM { pub fn calculate_next_pc(&self, opcode: MemOp, offset: usize, width: usize) -> u64 { // Get the table offset - let (table_offset, one_word) = match opcode { - MemOp::OneRead => { - (1, true) - } - - MemOp::OneWrite => { - (1 + ONE_WORD_COMBINATIONS * OP_SIZES[0], true) - } - - MemOp::TwoReads => { - (1 + ONE_WORD_COMBINATIONS * OP_SIZES[0] + ONE_WORD_COMBINATIONS * OP_SIZES[1], false) - } - - MemOp::TwoWrites => { - (1 + ONE_WORD_COMBINATIONS * OP_SIZES[0] + ONE_WORD_COMBINATIONS * OP_SIZES[1] + TWO_WORD_COMBINATIONS * OP_SIZES[2], false) - } + let (table_offset, one_word) = match opcode { + MemOp::OneRead => (1, true), + + MemOp::OneWrite => (1 + ONE_WORD_COMBINATIONS * OP_SIZES[0], true), + + MemOp::TwoReads => ( + 1 + ONE_WORD_COMBINATIONS * OP_SIZES[0] + ONE_WORD_COMBINATIONS * OP_SIZES[1], + false, + ), + + MemOp::TwoWrites => ( + 1 + ONE_WORD_COMBINATIONS * OP_SIZES[0] + + ONE_WORD_COMBINATIONS * OP_SIZES[1] + + TWO_WORD_COMBINATIONS * OP_SIZES[2], + false, + ), }; // Get the first row index @@ -114,7 +114,13 @@ impl MemAlignRomSM { first_row_idx } - fn get_first_row_idx(opcode: MemOp, offset: usize, width: usize, table_offset: u64, one_word: bool) -> u64 { + fn get_first_row_idx( + opcode: MemOp, + offset: usize, + width: usize, + table_offset: u64, + one_word: bool, + ) -> u64 { let opcode_idx = opcode as usize; let op_size = OP_SIZES[opcode_idx]; @@ -203,11 +209,7 @@ impl MemAlignRomSM { } } - info!( - "{}: ··· Creating Mem Align ROM instance [{} rows filled 100%]", - Self::MY_NAME, - self.num_rows, - ); + info!("{}: ··· Creating Mem Align Rom instance", Self::MY_NAME,); let air_instance = AirInstance::new(sctx, ZISK_AIRGROUP_ID, MEM_ALIGN_ROM_AIR_IDS[0], None, prover_buffer); diff --git a/state-machines/mem/src/mem_align_sm.rs b/state-machines/mem/src/mem_align_sm.rs index eafbd80d..40e1e148 100644 --- a/state-machines/mem/src/mem_align_sm.rs +++ b/state-machines/mem/src/mem_align_sm.rs @@ -17,8 +17,12 @@ use zisk_pil::{MemAlignRow, MemAlignTrace, MEM_ALIGN_AIR_IDS, ZISK_AIRGROUP_ID}; use crate::{MemAlignInput, MemAlignRomSM, MemOp}; +const RC: usize = 2; const CHUNK_NUM: usize = 8; +const CHUNKS_BY_RC: usize = CHUNK_NUM / RC; const CHUNK_BITS: usize = 8; +const RC_BITS: u64 = (CHUNKS_BY_RC * CHUNK_BITS) as u64; +const RC_MASK: u64 = (1 << RC_BITS) - 1; const OFFSET_MASK: u32 = 0x07; const OFFSET_BITS: u32 = 3; const CHUNK_BITS_MASK: u64 = (1 << CHUNK_BITS) - 1; @@ -131,13 +135,6 @@ impl MemAlignSM { #[inline(always)] pub fn get_mem_op(&self, input: &MemAlignInput, phase: usize) -> MemAlignResponse { - debug_assert!( - input.mem_values.len() == phase + 1, - "The number of mem_values {} is not equal to phase + 1 {}", - input.mem_values.len(), - phase + 1 - ); - let addr = input.address; let width = input.width; @@ -192,6 +189,7 @@ impl MemAlignSM { let mut read_row = MemAlignRow:: { step: F::from_canonical_u64(step), addr: F::from_canonical_u32(addr_read), + // delta_addr: F::zero(), offset: F::from_canonical_u64(DEFAULT_OFFSET), width: F::from_canonical_u64(DEFAULT_WIDTH), // wr: F::from_bool(false), @@ -204,6 +202,7 @@ impl MemAlignSM { let mut value_row = MemAlignRow:: { step: F::from_canonical_u64(step), addr: F::from_canonical_u32(addr_read), + // delta_addr: F::zero(), offset: F::from_canonical_usize(offset), width: F::from_canonical_usize(width), // wr: F::from_bool(false), @@ -226,6 +225,15 @@ impl MemAlignSM { } } + let mut _value_read = value_read; + let mut _value = value; + for i in 0..RC { + read_row.value[i] = F::from_canonical_u64(_value_read & RC_MASK); + value_row.value[i] = F::from_canonical_u64(_value & RC_MASK); + _value_read >>= RC_BITS; + _value >>= RC_BITS; + } + #[rustfmt::skip] debug_info!( "\nOne Word Read\n\ @@ -309,6 +317,7 @@ impl MemAlignSM { let mut read_row = MemAlignRow:: { step: F::from_canonical_u64(step), addr: F::from_canonical_u32(addr_read), + // delta_addr: F::zero(), offset: F::from_canonical_u64(DEFAULT_OFFSET), width: F::from_canonical_u64(DEFAULT_WIDTH), // wr: F::from_bool(false), @@ -321,6 +330,7 @@ impl MemAlignSM { let mut write_row = MemAlignRow:: { step: F::from_canonical_u64(step + 1), addr: F::from_canonical_u32(addr_read), + // delta_addr: F::zero(), offset: F::from_canonical_u64(DEFAULT_OFFSET), width: F::from_canonical_u64(DEFAULT_WIDTH), wr: F::from_bool(true), @@ -333,6 +343,7 @@ impl MemAlignSM { let mut value_row = MemAlignRow:: { step: F::from_canonical_u64(step), addr: F::from_canonical_u32(addr_read), + // delta_addr: F::zero(), offset: F::from_canonical_usize(offset), width: F::from_canonical_usize(width), wr: F::from_bool(true), @@ -365,6 +376,18 @@ impl MemAlignSM { } } + let mut _value_read = value_read; + let mut _value_write = value_write; + let mut _value = value; + for i in 0..RC { + read_row.value[i] = F::from_canonical_u64(_value_read & RC_MASK); + write_row.value[i] = F::from_canonical_u64(_value_write & RC_MASK); + value_row.value[i] = F::from_canonical_u64(_value & RC_MASK); + _value_read >>= RC_BITS; + _value_write >>= RC_BITS; + _value >>= RC_BITS; + } + #[rustfmt::skip] debug_info!( "\nOne Word Write\n\ @@ -454,6 +477,7 @@ impl MemAlignSM { let mut first_read_row = MemAlignRow:: { step: F::from_canonical_u64(step), addr: F::from_canonical_u32(addr_first_read), + // delta_addr: F::zero(), offset: F::from_canonical_u64(DEFAULT_OFFSET), width: F::from_canonical_u64(DEFAULT_WIDTH), // wr: F::from_bool(false), @@ -466,6 +490,7 @@ impl MemAlignSM { let mut value_row = MemAlignRow:: { step: F::from_canonical_u64(step), addr: F::from_canonical_u32(addr_first_read), + // delta_addr: F::zero(), offset: F::from_canonical_usize(offset), width: F::from_canonical_usize(width), // wr: F::from_bool(false), @@ -478,6 +503,7 @@ impl MemAlignSM { let mut second_read_row = MemAlignRow:: { step: F::from_canonical_u64(step), addr: F::from_canonical_u32(addr_second_read), + delta_addr: F::one(), offset: F::from_canonical_u64(DEFAULT_OFFSET), width: F::from_canonical_u64(DEFAULT_WIDTH), // wr: F::from_bool(false), @@ -508,6 +534,20 @@ impl MemAlignSM { } } + let mut _value_first_read = value_first_read; + let mut _value = value; + let mut _value_second_read = value_second_read; + for i in 0..RC { + first_read_row.value[i] = + F::from_canonical_u64(_value_first_read & RC_MASK); + value_row.value[i] = F::from_canonical_u64(_value & RC_MASK); + second_read_row.value[i] = + F::from_canonical_u64(_value_second_read & RC_MASK); + _value_first_read >>= RC_BITS; + _value >>= RC_BITS; + _value_second_read >>= RC_BITS; + } + #[rustfmt::skip] debug_info!( "\nTwo Words Read\n\ @@ -677,6 +717,7 @@ impl MemAlignSM { let mut first_read_row = MemAlignRow:: { step: F::from_canonical_u64(step), addr: F::from_canonical_u32(addr_first_read_write), + // delta_addr: F::zero(), offset: F::from_canonical_u64(DEFAULT_OFFSET), width: F::from_canonical_u64(DEFAULT_WIDTH), // wr: F::from_bool(false), @@ -689,6 +730,7 @@ impl MemAlignSM { let mut first_write_row = MemAlignRow:: { step: F::from_canonical_u64(step + 1), addr: F::from_canonical_u32(addr_first_read_write), + // delta_addr: F::zero(), offset: F::from_canonical_u64(DEFAULT_OFFSET), width: F::from_canonical_u64(DEFAULT_WIDTH), wr: F::from_bool(true), @@ -701,6 +743,7 @@ impl MemAlignSM { let mut value_row = MemAlignRow:: { step: F::from_canonical_u64(step), addr: F::from_canonical_u32(addr_first_read_write), + // delta_addr: F::zero(), offset: F::from_canonical_usize(offset), width: F::from_canonical_usize(width), wr: F::from_bool(true), @@ -713,6 +756,7 @@ impl MemAlignSM { let mut second_write_row = MemAlignRow:: { step: F::from_canonical_u64(step + 1), addr: F::from_canonical_u32(addr_second_read_write), + delta_addr: F::one(), offset: F::from_canonical_u64(DEFAULT_OFFSET), width: F::from_canonical_u64(DEFAULT_WIDTH), wr: F::from_bool(true), @@ -725,6 +769,7 @@ impl MemAlignSM { let mut second_read_row = MemAlignRow:: { step: F::from_canonical_u64(step), addr: F::from_canonical_u32(addr_second_read_write), + // delta_addr: F::zero(), offset: F::from_canonical_u64(DEFAULT_OFFSET), width: F::from_canonical_u64(DEFAULT_WIDTH), // wr: F::from_bool(false), @@ -777,6 +822,28 @@ impl MemAlignSM { } } + let mut _value_first_read = value_first_read; + let mut _value_first_write = value_first_write; + let mut _value = value; + let mut _value_second_write = value_second_write; + let mut _value_second_read = value_second_read; + for i in 0..RC { + first_read_row.value[i] = + F::from_canonical_u64(_value_first_read & RC_MASK); + first_write_row.value[i] = + F::from_canonical_u64(_value_first_write & RC_MASK); + value_row.value[i] = F::from_canonical_u64(_value & RC_MASK); + second_write_row.value[i] = + F::from_canonical_u64(_value_second_write & RC_MASK); + second_read_row.value[i] = + F::from_canonical_u64(_value_second_read & RC_MASK); + _value_first_read >>= RC_BITS; + _value_first_write >>= RC_BITS; + _value >>= RC_BITS; + _value_second_write >>= RC_BITS; + _value_second_read >>= RC_BITS; + } + #[rustfmt::skip] debug_info!( "\nTwo Words Write\n\ @@ -945,7 +1012,7 @@ impl MemAlignSM { ); } - // Compute the padding multiplicity + // Compute the program multiplicity let mem_align_rom_sm = self.mem_align_rom_sm.clone(); mem_align_rom_sm.update_padding_row(padding_size as u64); @@ -953,8 +1020,8 @@ impl MemAlignSM { "{}: ··· Creating Mem Align instance [{} / {} rows filled {:.2}%]", Self::MY_NAME, rows_len, - air_mem_align.num_rows(), - rows_len as f64 / air_mem_align.num_rows() as f64 * 100.0 + air_mem_align_rows, + rows_len as f64 / air_mem_align_rows as f64 * 100.0 ); // Add a new Mem Align instance diff --git a/state-machines/mem/src/mem_proxy.rs b/state-machines/mem/src/mem_proxy.rs index 2d2a1dbb..b54a0d18 100644 --- a/state-machines/mem/src/mem_proxy.rs +++ b/state-machines/mem/src/mem_proxy.rs @@ -23,8 +23,8 @@ pub struct MemProxy { impl MemProxy { pub fn new(wcm: Arc>, std: Arc>) -> Arc { let mem_align_rom_sm = MemAlignRomSM::new(wcm.clone()); - let mem_align_sm = MemAlignSM::new(wcm.clone(), std, mem_align_rom_sm.clone()); - let mem_sm = MemSM::new(wcm.clone()); + let mem_align_sm = MemAlignSM::new(wcm.clone(), std.clone(), mem_align_rom_sm.clone()); + let mem_sm = MemSM::new(wcm.clone(), std); let mem_proxy = Self { registered_predecessors: AtomicU32::new(0), diff --git a/state-machines/mem/src/mem_sm.rs b/state-machines/mem/src/mem_sm.rs index 31e45377..66eb8df9 100644 --- a/state-machines/mem/src/mem_sm.rs +++ b/state-machines/mem/src/mem_sm.rs @@ -3,10 +3,11 @@ use std::sync::{ Arc, Mutex, }; -const MEM_INITIAL_ADDRESS: u32 = 0xA0000000; -const MEM_FINAL_ADDRESS: u32 = MEM_INITIAL_ADDRESS + 128 * 1024 * 1024; use crate::{MemInput, MemModule}; +use num_bigint::BigInt; +use num_traits::cast::ToPrimitive; use p3_field::PrimeField; +use pil_std_lib::Std; use proofman::{WitnessComponent, WitnessManager}; use proofman_common::AirInstance; use rayon::prelude::*; @@ -14,10 +15,17 @@ use rayon::prelude::*; use sm_common::create_prover_buffer; use zisk_pil::{MemTrace, MEM_AIR_IDS, ZISK_AIRGROUP_ID}; +const MEM_INITIAL_ADDRESS: u32 = 0xA0000000; +const MEM_FINAL_ADDRESS: u32 = MEM_INITIAL_ADDRESS + 128 * 1024 * 1024; +const MEMORY_MAX_DIFF: u32 = 0x1000000; + pub struct MemSM { // Witness computation manager wcm: Arc>, + // STD + std: Arc>, + num_rows: usize, // Count of registered predecessors registered_predecessors: AtomicU32, @@ -25,11 +33,12 @@ pub struct MemSM { #[allow(unused, unused_variables)] impl MemSM { - pub fn new(wcm: Arc>) -> Arc { + pub fn new(wcm: Arc>, std: Arc>) -> Arc { let pctx = wcm.get_pctx(); let air = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_AIR_IDS[0]); let mem_sm = Self { wcm: wcm.clone(), + std: std.clone(), num_rows: air.num_rows(), registered_predecessors: AtomicU32::new(0), }; @@ -37,6 +46,8 @@ impl MemSM { wcm.register_component(mem_sm.clone(), Some(ZISK_AIRGROUP_ID), Some(MEM_AIR_IDS)); + std.register_predecessor(); + mem_sm } @@ -45,30 +56,32 @@ impl MemSM { } pub fn unregister_predecessor(&self) { - if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 {} + if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { + let pctx = self.wcm.get_pctx(); + self.std.unregister_predecessor(pctx, None); + } } - pub fn prove(&self, mem_accesses: &[MemInput]) { - // Sort the (full) aligned memory accesses - - let pctx = self.wcm.get_pctx(); - let ectx = self.wcm.get_ectx(); - let sctx = self.wcm.get_sctx(); + pub fn prove(&self, inputs: &[MemInput]) { + let wcm = self.wcm.clone(); + let pctx = wcm.get_pctx(); + let ectx = wcm.get_ectx(); + let sctx = wcm.get_sctx(); - let air = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_AIR_IDS[0]); + let air_mem = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_AIR_IDS[0]); + let air_mem_rows = air_mem.num_rows(); - let num_chunks = (mem_accesses.len() as f64 / (air.num_rows() - 1) as f64).ceil() as usize; + let inputs_len = inputs.len(); + let num_chunks = (inputs_len as f64 / air_mem_rows as f64).ceil() as usize; let mut prover_buffers = Mutex::new(vec![Vec::new(); num_chunks]); let mut offsets = vec![0; num_chunks]; let mut global_idxs = vec![0; num_chunks]; for i in 0..num_chunks { - if let (true, global_idx) = self.wcm.get_ectx().dctx.write().unwrap().add_instance( - ZISK_AIRGROUP_ID, - MEM_AIR_IDS[0], - 1, - ) { + if let (true, global_idx) = + ectx.dctx.write().unwrap().add_instance(ZISK_AIRGROUP_ID, MEM_AIR_IDS[0], 1) + { let (buffer, offset) = create_prover_buffer::(&ectx, &sctx, ZISK_AIRGROUP_ID, MEM_AIR_IDS[0]); @@ -78,27 +91,41 @@ impl MemSM { } } - mem_accesses.par_chunks(air.num_rows() - 1).enumerate().for_each( - |(segment_id, mem_ops)| { - let mem_first_row = if segment_id == 0 { - mem_accesses.last().unwrap().clone() - } else { - mem_accesses[segment_id * ((air.num_rows() - 1) - 1)].clone() - }; - - let prover_buffer = std::mem::take(&mut prover_buffers.lock().unwrap()[segment_id]); - - self.prove_instance( - mem_ops, - mem_first_row, - segment_id, - segment_id == mem_accesses.len() - 1, - prover_buffer, - offsets[segment_id], - global_idxs[segment_id], - ); - }, - ); + for (segment_id, mem_ops) in inputs.chunks(air_mem_rows).enumerate() { + let is_last_segment = segment_id == num_chunks - 1; + + let prover_buffer = std::mem::take(&mut prover_buffers.lock().unwrap()[segment_id]); + + self.prove_instance( + mem_ops, + segment_id, + is_last_segment, + prover_buffer, + offsets[segment_id], + global_idxs[segment_id], + ); + } + + // TODO: Uncomment when sequential works + // inputs.par_chunks(air_mem_rows - 1).enumerate().for_each(|(segment_id, mem_ops)| { + // let mem_first_row = if segment_id == 0 { + // inputs.last().unwrap().clone() + // } else { + // inputs[segment_id * ((air_mem_rows - 1) - 1)].clone() + // }; + + // let prover_buffer = std::mem::take(&mut prover_buffers.lock().unwrap()[segment_id]); + + // self.prove_instance( + // mem_ops, + // mem_first_row, + // segment_id, + // segment_id == inputs.len() - 1, + // prover_buffer, + // offsets[segment_id], + // global_idxs[segment_id], + // ); + // }); } /// Finalizes the witness accumulation process and triggers the proof generation. @@ -111,20 +138,21 @@ impl MemSM { pub fn prove_instance( &self, mem_ops: &[MemInput], - mem_first_row: MemInput, segment_id: usize, is_last_segment: bool, mut prover_buffer: Vec, offset: u64, global_idx: usize, ) -> Result<(), Box> { - let pctx = self.wcm.get_pctx(); - let sctx = self.wcm.get_sctx(); + let wcm = self.wcm.clone(); + let pctx = wcm.get_pctx(); + let sctx = wcm.get_sctx(); // STEP2: Process the memory inputs and convert them to AIR instances - let air = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_AIR_IDS[0]); + let air_mem = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_AIR_IDS[0]); + let air_mem_rows = air_mem.num_rows(); - let max_rows_per_segment = air.num_rows() - 1; + let max_rows_per_segment = air_mem_rows - 1; assert!(mem_ops.len() > 0 && mem_ops.len() <= max_rows_per_segment); @@ -142,29 +170,21 @@ impl MemSM { // in the prove_witnesses method we drain the memory operations in chunks of n - 1 rows let mut trace = - MemTrace::::map_buffer(&mut prover_buffer, air.num_rows(), offset as usize).unwrap(); - - // STEP1. Add the first row to the output vector as equal to the last row of the previous - // segment CASE: last row of segment is read - // - // S[n-1] wr = 0, sel = 1, addr, step, value - // S+1[0] wr = 0, sel = 0, addr, step, value - // - // CASE: last row of segment is write - // - // S[n-1] wr = 1, sel = 1, addr, step, value - // S+1[0] wr = 0, sel = 0, addr, step, value - - // TODO CHECK - // trace[0].mem_segment = segment_id_field; - // trace[0].mem_last_segment = is_last_segment_field; - - trace[0].addr = F::from_canonical_u32(mem_first_row.address); - trace[0].step = F::from_canonical_u64(mem_first_row.step); + MemTrace::::map_buffer(&mut prover_buffer, air_mem_rows, offset as usize).unwrap(); + + let mut range_check_data: Vec = vec![0; MEMORY_MAX_DIFF as usize]; + + // Fill the first row + let first_mem_op = mem_ops.first().unwrap(); + let addr = first_mem_op.address >> 3; + debug_assert!(addr >= MEM_INITIAL_ADDRESS); + + trace[0].addr = F::from_canonical_u32(addr); + trace[0].step = F::from_canonical_u64(first_mem_op.step); trace[0].sel = F::zero(); trace[0].wr = F::zero(); - let value = mem_first_row.value; + let value = first_mem_op.value; let (low_val, high_val) = self.get_u32_values(value); trace[0].value = [F::from_canonical_u32(low_val), F::from_canonical_u32(high_val)]; trace[0].addr_changes = F::zero(); @@ -172,12 +192,16 @@ impl MemSM { trace[0].same_value = F::zero(); trace[0].first_addr_access_is_read = F::zero(); - // STEP2. Add all the memory operations to the buffer + let increment = addr - 1 - MEM_INITIAL_ADDRESS; + trace[0].increment = F::from_canonical_u32(increment); + + // Store the value of incremenet so it can be range checked + println!("addr: {:#X}, initial: {:#X}, increment: {:#X}", addr, MEM_INITIAL_ADDRESS, increment); + range_check_data[increment as usize] += 1; // TODO + + // Fill the remaining rows for (idx, mem_op) in mem_ops.iter().enumerate() { let i = idx + 1; - // TODO CHECK - // trace[i].mem_segment = segment_id_field; - // trace[i].mem_last_segment = is_last_segment_field; let mem_addr = mem_op.address >> 3; trace[i].addr = F::from_canonical_u32(mem_addr); // n-byte address, real address = addr * MEM_BYTES @@ -199,17 +223,31 @@ impl MemSM { trace[i].first_addr_access_is_read = if first_addr_access_is_read { F::one() } else { F::zero() }; assert!(trace[i].sel.is_zero() || trace[i].sel.is_one()); + + let increment = if addr_changes { + trace[i].addr - trace[i - 1].addr + } else { + trace[i].step - trace[i - 1].step + }; + trace[i].increment = increment; + + // Store the value of incremenet so it can be range checked + let element = + increment.as_canonical_biguint().to_usize().expect("Cannot convert to usize"); + // range_check_data[element] += 1; // TODO: } // STEP3. Add dummy rows to the output vector to fill the remaining rows - //PADDING: At end of memory fill with same addr, incrementing step, same value, sel = 0, rd + // PADDING: At end of memory fill with same addr, incrementing step, same value, sel = 0, rd // = 1, wr = 0 let last_row_idx = mem_ops.len(); let addr = trace[last_row_idx].addr; let mut step = trace[last_row_idx].step; let value = trace[last_row_idx].value; - for i in (mem_ops.len() + 1)..air.num_rows() { + let padding_size = air_mem_rows - (mem_ops.len() + 1); + + for i in (mem_ops.len() + 1)..air_mem_rows { step += F::one(); // TODO CHECK @@ -226,10 +264,27 @@ impl MemSM { trace[i].addr_changes = F::zero(); trace[i].same_value = F::one(); trace[i].first_addr_access_is_read = F::zero(); + + // Set increment to the minimum value so the range check passes + trace[i].increment = F::one(); } + // Store the value of trivial increment so that they can be range checked + range_check_data[1] += padding_size as u64; + + // TODO: Perform the range checks + // let std = self.std.clone(); + // let range_id = std.get_range(BigInt::from(1), BigInt::from(MEMORY_MAX_DIFF), None); + // for (value, &multiplicity) in range_check_data.iter().enumerate() { + // std.range_check( + // F::from_canonical_usize(value), + // F::from_canonical_u64(multiplicity), + // range_id, + // ); + // } + let mut air_instance = AirInstance::new( - self.wcm.get_sctx(), + sctx.clone(), ZISK_AIRGROUP_ID, MEM_AIR_IDS[0], Some(segment_id),