diff --git a/.cargo/config.toml b/.cargo/config.toml deleted file mode 100644 index 862be18186..0000000000 --- a/.cargo/config.toml +++ /dev/null @@ -1,5 +0,0 @@ -[source.crates-io] -replace-with = "vendored-sources" - -[source.vendored-sources] -directory = "proof-systems-vendors" \ No newline at end of file diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d64fec6209..09b91f2ab9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -95,7 +95,7 @@ jobs: with: rust_toolchain_version: ${{ matrix.rust_toolchain_version }} - - name: Rust Cache + - name: Apply the Rust smart cacheing uses: Swatinem/rust-cache@v2 - name: Use shared OCaml setting up steps @@ -134,7 +134,7 @@ jobs: - name: Build cargo docs run: | eval $(opam env) - RUSTDOCFLAGS="-D warnings" cargo doc --all-features --no-deps + make generate-doc # # Coding guidelines @@ -174,24 +174,20 @@ jobs: eval $(opam env) make nextest - - name: Run heavy tests without the code coverage + - name: Run non-heavy tests with the code coverage if: ${{ matrix.rust_toolchain_version == env.RUST_TOOLCHAIN_COVERAGE_VERSION }} run: | eval $(opam env) - make nextest-heavy - # TODO: Re-enable this for tests with coverage data gathering. - # TODO: We intentionally disable the tests with the code coverage data gathering because - # TODO: Rust's code coverage instrumentation is not compatible with some of the optimizations that occur during release builds. - # make nextest-with-coverage - # make test-doc-with-coverage - # make generate-test-coverage-report - - # - name: Use shared code coverage summary - # if: ${{ matrix.rust_toolchain_version == env.RUST_TOOLCHAIN_COVERAGE_VERSION }} - # uses: ./.github/actions/coverage-summary-shared - - # - name: Use shared Codecov reporting steps - # if: ${{ matrix.rust_toolchain_version == env.RUST_TOOLCHAIN_COVERAGE_VERSION }} - # uses: ./.github/actions/codecov-shared - # with: - # token: ${{ secrets.CODECOV_TOKEN }} + make nextest-with-coverage + make test-doc-with-coverage + make generate-test-coverage-report + + - name: Use shared code coverage summary + if: ${{ matrix.rust_toolchain_version == env.RUST_TOOLCHAIN_COVERAGE_VERSION }} + uses: ./.github/actions/coverage-summary-shared + + - name: Use shared Codecov reporting steps + if: ${{ matrix.rust_toolchain_version == env.RUST_TOOLCHAIN_COVERAGE_VERSION }} + uses: ./.github/actions/codecov-shared + with: + token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.github/workflows/gh-page.yml b/.github/workflows/gh-page.yml index 74655d4392..6e9b3386b9 100644 --- a/.github/workflows/gh-page.yml +++ b/.github/workflows/gh-page.yml @@ -2,12 +2,13 @@ name: Deploy Specifications & Docs to GitHub Pages on: push: - branches: - - master + branches: ["master"] env: OCAML_VERSION: "4.14.0" - RUST_TOOLCHAIN_VERSION: "nightly" + # This version has been chosen randomly. It seems that with 2023-11-16, it is + # broken. The compiler crashes. Feel free to pick any newer working version. + RUST_TOOLCHAIN_VERSION: "1.72" jobs: release: @@ -30,10 +31,11 @@ jobs: with: ocaml_version: ${{ env.OCAML_VERSION }} + # This must be the same as in the section "Generate rustdoc locally" in the README.md - name: Build Rust Documentation run: | eval $(opam env) - RUSTDOCFLAGS="--enable-index-page -Zunstable-options" cargo +nightly doc --all --no-deps + RUSTDOCFLAGS="--enable-index-page -Zunstable-options -D warnings" cargo doc --workspace --all-features --no-deps - name: Build the mdbook run: | @@ -47,7 +49,8 @@ jobs: mv ./target/doc ./book/book/html/rustdoc - name: Deploy - uses: peaceiris/actions-gh-pages@v3 + uses: peaceiris/actions-gh-pages@v4 + if: github.ref == 'refs/heads/master' with: github_token: ${{ secrets.GITHUB_TOKEN }} publish_dir: ./book/book/html diff --git a/.github/workflows/o1vm-ci.yml b/.github/workflows/o1vm-ci.yml new file mode 100644 index 0000000000..f657d3c879 --- /dev/null +++ b/.github/workflows/o1vm-ci.yml @@ -0,0 +1,92 @@ +name: o1vm CI + +on: + workflow_dispatch: + pull_request: + push: + branches: + - master + +env: + # https://doc.rust-lang.org/cargo/reference/profiles.html#release + # Disable for the time being since it fails with the "attempt to multiply with overflow" error. + # Known issue yeat to be fixed. + # RUSTFLAGS: -Coverflow-checks=y -Cdebug-assertions=y + # https://doc.rust-lang.org/cargo/reference/profiles.html#incremental + CARGO_INCREMENTAL: 1 + # https://nexte.st/book/pre-built-binaries.html#using-nextest-in-github-actions + CARGO_TERM_COLOR: always + # 30 MB of stack for Keccak tests + RUST_MIN_STACK: 31457280 + +jobs: + run_o1vm_with_cached_data: + name: Run o1vm with cached data + # We run only one of the matrix options on the toffee `hetzner-1` self-hosted GitHub runner. + # Only in this configuration we enable tests with the code coverage data gathering. + runs-on: ["ubuntu-latest"] + strategy: + matrix: + rust_toolchain_version: ["1.74"] + # FIXME: currently not available for 5.0.0. + # It might be related to boxroot dependency, and we would need to bump + # up the ocaml-rs dependency + ocaml_version: ["4.14"] + services: + o1vm-e2e-testing-cache: + image: o1labs/proof-systems:o1vm-e2e-testing-cache + volumes: + - /tmp:/tmp/cache + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + submodules: recursive + + - name: Use shared Rust toolchain setting up steps + uses: ./.github/actions/toolchain-shared + with: + rust_toolchain_version: ${{ matrix.rust_toolchain_version }} + + - name: Apply the Rust smart cacheing + uses: Swatinem/rust-cache@v2 + + - name: Use shared OCaml setting up steps + uses: ./.github/actions/ocaml-shared + with: + ocaml_version: ${{ matrix.ocaml_version }} + + - name: Install the Python + uses: actions/setup-python@v5 + with: + python-version: "3.13" + check-latest: true + + - name: Install the Go + uses: actions/setup-go@v5 + with: + go-version: "1.21.0" + + - name: Install the Foundry + uses: foundry-rs/foundry-toolchain@v1 + + - name: Build the OP program + run: | + cd o1vm + make -C ./ethereum-optimism/op-program op-program + cd .. + + - name: Start the local HTTP server + run: | + python -m http.server 8765 & + + # + # Tests + # + + - name: Execute o1vm in E2E flavor using cached data + run: | + eval $(opam env) + cd o1vm + unzip -q -o /tmp/o1vm-e2e-testing-cache.zip -d ./ + RUN_WITH_CACHED_DATA="y" FILENAME="env-for-latest-l2-block.sh" O1VM_FLAVOR="pickles" STOP_AT="=3000000" ./run-code.sh diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml deleted file mode 100644 index 6dd1fdf361..0000000000 --- a/.github/workflows/stale.yml +++ /dev/null @@ -1,35 +0,0 @@ -# This workflow warns and then closes issues and PRs that have had no activity for a specified amount of time. -# -# You can adjust the behavior by modifying this file. -# For more information, see: -# https://github.com/actions/stale -name: Mark stale issues and pull requests - -on: - schedule: -# ┌───────────── minute (0 - 59) -# │ ┌───────────── hour (0 - 23) -# │ │ ┌───────────── day of the month (1 - 31) -# │ │ │ ┌───────────── month (1 - 12) -# │ │ │ │ ┌───────────── day of the week (0 - 6) -# │ │ │ │ │ -# │ │ │ │ │ -# │ │ │ │ │ - - cron: '19 7 * * *' - -jobs: - stale: - - runs-on: ubuntu-latest - permissions: - issues: write - pull-requests: write - - steps: - - uses: actions/stale@v3 - with: - repo-token: ${{ secrets.GITHUB_TOKEN }} - stale-issue-message: 'Stale issue message' - stale-pr-message: 'Stale pull request message' - stale-issue-label: 'no-issue-activity' - stale-pr-label: 'no-pr-activity' diff --git a/.gitignore b/.gitignore index 43461ced65..8fb024dbfc 100644 --- a/.gitignore +++ b/.gitignore @@ -29,3 +29,13 @@ build tools/srs .ignore +o1vm/env-for-* +o1vm/cpu.pprof +o1vm/snapshot-state* +o1vm/rpcs.sh +o1vm/meta.json +o1vm/out.json +o1vm/op-program-db* +o1vm/state.json +meta.json +state.json diff --git a/.gitmodules b/.gitmodules index 3b6a5edc6a..916b6be0fb 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,3 @@ -[submodule "proof-systems-vendors"] - path = proof-systems-vendors - url = https://github.com/o1-labs/proof-systems-vendors.git +[submodule "o1vm/ethereum-optimism"] + path = o1vm/ethereum-optimism + url = https://github.com/ethereum-optimism/optimism.git diff --git a/.rustfmt.toml b/.rustfmt.toml index be0b59b240..bf2e23a2e4 100644 --- a/.rustfmt.toml +++ b/.rustfmt.toml @@ -1,3 +1,3 @@ indent_style = "Block" imports_granularity = "Crate" -reorder_imports = true \ No newline at end of file +reorder_imports = true diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 3f47c5c106..fa5c79bad0 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -14,8 +14,8 @@ We have a list of easy task to start contributing. [Start over there](https://gi ## Setting up the project -Make sure you have the GNU `make` utility installed since we use it to streamline various tasks. -Windows users may need to use the `WSL` to run `make` commands. +Make sure you have the GNU `make` utility installed since we use it to streamline various tasks. +Windows users may need to use the `WSL` to run `make` commands. For the complete list of `make` targets, please refer to the [Makefile](Makefile). After the repository being cloned, run: @@ -75,6 +75,12 @@ BIN_EXTRA_ARGS="-p poly-commitment" make nextest-all-with-coverage ``` Note: In example above we run tests for the `poly-commitment` package only. +You can also use the environment variable `BIN_EXTRA_ARGS` to select a specific +test to run. For instance: +``` +BIN_EXTRA_ARGS="test_opening_proof" make nextest +``` +will only run the tests containing `test_opening_proof`. We build and run tests in `--release` mode, because otherwise tests execution can last for a long time. @@ -96,7 +102,7 @@ Note: cargo can automatically fix some lints. To do so, add `--fix` to the `CARG CARGO_EXTRA_ARGS="--fix" make lint ``` -Formatting and lints are enforced by GitHub PR checks, so please be sure to have any errors produced by the above tools fixed before pushing the code to your pull request branch. +Formatting and lints are enforced by GitHub PR checks, so please be sure to have any errors produced by the above tools fixed before pushing the code to your pull request branch. Please refer to [CI](.github/workflows/ci.yml) workflow to see all PR checks. ## Branching policy diff --git a/Cargo.lock b/Cargo.lock index ff2f89b002..8295764350 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,12 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "adler32" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aae1277d39aeec15cb388266ecc24b11c80469deae6067e17a1a7aa9e5c1f234" + [[package]] name = "ahash" version = "0.8.7" @@ -53,6 +59,12 @@ dependencies = [ "libc", ] +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + [[package]] name = "ansi_term" version = "0.12.1" @@ -62,6 +74,54 @@ dependencies = [ "winapi 0.3.9", ] +[[package]] +name = "anstream" +version = "0.6.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e2e1ebcb11de5c03c67de28a7df593d32191b44939c482e97702baaaa6ab6a5" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2faccea4cc4ab4a667ce676a30e8ec13922a692c99bb8f5b11f1502c72e04220" + +[[package]] +name = "anstyle-parse" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c75ac65da39e5fe5ab759307499ddad880d724eed2f6ce5b5e8a26f4f387928c" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e28923312444cdd728e4738b3f9c9cac739500909bb3d3c94b43551b16517648" +dependencies = [ + "windows-sys 0.52.0", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1cd54b81ec8d6180e24654d0b371ad22fc3dd083b6ff8ba325b72e00c87660a7" +dependencies = [ + "anstyle", + "windows-sys 0.52.0", +] + [[package]] name = "ark-algebra-test-templates" version = "0.4.2" @@ -105,7 +165,7 @@ dependencies = [ "ark-std", "derivative", "hashbrown 0.13.2", - "itertools", + "itertools 0.10.5", "num-traits", "rayon", "zeroize", @@ -123,7 +183,7 @@ dependencies = [ "ark-std", "derivative", "digest", - "itertools", + "itertools 0.10.5", "num-bigint", "num-traits", "paste", @@ -214,6 +274,34 @@ dependencies = [ "ark-std", ] +[[package]] +name = "arrabiata" +version = "0.1.0" +dependencies = [ + "ark-ec", + "ark-ff", + "ark-poly", + "clap 4.4.18", + "env_logger", + "groupmap", + "itertools 0.12.1", + "kimchi", + "log", + "mina-curves", + "mina-poseidon", + "mvpoly", + "num-bigint", + "num-integer", + "o1-utils", + "once_cell", + "poly-commitment", + "rand", + "rayon", + "serde", + "strum", + "strum_macros", +] + [[package]] name = "askama" version = "0.11.1" @@ -268,7 +356,7 @@ version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" dependencies = [ - "hermit-abi", + "hermit-abi 0.1.19", "libc", "winapi 0.3.9", ] @@ -454,6 +542,33 @@ dependencies = [ "windows-targets 0.52.0", ] +[[package]] +name = "ciborium" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" + +[[package]] +name = "ciborium-ll" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" +dependencies = [ + "ciborium-io", + "half 2.3.1", +] + [[package]] name = "clap" version = "2.34.0" @@ -478,7 +593,7 @@ dependencies = [ "atty", "bitflags 1.3.2", "clap_derive", - "clap_lex", + "clap_lex 0.2.4", "indexmap 1.9.3", "once_cell", "strsim 0.10.0", @@ -486,6 +601,27 @@ dependencies = [ "textwrap 0.16.0", ] +[[package]] +name = "clap" +version = "4.4.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e578d6ec4194633722ccf9544794b71b1385c3c027efe0c55db226fc880865c" +dependencies = [ + "clap_builder", +] + +[[package]] +name = "clap_builder" +version = "4.4.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4df4df40ec50c46000231c914968278b1eb05098cf8f1b3a518a95030e71d1c7" +dependencies = [ + "anstream", + "anstyle", + "clap_lex 0.6.0", + "strsim 0.10.0", +] + [[package]] name = "clap_derive" version = "3.2.25" @@ -508,6 +644,18 @@ dependencies = [ "os_str_bytes", ] +[[package]] +name = "clap_lex" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "702fc72eb24e5a1e48ce58027a675bc24edd52096d5397d4aea7c6dd9eca0bd1" + +[[package]] +name = "colorchoice" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" + [[package]] name = "colored" version = "2.1.0" @@ -518,6 +666,16 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "command-fds" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7bb11bd1378bf3731b182997b40cefe00aba6a6cc74042c8318c1b271d3badf7" +dependencies = [ + "nix", + "thiserror", +] + [[package]] name = "comrak" version = "0.13.2" @@ -573,6 +731,15 @@ version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" +[[package]] +name = "core2" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b49ba7ef1ad6107f8824dbe97de947cbaac53c44e7f9756a1fba0d37c1eec505" +dependencies = [ + "memchr", +] + [[package]] name = "cpufeatures" version = "0.2.12" @@ -600,9 +767,9 @@ dependencies = [ "atty", "cast", "clap 2.34.0", - "criterion-plot", + "criterion-plot 0.4.5", "csv", - "itertools", + "itertools 0.10.5", "lazy_static", "num-traits", "oorandom", @@ -617,6 +784,32 @@ dependencies = [ "walkdir", ] +[[package]] +name = "criterion" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +dependencies = [ + "anes", + "cast", + "ciborium", + "clap 4.4.18", + "criterion-plot 0.5.0", + "is-terminal", + "itertools 0.10.5", + "num-traits", + "once_cell", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", +] + [[package]] name = "criterion-plot" version = "0.4.5" @@ -624,7 +817,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2673cc8207403546f45f5fd319a974b1e6983ad1a3ee7e6041650013be041876" dependencies = [ "cast", - "itertools", + "itertools 0.10.5", +] + +[[package]] +name = "criterion-plot" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +dependencies = [ + "cast", + "itertools 0.10.5", ] [[package]] @@ -730,6 +933,12 @@ dependencies = [ "syn 2.0.48", ] +[[package]] +name = "dary_heap" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7762d17f1241643615821a8455a0b2c3e803784b058693d990b11f2dce25a0ca" + [[package]] name = "deranged" version = "0.3.11" @@ -762,24 +971,47 @@ dependencies = [ "subtle", ] -[[package]] -name = "disjoint-set" -version = "0.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d102f1a462fdcdddce88d6d46c06c074a2d2749b262230333726b06c52bb7585" - [[package]] name = "either" version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" +[[package]] +name = "elf" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4445909572dbd556c457c849c4ca58623d84b27c8fff1e74b0b4227d8b90d17b" + [[package]] name = "entities" version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5320ae4c3782150d900b79807611a59a99fc9a1d61d686faafc24b93fc8d7ca" +[[package]] +name = "env_filter" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a009aa4810eb158359dda09d0c87378e4bbb89b5a801f016885a4707ba24f7ea" +dependencies = [ + "log", + "regex", +] + +[[package]] +name = "env_logger" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05e7cf40684ae96ade6232ed84582f40ce0a66efcd43a5117aef610534f8e0b8" +dependencies = [ + "anstream", + "anstyle", + "env_filter", + "humantime", + "log", +] + [[package]] name = "equivalent" version = "1.0.1" @@ -857,6 +1089,38 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "folding" +version = "0.1.0" +dependencies = [ + "ark-bn254", + "ark-ec", + "ark-ff", + "ark-poly", + "ark-serialize", + "derivative", + "groupmap", + "itertools 0.12.1", + "kimchi", + "log", + "mina-curves", + "mina-poseidon", + "num-bigint", + "num-integer", + "num-traits", + "o1-utils", + "poly-commitment", + "rand", + "rayon", + "rmp-serde", + "serde", + "serde_json", + "serde_with", + "strum", + "strum_macros", + "thiserror", +] + [[package]] name = "fsevent" version = "0.4.0" @@ -941,6 +1205,16 @@ version = "1.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b43ede17f21864e81be2fa654110bf1e793774238d86ef8555c37e6519c0403" +[[package]] +name = "half" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc52e53916c08643f1b56ec082790d1e86a32e58dc5268f897f313fbae7b4872" +dependencies = [ + "cfg-if 1.0.0", + "crunchy", +] + [[package]] name = "hashbrown" version = "0.12.3" @@ -977,6 +1251,12 @@ dependencies = [ "libc", ] +[[package]] +name = "hermit-abi" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d3d0e0f38255e7fa3cf31335b3a56f05febd18025f4db5ef7a0cfb4f8da651f" + [[package]] name = "hex" version = "0.4.3" @@ -992,6 +1272,12 @@ version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "02296996cb8796d7c6e3bc2d9211b7802812d36999a51bb754123ead7d37d026" +[[package]] +name = "humantime" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" + [[package]] name = "iai" version = "0.1.1" @@ -1088,6 +1374,17 @@ dependencies = [ "libc", ] +[[package]] +name = "is-terminal" +version = "0.4.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bad00257d07be169d870ab665980b06cdb366d792ad690bf2e76876dc503455" +dependencies = [ + "hermit-abi 0.3.4", + "rustix", + "windows-sys 0.52.0", +] + [[package]] name = "is_ci" version = "1.1.1" @@ -1103,21 +1400,64 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" +[[package]] +name = "ivc" +version = "0.1.0" +dependencies = [ + "ark-bn254", + "ark-ec", + "ark-ff", + "ark-poly", + "folding", + "itertools 0.12.1", + "kimchi", + "kimchi-msm", + "mina-poseidon", + "num-bigint", + "num-integer", + "o1-utils", + "once_cell", + "poly-commitment", + "rand", + "rayon", + "strum", + "strum_macros", + "thiserror", +] + [[package]] name = "js-sys" -version = "0.3.64" +version = "0.3.67" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5f195fe497f702db0f318b07fdd68edb16955aed830df8363d837542f8f935a" +checksum = "9a1d36f1235bc969acba30b7f5990b864423a6068a10f7c90ae8f0112e3a59d1" dependencies = [ "wasm-bindgen", ] +[[package]] +name = "keccak" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecc2af9a1119c51f12a14607e783cb977bde58bc069ff0c3da1095e635d70654" +dependencies = [ + "cpufeatures", +] + [[package]] name = "kernel32-sys" version = "0.2.2" @@ -1139,13 +1479,13 @@ dependencies = [ "ark-serialize", "blake2", "colored", - "criterion", - "disjoint-set", + "criterion 0.5.1", "groupmap", "hex", "iai", "internal-tracing", - "itertools", + "itertools 0.12.1", + "log", "mina-curves", "mina-poseidon", "num-bigint", @@ -1174,6 +1514,39 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "kimchi-msm" +version = "0.1.0" +dependencies = [ + "ark-bn254", + "ark-ec", + "ark-ff", + "ark-poly", + "ark-serialize", + "ark-std", + "folding", + "groupmap", + "itertools 0.12.1", + "kimchi", + "log", + "mina-curves", + "mina-poseidon", + "num-bigint", + "num-integer", + "num-traits", + "o1-utils", + "poly-commitment", + "rand", + "rayon", + "rmp-serde", + "serde", + "serde_json", + "serde_with", + "strum", + "strum_macros", + "thiserror", +] + [[package]] name = "kimchi-visu" version = "0.1.0" @@ -1209,6 +1582,30 @@ version = "0.2.153" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" +[[package]] +name = "libflate" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f7d5654ae1795afc7ff76f4365c2c8791b0feb18e8996a96adad8ffd7c3b2bf" +dependencies = [ + "adler32", + "core2", + "crc32fast", + "dary_heap", + "libflate_lz77", +] + +[[package]] +name = "libflate_lz77" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be5f52fb8c451576ec6b79d3f4deb327398bc05bbdbd99021a6e77a4c855d524" +dependencies = [ + "core2", + "hashbrown 0.13.2", + "rle-decode-fast", +] + [[package]] name = "libm" version = "0.2.8" @@ -1314,6 +1711,7 @@ dependencies = [ "ark-serialize", "ark-std", "ark-test-curves", + "num-bigint", "rand", ] @@ -1338,7 +1736,7 @@ dependencies = [ "ark-ff", "ark-poly", "ark-serialize", - "criterion", + "criterion 0.3.6", "hex", "mina-curves", "o1-utils", @@ -1428,6 +1826,20 @@ dependencies = [ "ws2_32-sys", ] +[[package]] +name = "mvpoly" +version = "0.1.0" +dependencies = [ + "ark-ff", + "criterion 0.5.1", + "kimchi", + "log", + "mina-curves", + "num-integer", + "o1-utils", + "rand", +] + [[package]] name = "net2" version = "0.2.39" @@ -1439,6 +1851,17 @@ dependencies = [ "winapi 0.3.9", ] +[[package]] +name = "nix" +version = "0.27.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2eb04e9c688eff1c89d72b407f168cf79bb9e867a9d3323ed6c01519eb9cc053" +dependencies = [ + "bitflags 2.4.2", + "cfg-if 1.0.0", + "libc", +] + [[package]] name = "nom" version = "7.1.3" @@ -1536,6 +1959,47 @@ dependencies = [ "thiserror", ] +[[package]] +name = "o1vm" +version = "0.1.0" +dependencies = [ + "ark-bn254", + "ark-ec", + "ark-ff", + "ark-poly", + "base64", + "clap 4.4.18", + "command-fds", + "elf", + "env_logger", + "folding", + "groupmap", + "hex", + "itertools 0.12.1", + "kimchi", + "kimchi-msm", + "libc", + "libflate", + "log", + "mina-curves", + "mina-poseidon", + "o1-utils", + "os_pipe", + "poly-commitment", + "rand", + "rayon", + "regex", + "rmp-serde", + "serde", + "serde_json", + "serde_with", + "sha3", + "stacker", + "strum", + "strum_macros", + "thiserror", +] + [[package]] name = "object" version = "0.32.2" @@ -1656,6 +2120,16 @@ version = "11.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" +[[package]] +name = "os_pipe" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57119c3b893986491ec9aa85056780d3a0f3cf4da7cc09dd3650dbd6c6738fb9" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + [[package]] name = "os_str_bytes" version = "6.6.1" @@ -1733,9 +2207,9 @@ checksum = "2900ede94e305130c13ddd391e0ab7cbaeb783945ae07a279c268cb05109c6cb" [[package]] name = "plist" -version = "1.5.0" +version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bdc0001cfea3db57a2e24bc0d818e9e20e554b5f97fabb9bc231dc240269ae06" +checksum = "9a4a0cfc5fb21a09dc6af4bf834cf10d4a32fccd9e2ea468c4b1751a097487aa" dependencies = [ "base64", "indexmap 1.9.3", @@ -1784,9 +2258,9 @@ dependencies = [ "ark-serialize", "blake2", "colored", - "criterion", + "criterion 0.5.1", "groupmap", - "itertools", + "itertools 0.12.1", "mina-curves", "mina-poseidon", "o1-utils", @@ -1879,6 +2353,15 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "psm" +version = "0.1.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5787f7cda34e3033a72192c018bc5883100330f362ef279a8cbccfce8bb4e874" +dependencies = [ + "cc", +] + [[package]] name = "quick-error" version = "1.2.3" @@ -1887,9 +2370,9 @@ checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" [[package]] name = "quick-xml" -version = "0.29.0" +version = "0.30.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81b9228215d82c7b61490fec1de287136b5de6f5700f6e58ea9ad61a7964ca51" +checksum = "eff6510e86862b57b210fd8cbe8ed3f0d7d600b9c2863cd4549a2e033c66e956" dependencies = [ "memchr", ] @@ -2012,6 +2495,12 @@ version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" +[[package]] +name = "rle-decode-fast" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3582f63211428f83597b51b2ddb88e2a91a9d52d12831f9d08f5e624e8977422" + [[package]] name = "rmp" version = "0.8.12" @@ -2140,7 +2629,7 @@ version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2bef2ebfde456fb76bbcf9f59315333decc4fda0b2b44b420243c11e0f5ec1f5" dependencies = [ - "half", + "half 1.8.3", "serde", ] @@ -2206,6 +2695,16 @@ dependencies = [ "digest", ] +[[package]] +name = "sha3" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75872d278a8f37ef87fa0ddbda7802605cb18344497949862c0d4dcb291eba60" +dependencies = [ + "digest", + "keccak", +] + [[package]] name = "shell-words" version = "1.1.0" @@ -2227,6 +2726,19 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b7c388c1b5e93756d0c740965c41e8822f866621d41acbdf6336a6a168f8840c" +[[package]] +name = "stacker" +version = "0.1.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c886bd4480155fd3ef527d45e9ac8dd7118a898a46530b7b94c3e21866259fce" +dependencies = [ + "cc", + "cfg-if 1.0.0", + "libc", + "psm", + "winapi 0.3.9", +] + [[package]] name = "static_assertions" version = "1.1.0" @@ -2592,6 +3104,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" +[[package]] +name = "utf8parse" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" + [[package]] name = "vec_map" version = "0.8.2" @@ -2631,9 +3149,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.87" +version = "0.2.90" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7706a72ab36d8cb1f80ffbf0e071533974a60d0a308d01a5d0375bf60499a342" +checksum = "b1223296a201415c7fad14792dbefaace9bd52b62d33453ade1c5b5f07555406" dependencies = [ "cfg-if 1.0.0", "wasm-bindgen-macro", @@ -2641,9 +3159,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.87" +version = "0.2.90" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ef2b6d3c510e9625e5fe6f509ab07d66a760f0885d858736483c32ed7809abd" +checksum = "fcdc935b63408d58a32f8cc9738a0bffd8f05cc7c002086c6ef20b7312ad9dcd" dependencies = [ "bumpalo", "log", @@ -2656,9 +3174,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.87" +version = "0.2.90" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dee495e55982a3bd48105a7b947fd2a9b4a8ae3010041b9e0faab3f9cd028f1d" +checksum = "3e4c238561b2d428924c49815533a8b9121c664599558a5d9ec51f8a1740a999" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -2666,9 +3184,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.87" +version = "0.2.90" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "54681b18a46765f095758388f2d0cf16eb8d4169b639ab575a8f5693af210c7b" +checksum = "bae1abb6806dc1ad9e560ed242107c0f6c84335f1749dd4e8ddb012ebd5e25a7" dependencies = [ "proc-macro2", "quote", @@ -2679,15 +3197,15 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.87" +version = "0.2.90" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca6ad05a4870b2bf5fe995117d3728437bd27d7cd5f06f13c17443ef369775a1" +checksum = "4d91413b1c31d7539ba5ef2451af3f0b833a005eb27a631cec32bc0635a8602b" [[package]] name = "web-sys" -version = "0.3.64" +version = "0.3.67" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b85cbef8c220a6abc02aefd892dfc0fc23afb1c6a426316ec33253a3877249b" +checksum = "58cd2333b6e0be7a39605f0e255892fd7418a682d8da8fe042fe25128794d2ed" dependencies = [ "js-sys", "wasm-bindgen", diff --git a/Cargo.toml b/Cargo.toml index bd607426b9..fb82149618 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,18 +1,24 @@ [workspace] members = [ + "arrabiata", "book", "turshi", "curves", "groupmap", "hasher", "kimchi", + "msm", + "o1vm", "poseidon", "poseidon/export_test_vectors", "poly-commitment", "signer", + "mvpoly", "tools/kimchi-visu", "utils", "internal-tracing", + "ivc", + "folding" ] resolver = "2" @@ -34,12 +40,12 @@ clap = "4.4.6" colored = "2.0.0" command-fds = "0.3" convert_case = "0.6.0" -criterion = "0.3.6" +criterion = "0.5" elf = "0.7.2" env_logger = "0.11.1" hex = { version = "0.4", features = ["serde"] } iai = "0.1" -itertools = "0.10.5" +itertools = "0.12.1" libc = "0.2.62" libflate = "2" log = "0.4.20" @@ -62,25 +68,32 @@ rayon = "1.5.0" regex = "1.10.2" rmp-serde = "1.1.2" secp256k1 = "0.28.2" -serde = { version = "1.0.130", features = ["derive"] } +serde = { version = "1.0.130", features = ["derive", "rc"] } serde_json = "1.0.113" serde_with = "3.6.0" -sha2 = "0.10.2" +sha2 = "0.10.0" +sha3 = "0.10.8" strum = "0.26.1" strum_macros = "0.26.1" syn = { version = "1.0.109", features = ["full"] } thiserror = "1.0.30" tinytemplate = "1.1" -wasm-bindgen = "=0.2.87" +wasm-bindgen = "=0.2.90" + +arrabiata = { path = "./arrabiata", version = "0.1.0" } +folding = { path = "./folding", version = "0.1.0" } groupmap = { path = "./groupmap", version = "0.1.0" } internal-tracing = { path = "./internal-tracing", version = "0.1.0" } kimchi = { path = "./kimchi", version = "0.1.0", features = ["bn254"] } kimchi-visu = { path = "./tools/kimchi-visu", version = "0.1.0" } +kimchi-msm = { path = "./msm", version = "0.1.0" } mina-curves = { path = "./curves", version = "0.1.0" } mina-hasher = { path = "./hasher", version = "0.1.0" } mina-poseidon = { path = "./poseidon", version = "0.1.0" } +mvpoly = { path = "./mvpoly", version = "0.1.0" } o1-utils = { path = "./utils", version = "0.1.0" } +o1vm = { path = "./o1vm", version = "0.1.0" } optimism = { path = "./optimism", version = "0.1.0" } poly-commitment = { path = "./poly-commitment", version = "0.1.0" } signer = { path = "./signer", version = "0.1.0" } diff --git a/Makefile b/Makefile index 6063509911..53253a15e0 100644 --- a/Makefile +++ b/Makefile @@ -21,11 +21,10 @@ setup: @echo "Git submodules synced." @echo "" -# Install test dependencies # https://nexte.st/book/pre-built-binaries.html#using-nextest-in-github-actions # FIXME: update to 0.9.68 when we get rid of 1.71 and 1.72. # FIXME: latest 0.8.19+ requires rustc 1.74+ -install-test-deps: +install-test-deps: ## Install test dependencies @echo "" @echo "Installing the test dependencies." @echo "" @@ -36,73 +35,73 @@ install-test-deps: @echo "Test dependencies installed." @echo "" -# Clean the project -clean: + +clean: ## Clean the project cargo clean -# Build the project -build: + +build: ## Build the project cargo build --all-targets --all-features -# Build the project in release mode -release: + +release: ## Build the project in release mode cargo build --release --all-targets --all-features -# Test the project's docs comments -test-doc: + +test-doc: ## Test the project's docs comments cargo test --all-features --release --doc test-doc-with-coverage: $(COVERAGE_ENV) $(MAKE) test-doc -# Test the project with non-heavy tests and using native cargo test runner -test: + +test: ## Test the project with non-heavy tests and using native cargo test runner cargo test --all-features --release $(CARGO_EXTRA_ARGS) -- --nocapture --skip heavy $(BIN_EXTRA_ARGS) test-with-coverage: $(COVERAGE_ENV) CARGO_EXTRA_ARGS="$(CARGO_EXTRA_ARGS)" BIN_EXTRA_ARGS="$(BIN_EXTRA_ARGS)" $(MAKE) test -# Test the project with heavy tests and using native cargo test runner -test-heavy: + +test-heavy: ## Test the project with heavy tests and using native cargo test runner cargo test --all-features --release $(CARGO_EXTRA_ARGS) -- --nocapture heavy $(BIN_EXTRA_ARGS) test-heavy-with-coverage: $(COVERAGE_ENV) CARGO_EXTRA_ARGS="$(CARGO_EXTRA_ARGS)" BIN_EXTRA_ARGS="$(BIN_EXTRA_ARGS)" $(MAKE) test-heavy -# Test the project with all tests and using native cargo test runner -test-all: + +test-all: ## Test the project with all tests and using native cargo test runner cargo test --all-features --release $(CARGO_EXTRA_ARGS) -- --nocapture $(BIN_EXTRA_ARGS) test-all-with-coverage: $(COVERAGE_ENV) CARGO_EXTRA_ARGS="$(CARGO_EXTRA_ARGS)" BIN_EXTRA_ARGS="$(BIN_EXTRA_ARGS)" $(MAKE) test-all -# Test the project with non-heavy tests and using nextest test runner -nextest: + +nextest: ## Test the project with non-heavy tests and using nextest test runner cargo nextest run --all-features --release --profile ci -E "not test(heavy)" $(BIN_EXTRA_ARGS) nextest-with-coverage: $(COVERAGE_ENV) BIN_EXTRA_ARGS="$(BIN_EXTRA_ARGS)" $(MAKE) nextest -# Test the project with heavy tests and using nextest test runner -nextest-heavy: + +nextest-heavy: ## Test the project with heavy tests and using nextest test runner cargo nextest run --all-features --release --profile ci -E "test(heavy)" $(BIN_EXTRA_ARGS) nextest-heavy-with-coverage: $(COVERAGE_ENV) BIN_EXTRA_ARGS="$(BIN_EXTRA_ARGS)" $(MAKE) nextest-heavy -# Test the project with all tests and using nextest test runner -nextest-all: + +nextest-all: ## Test the project with all tests and using nextest test runner cargo nextest run --all-features --release --profile ci $(BIN_EXTRA_ARGS) nextest-all-with-coverage: $(COVERAGE_ENV) BIN_EXTRA_ARGS="$(BIN_EXTRA_ARGS)" $(MAKE) nextest-all -# Format the code -format: + +format: ## Format the code cargo +nightly fmt -- --check -# Lint the code -lint: + +lint: ## Lint the code cargo clippy --all-features --all-targets --tests $(CARGO_EXTRA_ARGS) -- -W clippy::all -D warnings generate-test-coverage-report: @@ -119,4 +118,17 @@ generate-test-coverage-report: @echo "The test coverage report is available at: ./target/coverage" @echo "" -.PHONY: all setup install-test-deps clean build release test-doc test-doc-with-coverage test test-with-coverage test-heavy test-heavy-with-coverage test-all test-all-with-coverage nextest nextest-with-coverage nextest-heavy nextest-heavy-with-coverage nextest-all nextest-all-with-coverage format lint generate-test-coverage-report +generate-doc: + @echo "" + @echo "Generating the documentation." + @echo "" + RUSTDOCFLAGS="-D warnings" cargo doc --all-features --no-deps + @echo "" + @echo "The documentation is available at: ./target/doc" + @echo "" + +help: ## Ask for help! + @grep -E '^[a-zA-Z0-9_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' + + +.PHONY: all setup install-test-deps clean build release test-doc test-doc-with-coverage test test-with-coverage test-heavy test-heavy-with-coverage test-all test-all-with-coverage nextest nextest-with-coverage nextest-heavy nextest-heavy-with-coverage nextest-all nextest-all-with-coverage format lint generate-test-coverage-report generate-doc help diff --git a/README.md b/README.md index 9f17533e5d..68f0191f5f 100644 --- a/README.md +++ b/README.md @@ -80,11 +80,15 @@ You can visualize the documentation by opening the file `target/doc/index.html`. -- [CI](.github/workflows/ci.yml). +- [CI](.github/workflows/ci.yml). This workflow ensures that the entire project builds correctly, adheres to guidelines, and passes all necessary tests. -- [Nightly tests with the code coverage](.github/workflows/ci-nightly.yml). +- [Nightly tests with the code coverage](.github/workflows/ci-nightly.yml). This workflow runs all the tests per scheduler or on-demand, generates and attaches the code coverage report to the job's execution results. -- [Benchmarks](.github/workflows/benches.yml). +- [Benchmarks](.github/workflows/benches.yml). This workflow runs benchmarks when a pull request is labeled with "benchmark." It sets up the Rust and OCaml environments, installs necessary tools, and executes cargo criterion benchmarks on the kimchi crate. The benchmark results are then posted as a comment on the pull request for review. -- [Deploy Specifications & Docs to GitHub Pages](.github/workflows/gh-page.yml). +- [Deploy Specifications & Docs to GitHub Pages](.github/workflows/gh-page.yml). When CI passes on master, the documentation built from the rust code will be available by this [link](https://o1-labs.github.io/proof-systems/rustdoc) and the book will be available by this [link](https://o1-labs.github.io/proof-systems). + +## Nix for Dependencies (WIP) + +If you have `nix` installed and in particular, `flakes` enabled, you can install the dependencies for these projects using nix. Simply `nix develop .` inside this directory to bring into scope `rustup`, `opam`, and `go` (along with a few other tools). You will have to manage the toolchains yourself using `rustup` and `opam`, in the current iteration. diff --git a/arrabiata/Cargo.toml b/arrabiata/Cargo.toml new file mode 100644 index 0000000000..599b627ec9 --- /dev/null +++ b/arrabiata/Cargo.toml @@ -0,0 +1,36 @@ +[package] +name = "arrabiata" +version = "0.1.0" +repository = "https://github.com/o1-labs/proof-systems" +homepage = "https://o1-labs.github.io/proof-systems/" +documentation = "https://o1-labs.github.io/proof-systems/rustdoc/" +edition = "2021" +license = "Apache-2.0" + +[[bin]] +name = "arrabiata" +path = "src/main.rs" + +[dependencies] +ark-ec.workspace = true +ark-ff.workspace = true +ark-poly.workspace = true +clap.workspace = true +env_logger.workspace = true +groupmap.workspace = true +kimchi.workspace = true +itertools.workspace = true +log.workspace = true +mina-curves.workspace = true +mina-poseidon.workspace = true +mvpoly.workspace = true +num-bigint.workspace = true +num-integer.workspace = true +o1-utils.workspace = true +once_cell.workspace = true +poly-commitment.workspace = true +rand.workspace = true +rayon.workspace = true +serde.workspace = true +strum.workspace = true +strum_macros.workspace = true \ No newline at end of file diff --git a/arrabiata/README.md b/arrabiata/README.md new file mode 100644 index 0000000000..81b68bd077 --- /dev/null +++ b/arrabiata/README.md @@ -0,0 +1,93 @@ +## Arrabiata - a generic recursive zero-knowledge argument implementation based on folding schemes + +### Motivation + +This library provides an implementation of a generic recursive zero-knowledge +argument based on folding schemes (initially defined in +[Nova](https://eprint.iacr.org/2021/370)), over the +[pasta](https://github.com/zcash/pasta_curves) curves and using the IPA +polynomial commitment. + +The end goal of this repository is to implement a Nova-like prover, and have a zkApp +on the Mina blockchain verifying the recursive proof*. This way, any user can run +arbitrarily large computation on their machine, make a proof, and rely on the +SNARK workers to verify the proof is correct and include it on-chains. + +The first iteration of the project will allow to fold a polynomial-time function +`f` (which can be, in a near future, a zkApp), of degree 2**. No generic lookup +argument will be implemented in the first version, even though a "runtime" +lookup/permutation argument will be required +for cross-cells referencing. The implementation leverages different +constructions and ideas described in papers like +[ProtoGalaxy](https://eprint.iacr.org/2023/1106), +[ProtoStar](https://eprint.iacr.org/2023/620), etc. + +*This might need changes to the Mina blockchain, with a possible new hardfork. +Not even sure it is possible right now. + +**This will change. We might go up to degree 6 or 7, as we're building the +different gadgets (EC addition, EC scalar multiplication, Poseidon). + +### Implementation details + +We provide a binary to run arbitrarily large computation. +The implementation of the circuits will follow the one used by the o1vm +interpreter. An interpreter will be implemented over an abstract environment. + +The environment used for the witness will contain all the information required +to make an IVC proof, i.e. the current witness in addition to the oracle, the +folding accumulator, etc. And the implementation attempts to have the smallest +memory footprint it can have by optimising the use of the CPU cache and using +references as much as it can. The implementation attempts to prioritize +allocations on the stack when possible, and attempts to keep in the CPU cache as +many data as required as long as possible to avoid indirection. + +The witness interpreter attempts to perform most of the algebraic operations +using integers, and not field element. The results are reduced into the field +when mandatory. + +While building the witness, the cross terms are also computed on the fly, to be +used in the next iteration. This way, the prover only pays the price of the +activated gate on each row. + +### Examples + +Different built-in examples are provided. For instance: +``` +cargo run --bin arrabiata --release -- square-root --n 10 --srs-size 16 +``` + +will generate 10 full folding iterations of the polynomial-time function `f(X, Y) = +X^2 - Y` and for each witness, generates random values, and make an IVC proof at +the end. + +You can also activate logging which contains benchmarking by using the +environment variable `RUST_LOG=debug`. + +### Run tests + +``` +cargo nextest run --all-features --release --nocapture -p arrabiata +``` + +### Registry of zkApps + +A registry of zkApps is already preconfigured. +To write a zkApp, check TODO. + +The zkApp registry can be found in TODO + + + + + +### References + +- The name is not only used as a reference to Kimchi and Pickles, but also to + the mathematician [Aryabhata](https://en.wikipedia.org/wiki/Aryabhata). diff --git a/arrabiata/notes.md b/arrabiata/notes.md new file mode 100644 index 0000000000..a347f8109a --- /dev/null +++ b/arrabiata/notes.md @@ -0,0 +1,44 @@ +## Notes to be removed later + +We can go around 30 columns if we want stay around 2^13 rows for the IVC. + +Currently, with 17 columns, we have to perform 20 + 3 (degree 3 folding) MSM which cost 256 rows +each. Therefore, we need 23 * 256 rows = ~2^12-~2^13 rows. + +It leaves only: +- SRS = 2^16: 2^16 - 2^13 = 57k rows for computations +- SRS = 2^17: 2^17 - 2^13 = 122k rows for computations + +Note that we also need lookups. Lookups require, per instance, `m`, `t`, +"accumulator φ" and some lookup partial sums. It must also be folded. + +If we use a "horizontal layout" (i.e. IVC next to APP), we could use the spare +rows to build some non-uniform circuits like SuperNova does by selecting an +instruction, and have more than one accumulators. + +### Cross terms computation + +based on the idea of computing on-the-fly: we might need to keep for each row +the partial evaluations, and define a polynomial in the second homogeneous value +`u''`. + + +See PDF in drive for the formula, but we can recursively populate a contiguous +array with the individual powers, and keep an index to get the appropriate power. +With this method, we avoid recomputing the powers of a certain value. Repeated +multiplications can also be avoided. +The contiguous array is passed as a parameter on the stack to the function. The +CPU cache should be big enough to handle for low-degree folding schemes (like 5 +or 6). + +### Next steps + +- [x] Method to compute cross-terms (2502) +- [ ] Change EC scalar multiplication to use the permutation argument later. + -> 1 day +- [ ] Permutation argument -> 2-3 days + - We don't necessarily need it in folding because it is linear. +- [ ] Coining challenges + verification en circuit +- [ ] Compute cross-terms with the existing expressions and commit to it + pass + as public input +- [ ] Compute challenge r diff --git a/arrabiata/src/column_env.rs b/arrabiata/src/column_env.rs new file mode 100644 index 0000000000..bd647b3257 --- /dev/null +++ b/arrabiata/src/column_env.rs @@ -0,0 +1 @@ +//! This module will be used by the prover to evaluate at a certain point. diff --git a/arrabiata/src/columns.rs b/arrabiata/src/columns.rs new file mode 100644 index 0000000000..3efc6e8616 --- /dev/null +++ b/arrabiata/src/columns.rs @@ -0,0 +1,135 @@ +use ark_ff::Field; +use kimchi::circuits::expr::{AlphaChallengeTerm, CacheId, ConstantExpr, Expr, FormattedOutput}; +use serde::{Deserialize, Serialize}; +use std::{ + collections::HashMap, + fmt::{Display, Formatter, Result}, + ops::Index, +}; +use strum_macros::{EnumCount as EnumCountMacro, EnumIter}; + +/// This enum represents the different gadgets that can be used in the circuit. +/// The selectors are defined at setup time, can take only the values `0` or +/// `1` and are public. +// IMPROVEME: should we merge it with the Instruction enum? +// It might not be that obvious to do so, as the Instruction enum could be +// defining operations that are not "fixed" in the circuit, but rather +// depend on runtime values (e.g. in a zero-knowledge virtual machine). +#[derive(Debug, Clone, Copy, PartialEq, EnumCountMacro, EnumIter)] +pub enum Gadget { + App, + // Elliptic curve related gadgets + EllipticCurveAddition, + EllipticCurveScaling, + /// This gadget implement the Poseidon hash instance described in the + /// top-level documentation. This implementation does use the "next row" + /// to allow the computation of one additional round per row. In the current + /// setup, with [crate::NUMBER_OF_COLUMNS] columns, we can compute 5 full + /// rounds per row. + Poseidon, +} + +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum Column { + Selector(Gadget), + PublicInput(usize), + X(usize), +} + +pub struct Challenges { + /// Challenge used to aggregate the constraints + pub alpha: F, + + /// Both challenges used in the permutation argument + pub beta: F, + pub gamma: F, + + /// Challenge to homogenize the constraints + pub homogenous_challenge: F, + + /// Random coin used to aggregate witnesses while folding + pub r: F, +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub enum ChallengeTerm { + /// Challenge used to aggregate the constraints + Alpha, + /// Both challenges used in the permutation argument + Beta, + Gamma, + /// Challenge to homogenize the constraints + HomogenousChallenge, + /// Random coin used to aggregate witnesses while folding + R, +} + +impl Display for ChallengeTerm { + fn fmt(&self, f: &mut Formatter) -> Result { + match self { + ChallengeTerm::Alpha => write!(f, "alpha"), + ChallengeTerm::Beta => write!(f, "beta"), + ChallengeTerm::Gamma => write!(f, "gamma"), + ChallengeTerm::HomogenousChallenge => write!(f, "u"), + ChallengeTerm::R => write!(f, "r"), + } + } +} +impl Index for Challenges { + type Output = F; + + fn index(&self, term: ChallengeTerm) -> &Self::Output { + match term { + ChallengeTerm::Alpha => &self.alpha, + ChallengeTerm::Beta => &self.beta, + ChallengeTerm::Gamma => &self.gamma, + ChallengeTerm::HomogenousChallenge => &self.homogenous_challenge, + ChallengeTerm::R => &self.r, + } + } +} + +impl<'a> AlphaChallengeTerm<'a> for ChallengeTerm { + const ALPHA: Self = Self::Alpha; +} + +pub type E = Expr, Column>; + +// Code to allow for pretty printing of the expressions +impl FormattedOutput for Column { + fn latex(&self, _cache: &mut HashMap) -> String { + match self { + Column::Selector(sel) => match sel { + Gadget::App => "q_app".to_string(), + Gadget::EllipticCurveAddition => "q_ec_add".to_string(), + Gadget::EllipticCurveScaling => "q_ec_mul".to_string(), + Gadget::Poseidon => "q_pos".to_string(), + }, + Column::PublicInput(i) => format!("pi_{{{i}}}").to_string(), + Column::X(i) => format!("x_{{{i}}}").to_string(), + } + } + + fn text(&self, _cache: &mut HashMap) -> String { + match self { + Column::Selector(sel) => match sel { + Gadget::App => "q_app".to_string(), + Gadget::EllipticCurveAddition => "q_ec_add".to_string(), + Gadget::EllipticCurveScaling => "q_ec_mul".to_string(), + Gadget::Poseidon => "q_pos_next_row".to_string(), + }, + Column::PublicInput(i) => format!("pi[{i}]"), + Column::X(i) => format!("x[{i}]"), + } + } + + fn ocaml(&self, _cache: &mut HashMap) -> String { + // FIXME + unimplemented!("Not used at the moment") + } + + fn is_alpha(&self) -> bool { + // FIXME + unimplemented!("Not used at the moment") + } +} diff --git a/arrabiata/src/constraints.rs b/arrabiata/src/constraints.rs new file mode 100644 index 0000000000..d3a22655de --- /dev/null +++ b/arrabiata/src/constraints.rs @@ -0,0 +1,372 @@ +use super::{columns::Column, interpreter::InterpreterEnv}; +use crate::{ + columns::{Gadget, E}, + interpreter::{self, Instruction, Side}, + MAX_DEGREE, NUMBER_OF_COLUMNS, NUMBER_OF_PUBLIC_INPUTS, +}; +use ark_ff::{Field, PrimeField}; +use kimchi::circuits::{ + expr::{ConstantTerm::Literal, Expr, ExprInner, Operations, Variable}, + gate::CurrOrNext, +}; +use log::debug; +use num_bigint::BigInt; +use o1_utils::FieldHelpers; + +#[derive(Clone, Debug)] +pub struct Env { + pub poseidon_mds: Vec>, + /// The parameter a is the coefficients of the elliptic curve in affine + /// coordinates. + // FIXME: this is ugly. Let use the curve as a parameter. Only lazy for now. + pub a: BigInt, + pub idx_var: usize, + pub idx_var_next_row: usize, + pub idx_var_pi: usize, + pub constraints: Vec>, + pub activated_gadget: Option, +} + +impl Env { + pub fn new(poseidon_mds: Vec>, a: BigInt) -> Self { + // This check might not be useful + assert!(a < Fp::modulus_biguint().into(), "a is too large"); + Self { + poseidon_mds, + a, + idx_var: 0, + idx_var_next_row: 0, + idx_var_pi: 0, + constraints: Vec::new(), + activated_gadget: None, + } + } +} + +/// An environment to build constraints. +/// The constraint environment is mostly useful when we want to perform a Nova +/// proof. +/// The constraint environment must be instantiated only once, at the last step +/// of the computation. +impl InterpreterEnv for Env { + type Position = (Column, CurrOrNext); + + type Variable = E; + + fn allocate(&mut self) -> Self::Position { + assert!(self.idx_var < NUMBER_OF_COLUMNS, "Maximum number of columns reached ({NUMBER_OF_COLUMNS}), increase the number of columns"); + let pos = Column::X(self.idx_var); + self.idx_var += 1; + (pos, CurrOrNext::Curr) + } + + fn allocate_next_row(&mut self) -> Self::Position { + assert!(self.idx_var_next_row < NUMBER_OF_COLUMNS, "Maximum number of columns reached ({NUMBER_OF_COLUMNS}), increase the number of columns"); + let pos = Column::X(self.idx_var_next_row); + self.idx_var_next_row += 1; + (pos, CurrOrNext::Next) + } + + fn read_position(&self, pos: Self::Position) -> Self::Variable { + let (col, row) = pos; + Expr::Atom(ExprInner::Cell(Variable { col, row })) + } + + fn allocate_public_input(&mut self) -> Self::Position { + assert!(self.idx_var_pi < NUMBER_OF_PUBLIC_INPUTS, "Maximum number of public inputs reached ({NUMBER_OF_PUBLIC_INPUTS}), increase the number of public inputs"); + let pos = Column::PublicInput(self.idx_var_pi); + self.idx_var_pi += 1; + (pos, CurrOrNext::Curr) + } + + fn constant(&self, value: BigInt) -> Self::Variable { + let v = value.to_biguint().unwrap(); + let v = Fp::from_biguint(&v).unwrap(); + let v_inner = Operations::from(Literal(v)); + Self::Variable::constant(v_inner) + } + + /// Return the corresponding expression regarding the selected public input + fn write_public_input(&mut self, pos: Self::Position, _v: BigInt) -> Self::Variable { + self.read_position(pos) + } + + /// Return the corresponding expression regarding the selected column + fn write_column(&mut self, pos: Self::Position, v: Self::Variable) -> Self::Variable { + let (col, row) = pos; + let res = Expr::Atom(ExprInner::Cell(Variable { col, row })); + self.assert_equal(res.clone(), v); + res + } + + fn activate_gadget(&mut self, gadget: Gadget) { + self.activated_gadget = Some(gadget); + } + + fn add_constraint(&mut self, constraint: Self::Variable) { + let degree = constraint.degree(1, 0); + debug!("Adding constraint of degree {degree}: {:}", constraint); + assert!(degree <= MAX_DEGREE, "degree is too high: {}. The folding scheme used currently allows constraint up to degree {}", degree, MAX_DEGREE); + self.constraints.push(constraint); + } + + fn constrain_boolean(&mut self, x: Self::Variable) { + let one = self.one(); + let c = x.clone() * (x.clone() - one); + self.constraints.push(c) + } + fn assert_zero(&mut self, x: Self::Variable) { + self.add_constraint(x); + } + + fn assert_equal(&mut self, x: Self::Variable, y: Self::Variable) { + self.add_constraint(x - y); + } + + unsafe fn bitmask_be( + &mut self, + _x: &Self::Variable, + _highest_bit: u32, + _lowest_bit: u32, + pos: Self::Position, + ) -> Self::Variable { + self.read_position(pos) + } + + fn square(&mut self, pos: Self::Position, x: Self::Variable) -> Self::Variable { + let v = self.read_position(pos); + let x = x.square(); + self.add_constraint(x - v.clone()); + v + } + + // This is witness-only. We simply return the corresponding expression to + // use later in constraints + fn fetch_input(&mut self, pos: Self::Position) -> Self::Variable { + self.read_position(pos) + } + + fn reset(&mut self) { + self.idx_var = 0; + self.idx_var_next_row = 0; + self.idx_var_pi = 0; + self.constraints.clear(); + self.activated_gadget = None; + } + + fn coin_folding_combiner(&mut self, pos: Self::Position) -> Self::Variable { + self.read_position(pos) + } + + fn load_poseidon_state(&mut self, pos: Self::Position, _i: usize) -> Self::Variable { + self.read_position(pos) + } + + // Witness-only + unsafe fn save_poseidon_state(&mut self, _x: Self::Variable, _i: usize) {} + + fn get_poseidon_round_constant( + &mut self, + pos: Self::Position, + _round: usize, + _i: usize, + ) -> Self::Variable { + let (col, row) = pos; + match col { + Column::PublicInput(_) => (), + _ => panic!("Only public inputs can be used as round constants"), + }; + Expr::Atom(ExprInner::Cell(Variable { col, row })) + } + + fn get_poseidon_mds_matrix(&mut self, i: usize, j: usize) -> Self::Variable { + let v = self.poseidon_mds[i][j]; + let v_inner = Operations::from(Literal(v)); + Self::Variable::constant(v_inner) + } + + unsafe fn fetch_value_to_absorb( + &mut self, + pos: Self::Position, + _curr_round: usize, + ) -> Self::Variable { + self.read_position(pos) + } + + unsafe fn load_temporary_accumulators( + &mut self, + pos_x: Self::Position, + pos_y: Self::Position, + _side: Side, + ) -> (Self::Variable, Self::Variable) { + let x = self.read_position(pos_x); + let y = self.read_position(pos_y); + (x, y) + } + + // witness only + unsafe fn save_temporary_accumulators( + &mut self, + _x: Self::Variable, + _y: Self::Variable, + _side: Side, + ) { + } + + /// Inverse of a variable + /// + /// # Safety + /// + /// Zero is not allowed as an input. + unsafe fn inverse(&mut self, pos: Self::Position, x: Self::Variable) -> Self::Variable { + let v = self.read_position(pos); + let res = v.clone() * x.clone(); + self.assert_equal(res.clone(), self.one()); + v + } + + unsafe fn is_same_ec_point( + &mut self, + pos: Self::Position, + _x1: Self::Variable, + _y1: Self::Variable, + _x2: Self::Variable, + _y2: Self::Variable, + ) -> Self::Variable { + self.read_position(pos) + } + + fn zero(&self) -> Self::Variable { + self.constant(BigInt::from(0_usize)) + } + + fn one(&self) -> Self::Variable { + self.constant(BigInt::from(1_usize)) + } + + /// Double the elliptic curve point given by the affine coordinates + /// `(x1, y1)` and save the result in the registers `pos_x` and `pos_y`. + fn double_ec_point( + &mut self, + pos_x: Self::Position, + pos_y: Self::Position, + x1: Self::Variable, + y1: Self::Variable, + ) -> (Self::Variable, Self::Variable) { + let lambda = { + let pos = self.allocate(); + self.read_position(pos) + }; + let x3 = self.read_position(pos_x); + let y3 = self.read_position(pos_y); + + // λ 2y1 = 3x1^2 + a + let x1_square = x1.clone() * x1.clone(); + let two_x1_square = x1_square.clone() + x1_square.clone(); + let three_x1_square = two_x1_square.clone() + x1_square.clone(); + let two_y1 = y1.clone() + y1.clone(); + let res = lambda.clone() * two_y1 - (three_x1_square + self.constant(self.a.clone())); + self.assert_zero(res); + // x3 = λ^2 - 2 x1 + self.assert_equal( + x3.clone(), + lambda.clone() * lambda.clone() - x1.clone() - x1.clone(), + ); + // y3 = λ(x1 - x3) - y1 + self.assert_equal( + y3.clone(), + lambda.clone() * (x1.clone() - x3.clone()) - y1.clone(), + ); + (x3, y3.clone()) + } + + fn compute_lambda( + &mut self, + pos: Self::Position, + is_same_point: Self::Variable, + x1: Self::Variable, + y1: Self::Variable, + x2: Self::Variable, + y2: Self::Variable, + ) -> Self::Variable { + let lambda = self.read_position(pos); + let lhs = lambda.clone() * (x1.clone() - x2.clone()) - (y1.clone() - y2.clone()); + let rhs = { + let x1_square = x1.clone() * x1.clone(); + let two_x1_square = x1_square.clone() + x1_square.clone(); + let three_x1_square = two_x1_square.clone() + x1_square.clone(); + let two_y1 = y1.clone() + y1.clone(); + lambda.clone() * two_y1 - (three_x1_square + self.constant(self.a.clone())) + }; + let res = is_same_point.clone() * lhs + (self.one() - is_same_point.clone()) * rhs; + self.assert_zero(res); + lambda + } +} + +impl Env { + /// Get all the constraints for the IVC circuit, only. + /// + /// The following gadgets are used in the IVC circuit: + /// - [Instruction::Poseidon] to verify the challenges and the public + /// IO + /// - [Instruction::EllipticCurveScaling] and + /// [Instruction::EllipticCurveAddition] to accumulate the commitments + // FIXME: the IVC circuit might not be complete, yet. For instance, we might + // need to accumulate the challenges and add a row to verify the output of + // the computation of the challenges. + // FIXME: add a test checking that whatever the value given in parameter of + // the gadget, the constraints are the same + pub fn get_all_constraints_for_ivc(&self) -> Vec> { + // Copying the instance we got in parameter, and making it mutable to + // avoid modifying the original instance. + let mut env = self.clone(); + // Resetting before running anything + env.reset(); + + let mut constraints = vec![]; + + // Poseidon constraints + // The constraints are the same for all the value given in parameter, + // therefore picking 0 + interpreter::run_ivc(&mut env, Instruction::Poseidon(0)); + constraints.extend(env.constraints.clone()); + env.reset(); + + // EC scaling + // The constraints are the same whatever the value given in parameter, + // therefore picking 0, 0 + interpreter::run_ivc(&mut env, Instruction::EllipticCurveScaling(0, 0)); + constraints.extend(env.constraints.clone()); + env.reset(); + + // EC addition + // The constraints are the same whatever the value given in parameter, + // therefore picking 0 + interpreter::run_ivc(&mut env, Instruction::EllipticCurveAddition(0)); + constraints.extend(env.constraints.clone()); + env.reset(); + + constraints + } + + /// Get all the constraints for the IVC circuit and the application. + // FIXME: the application should be given as an argument to handle Rust + // zkApp. It is only for the PoC. + // FIXME: the selectors are not added for now. + pub fn get_all_constraints(&self) -> Vec> { + let mut constraints = self.get_all_constraints_for_ivc(); + + // Copying the instance we got in parameter, and making it mutable to + // avoid modifying the original instance. + let mut env = self.clone(); + // Resetting before running anything + env.reset(); + + // Get the constraints for the application + interpreter::run_app(&mut env); + constraints.extend(env.constraints.clone()); + + constraints + } +} diff --git a/arrabiata/src/interpreter.rs b/arrabiata/src/interpreter.rs new file mode 100644 index 0000000000..025b2ac0a2 --- /dev/null +++ b/arrabiata/src/interpreter.rs @@ -0,0 +1,935 @@ +//! This module contains the implementation of the IVC scheme in addition to +//! running an arbitrary function that can use up to [crate::NUMBER_OF_COLUMNS] +//! columns. +//! At the moment, all constraints must be of maximum degree +//! [crate::MAX_DEGREE], but it might change in the future. +//! +//! The implementation relies on a representation of the circuit as a 2D array +//! of "data points" the interpreter can use. +//! +//! An interpreter defines what a "position" is in the circuit and allow to +//! perform operations using these positions. +//! Some of these positions will be considered as public inputs and might be +//! fixed at setup time while making a proof, when other will be considered as +//! private inputs. +//! +//! On top of these abstraction, gadgets are implemented. +//! For the Nova-like IVC schemes, we describe below the different gadgets and +//! how they are implemented with this abstraction. +//! +//! **Table of contents**: +//! - [Gadgets implemented](#gadgets-implemented) +//! - [Elliptic curve addition](#elliptic-curve-addition) +//! - [Gadget layout](#gadget-layout) +//! - [Hash - Poseidon](#hash---poseidon) +//! - [Gadget layout](#gadget-layout-1) +//! - [Elliptic curve scalar multiplication](#elliptic-curve-scalar-multiplication) +//! - [Gadget layout](#gadget-layout-2) +//! - [Handle the combinaison of constraints](#handle-the-combinaison-of-constraints) +//! - [Permutation argument](#permutation-argument) +//! - [Fiat-Shamir challenges](#fiat-shamir-challenges) +//! +//! ## Gadgets implemented +//! +//! ### Elliptic curve addition +//! +//! The Nova augmented circuit requires to perform elliptic curve operations, in +//! particular additions and scalar multiplications. +//! +//! To reduce the number of operations, we consider the affine coordinates. +//! As a reminder, here the equations to compute the addition of two different +//! points `P1 = (X1, Y1)` and `P2 = (X2, Y2)`. Let define `P3 = (X3, Y3) = P1 + +//! P2`. +//! +//! ```text +//! - λ = (Y1 - Y2) / (X1 - X2) +//! - X3 = λ^2 - X1 - X2 +//! - Y3 = λ (X1 - X3) - Y1 +//! ``` +//! +//! Therefore, the addition of elliptic curve points can be computed using the +//! following degree-2 constraints +//! +//! ```text +//! - Constraint 1: λ (X1 - X2) - Y1 + Y2 = 0 +//! - Constraint 2: X3 + X1 + X2 - λ^2 = 0 +//! - Constraint 3: Y3 - λ (X1 - X3) + Y1 = 0 +//! ``` +//! +//! If the points are the same, the λ is computed as follows: +//! +//! ```text +//! - λ = (3 X1^2 + a) / (2Y1) +//! ``` +//! +//! #### Gadget layout +//! +//! For given inputs (x1, y1) and (x2, y2), the layout will be as follow: +//! +//! ```text +//! | C1 | C2 | C3 | C4 | C5 | C6 | C7 | C8 | C9 | C10 | C11 | C12 | C13 | C14 | C15 | C16 | C17 | +//! | -- | -- | -- | -- | -- | -- | -- | -- | -- | --- | --- | --- | --- | --- | --- | --- | --- | +//! | x1 | y1 | x2 | y2 | b0 | λ | x3 | y3 | | | | | | | | | | +//! ``` +//! +//! where `b0` is equal two `1` if the points are the same, and `0` otherwise. +//! +//! TBD/FIXME: supports negation and the infinity point. +//! +//! TBD/FIXME: the gadget layout might change when we will implement the +//! permutation argument. The values `(x1, y1)` can be public inputs. +//! The values `(x2, y2)` can be fetched from the permutation argument, and must +//! be the output of the elliptic curve scaling. +//! +//! The gadget requires therefore 7 columns. +//! +//! ### Hash - Poseidon +//! +//! Hashing is a crucial part of the IVC scheme. The hash function the +//! interpreter does use for the moment is an instance of the Poseidon hash +//! function with a fixed state size of [POSEIDON_STATE_SIZE]. Increasing the +//! state size can be considered as it would potentially optimize the +//! number of rounds, and allow hashing more data on one row. We leave this for +//! future works. +//! +//! A direct optimisation would be to use +//! [Poseidon2](https://eprint.iacr.org/2023/323) as its performance on CPU is +//! better, for the same security level and the same cost in circuit. We leave +//! this for future works. +//! +//! For a first version, we consider an instance of the Poseidon hash function +//! that is suitable for curves whose field size is around 256 bits. +//! A security analysis for these curves give us a recommandation of 60 full +//! rounds if we consider a 128-bit security level and a low-degree +//! exponentiation of `5`, with only full rounds. +//! In the near future, we will consider the partial rounds strategy to reduce +//! the CPU cost. For a first version, we keep the full rounds strategy to keep +//! the design simple. +//! +//! When applying the full/partial round strategy, an optimisation can be used, +//! see [New Optimization techniques for PlonK's +//! arithmetisation](https://eprint.iacr.org/2022/462). The techniques described +//! in the paper can also be generalized to other constraints used in the +//! interpreter, but we leave this for future works. +//! +//! #### Gadget layout +//! +//! We start with the assumption that 17 columns are available for the whole +//! circuit, and we can support constraints up to degree 5. +//! Therefore, we can compute 4 full rounds per row if we rely on the +//! permutation argument, or 5 full rounds per row if we use the "next row". +//! +//! We provide two implementations of the Poseidon hash function. The first one +//! does not use the "next row" and is limited to 4 full rounds per row. The +//! second one uses the "next row" and can compute 5 full rounds per row. +//! The second implementation is more efficient as it allows to compute one +//! additional round per row. +//! For the second implementation, the permutation argument will only be +//! activated on the first and last group of 5 rounds. +//! +//! The layout for the one not using the "next row" is as follow (4 full rounds): +//! ```text +//! | C1 | C2 | C3 | C4 | C5 | C6 | C7 | C8 | C9 | C10 | C11 | C12 | C13 | C14 | C15 | +//! | -- | -- | -- | -- | -- | -- | -- | -- | -- | --- | --- | --- | --- | --- | --- | +//! | x | y | z | a1 | a2 | a3 | b1 | b2 | b3 | c1 | c2 | c3 | o1 | o2 | o3 | +//! ``` +//! where (x, y, z) is the input of the current step, (o1, o2, o3) is the +//! output, and the other values are intermediary values. And we have the following equalities: +//! ```text +//! (a1, a2, a3) = PoseidonRound(x, y, z) +//! (b1, b2, b3) = PoseidonRound(a1, a2, a3) +//! (c1, c2, c3) = PoseidonRound(b1, b2, b3) +//! (o1, o2, o3) = PoseidonRound(c1, c2, c3) +//! ``` +//! +//! The layout for the one using the "next row" is as follow (5 full rounds): +//! ```text +//! | C1 | C2 | C3 | C4 | C5 | C6 | C7 | C8 | C9 | C10 | C11 | C12 | C13 | C14 | C15 | +//! | -- | -- | -- | -- | -- | -- | -- | -- | -- | --- | --- | --- | --- | --- | --- | +//! | x | y | z | a1 | a2 | a3 | b1 | b2 | b3 | c1 | c2 | c3 | d1 | d2 | d3 | +//! | o1 | o2 | o2 +//! ``` +//! where (x, y, z) is the input of the current step, (o1, o2, o3) is the +//! output, and the other values are intermediary values. And we have the +//! following equalities: +//! ```text +//! (a1, a2, a3) = PoseidonRound(x, y, z) +//! (b1, b2, b3) = PoseidonRound(a1, a2, a3) +//! (c1, c2, c3) = PoseidonRound(b1, b2, b3) +//! (d1, d2, d3) = PoseidonRound(c1, c2, c3) +//! (o1, o2, o3) = PoseidonRound(d1, d2, d3) +//! ``` +//! +//! For both implementations, round constants are passed as public inputs. As a +//! reminder, public inputs are simply additional columns known by the prover +//! and verifier. +//! Also, the elements to absorb are added to the initial state at the beginning +//! of the call of the Poseidon full hash. The elements to absorb are supposed +//! to be passed as public inputs. +//! +//! ### Elliptic curve scalar multiplication +//! +//! The Nova-based IVC schemes require to perform scalar multiplications on +//! elliptic curve points. The scalar multiplication is computed using the +//! double-and-add algorithm. +//! +//! We will consider a basic implementation using the "next row". The +//! accumulators will be saved on the "next row". The decomposition of the +//! scalar will be incrementally on each row. +//! The scalar used for the scalar multiplication will be fetched using the +//! permutation argument (FIXME: to be implemented). +//! More than one bit can be decomposed at the same time, and we could reduce +//! the number of rows. +//! We leave this for future work. +//! +//! #### Gadget layout +//! +//! For a (x, y) point and a scalar, we apply the double-and-add algorithm, one step per row. +//! Therefore, we have 255 rows to compute the scalar multiplication. +//! For a given step `i`, we have the following values: +//! - `tmp_x`, `tmp_y`: the temporary values used to keep the double. +//! - `res_x`, `res_y`: the result of the scalar multiplication i.e. the accumulator. +//! - `b`: the i-th bit of the scalar. +//! - `r_i` and `r_(i+1)`: scalars such that r_(i+1) = b + 2 * r_i. +//! - `λ'` and `λ`: the coefficients +//! - o'_x and o'_y equal to `res_plus_tmp_x` and `res_plus_tmp_y` if `b == 1`, +//! otherwise equal to `o_x` and `o_y`. +//! +//! We have the following layout: +//! +//! ```text +//! | C1 | C2 | C3 | C4 | C5 | C7 | C7 | C8 | C9 | C10 | C11 | C12 | C13 | C14 | C15 | C16 | C17 | +//! | -- | ----- | ------------- | ------------- | --------- | -- | -------------- | -------------- | -- | --- | -------- | --- | --- | --- | --- | --- | --- | +//! | o_x | o_y | double_tmp_x | double_tmp_y | r_i | λ | res_plus_tmp_x | res_plus_tmp_y | λ' | b | +//! | o'_x | o'_y | double_tmp'_x | double_tmp'_y | r_(i+1) | +//! ``` +//! +//! FIXME: an optimisation can be implemented using "a bucket" style algorithm, +//! as described in [Efficient MSMs in Kimchi +//! Circuits](https://github.com/o1-labs/rfcs/blob/main/0013-efficient-msms-for-non-native-pickles-verification.md). +//! We leave this for future work. +//! +//! ## Handle the combinaison of constraints +//! +//! The prover will have to combine the constraints to generate the +//! full circuit at the end. The constraints will be combined using a +//! challenge (often called α) that will be generated in the verifier circuit by +//! simulating the Fiat-Shamir transformation. +//! The challenges will then be accumulated over time using the random coin used +//! by the folding argument. +//! The verifier circuit must be carefully implemented to ensure that all the +//! messages that the prover would have sent before coining the random combiner +//! for the constraints has been absorbed properly in the verifier circuit. +//! +//! Using this technique requires us a folding scheme that handles degree +//! `5 + 1` constraints, as the challenge will be considered as a variable. +//! The reader can refer to the folding library available in this monorepo for +//! more contexts. +//! +//! ## Permutation argument +//! +//! Communication between rows must be done using a permutation argument. The +//! argument we use will be a generalisation of the one used in the [PlonK +//! paper](https://eprint.iacr.org/2019/953). +//! +//! The construction of the permutations will be done using the methods prefixed +//! `save` and `load`. The index of the current row and the index of the +//! time the value has been written will be used to generate the permutation on +//! the fly. +//! +//! The permutation argument described in the PlonK paper is a kind of "inverse +//! lookup" protocol, like Plookup. The polynomials are defined as follows: +//! +//! ```text +//! Can be seen as T[f(X)] = Χ +//! -------- +//! | | +//! f'(X) = f(X) + β X + γ +//! |--- Can be seen as the evaluation point. +//! | +//! | +//! g'(X) = g(X) + β σ(X) + γ +//! | | +//! -------- +//! Can be seen as T[g(X)] = σ(X) +//! ``` +//! +//! And from this, we build an accumulator, like for Plookup. +//! The accumulator requires to coin two challenges, β and γ, and it must be +//! done after the commitments to the columns have been absorbed. +//! The verifier at the next step will verify that the challenges have been +//! correctly computed. +//! In the implementation, the accumulator will be computed after the challenges +//! and the commitments. Note that the accumulator must also be aggregated, and +//! the aggregation must be performed by the verifier at the next step. +//! +//! The methods `save` and `load` will accept as arguments only a column that is +//! included in the permutation argument. For instance, `save_poseidon_state` +//! will only accept columns with index 3, 4 and 5, where the +//! `load_poseidon_state` will only accepts columns with index 0, 1 and 2. +//! +//! The permutations values will be saved in public values, and will contain the +//! index of the row. The permutation values will be encoded with a 32 bits +//! value (u32) as we can suppose a negligible probability that a user will use +//! more than 2^32 rows. +//! +//! The permutation argument also generates constraints that will be +//! homogenized with the gadget constraints. +//! +//! Note all rows might require to use the permutation argument. Therefore, a +//! selector will be added to activate/deactivate the permutation argument. +//! When a method calls `save` or `load`, the selector will be activated. By +//! default, the selector will be deactivated. +//! +//! TBD: +//! - number of columns +//! - accumulator column +//! - folding of the permutation argument +//! +//! TBD/FIXME: do we use a additive permutation argument to increase the number +//! of columns we can perform the permutation on? +//! +//! TBD/FIXME: We can have more than one permutation argument. For instance, we +//! can have a permutation argument for the columns 0, 1, 2, 3 and one for the +//! columns 4, 5, 6, 7. It can help to decrease the degree. +//! +//! ## Fiat-Shamir challenges +//! +//! The challenges sent by the verifier must also be simulated by the IVC +//! circuit. +//! +//! For a step `i + 1`, the challenges of the step `i` must be computed by the +//! verifier, and check that it corresponds to the ones received as a public +//! input. +//! +//! TBD/FIXME: specify. Might require foreign field arithmetic. +//! +//! TBD/FIXME: do we need to aggregate them for the end? +//! +//! ## Folding +//! +//! Constraints must be homogenized for the folding scheme. +//! Homogenising a constraint means that we add a new variable (called "U" in +//! Nova for instance) that will be used to homogenize the degree of the monomials +//! forming the constraint. +//! Next to this, additional information, like the cross-terms and the error +//! terms must be computed. +//! +//! This computation depends on the constraints, and in particular on the +//! monomials describing the constraints. +//! The computation of the cross-terms and the error terms happen after the +//! witness has been built and the different arguments like the permutation or +//! lookup have been done. Therefore, the interpreter must provide a method to +//! compute it, and the constraints should be passed as an argument. +//! +//! When computing the cross-terms, we must compute the contribution of each +//! monomial to it. +//! +//! The implementation works as follow: +//! - Split the constraint in monomials +//! - For the monomials of degree `d`, compute the contribution when +//! homogenizing to degree `d'`. +//! - Sum all the contributions. +//! +//! The library [mvpoly] can be used to compute the cross-terms and to +//! homogenize the constraints. The constraints can be converted into a type +//! implementing the trait [MVPoly](mvpoly::MVPoly) and the method +//! [compute_cross_terms](mvpoly::MVPoly::compute_cross_terms) can be used from +//! there. + +use crate::{ + columns::Gadget, MAXIMUM_FIELD_SIZE_IN_BITS, NUMBER_OF_COLUMNS, POSEIDON_ROUNDS_FULL, + POSEIDON_STATE_SIZE, +}; +use ark_ff::{One, Zero}; +use log::debug; +use num_bigint::BigInt; + +/// A list of instruction/gadget implemented in the interpreter. +/// The control flow can be managed by implementing a function +/// `fetch_next_instruction` and `fetch_instruction` on a witness environnement. +/// See the [Witness environment](crate::witness::Env) for more details. +/// +/// Mostly, the instructions will be used to build the IVC circuit, but it can be +/// generalized. +/// +/// When the circuit is predefined, the instructions can be accompanied by a +/// public selector. When implementing a virtual machine, where instructions are +/// unknown at compile time, other methods can be used. We leave this for future +/// work. +/// +/// For the moment, the type is not parametrized, on purpose, to keep it simple +/// (KISS method). However, IO could be encoded in the type, and encode a +/// typed control-flow. We leave this for future work. +#[derive(Copy, Clone, Debug)] +pub enum Instruction { + /// This gadget implement the Poseidon hash instance described in the + /// top-level documentation. Compared to the previous one (that might be + /// deprecated in the future), this implementation does use the "next row" + /// to allow the computation of one additional round per row. In the current + /// setup, with [NUMBER_OF_COLUMNS] columns, we can compute 5 full rounds + /// per row. + Poseidon(usize), + EllipticCurveScaling(usize, u64), + EllipticCurveAddition(usize), + // The NoOp will simply do nothing + NoOp, +} + +/// Define the side of the temporary accumulator. +/// When computing G1 + G2, the interpreter will load G1 and after that G2. +/// This enum is used to decide which side fetching into the cells. +/// In the near future, it can be replaced by an index. +pub enum Side { + Left, + Right, +} + +/// An abstract interpreter that provides some functionality on the circuit. The +/// interpreter should be seen as a state machine with some built-in +/// functionality whose state is a matrix, and whose transitions are described +/// by polynomial functions. +pub trait InterpreterEnv { + type Position: Clone + Copy; + + /// The variable should be seen as a certain object that can be built by + /// multiplying and adding, i.e. the variable can be seen as a solution + /// to a polynomial. + /// When instantiating as expressions - "constraints" - it defines + /// multivariate polynomials. + type Variable: Clone + + std::ops::Add + + std::ops::Sub + + std::ops::Mul + + std::fmt::Debug + + Zero + + One; + + /// Allocate a new variable in the circuit for the current row + fn allocate(&mut self) -> Self::Position; + + /// Allocate a new variable in the circuit for the next row + fn allocate_next_row(&mut self) -> Self::Position; + + /// Return the corresponding variable at the given position + fn read_position(&self, pos: Self::Position) -> Self::Variable; + + fn allocate_public_input(&mut self) -> Self::Position; + + /// Set the value of the variable at the given position for the current row + fn write_column(&mut self, col: Self::Position, v: Self::Variable) -> Self::Variable; + + /// Write the corresponding public inputs. + // FIXME: This design might not be the best. Feel free to come up with a + // better solution. The PI should be static for all witnesses + fn write_public_input(&mut self, x: Self::Position, v: BigInt) -> Self::Variable; + + /// Activate the gadget for the row. + fn activate_gadget(&mut self, gadget: Gadget); + + /// Build the constant zero + fn zero(&self) -> Self::Variable; + + /// Build the constant one + fn one(&self) -> Self::Variable; + + fn constant(&self, v: BigInt) -> Self::Variable; + + /// Assert that the variable is zero + fn assert_zero(&mut self, x: Self::Variable); + + /// Assert that the two variables are equal + fn assert_equal(&mut self, x: Self::Variable, y: Self::Variable); + + fn add_constraint(&mut self, x: Self::Variable); + + fn constrain_boolean(&mut self, x: Self::Variable); + + /// Compute the square a field element + fn square(&mut self, res: Self::Position, x: Self::Variable) -> Self::Variable; + + /// Flagged as unsafe as it does require an additional range check + /// + /// # Safety + /// + /// There are no constraints on the returned value; callers must assert the relationship with + /// the source variable `x` and that the returned value fits in `highest_bit - lowest_bit` + /// bits. + unsafe fn bitmask_be( + &mut self, + x: &Self::Variable, + highest_bit: u32, + lowest_bit: u32, + position: Self::Position, + ) -> Self::Variable; + + /// Fetch an input of the application + // Witness-only + fn fetch_input(&mut self, res: Self::Position) -> Self::Variable; + + /// Reset the environment to build the next row + fn reset(&mut self); + + /// Return the folding combiner + fn coin_folding_combiner(&mut self, pos: Self::Position) -> Self::Variable; + + /// Compute the x^5 of the given variable + fn compute_x5(&self, x: Self::Variable) -> Self::Variable { + let x_square = x.clone() * x.clone(); + let x_cubed = x_square.clone() * x.clone(); + x_cubed * x_square.clone() + } + + // ---- Poseidon gadget ----- + /// Load the state of the Poseidon hash function into the environment + fn load_poseidon_state(&mut self, pos: Self::Position, i: usize) -> Self::Variable; + + /// Save the state of poseidon into the environment + /// + /// # Safety + /// + /// It does not have any effect on the constraints + unsafe fn save_poseidon_state(&mut self, v: Self::Variable, i: usize); + + fn get_poseidon_round_constant( + &mut self, + pos: Self::Position, + round: usize, + i: usize, + ) -> Self::Variable; + + /// Return the requested MDS matrix coefficient + fn get_poseidon_mds_matrix(&mut self, i: usize, j: usize) -> Self::Variable; + + /// Load the public value to absorb at the current step. + /// The position should be a public column. + /// + /// IMPROVEME: we could have in the environment an heterogeneous typed list, + /// and we pop values call after call. However, we try to keep the + /// interpreter simple. + /// + /// # Safety + /// + /// No constraint is added. It should be used with caution. + unsafe fn fetch_value_to_absorb( + &mut self, + pos: Self::Position, + curr_round: usize, + ) -> Self::Variable; + // ------------------------- + + /// Check if the points given by (x1, y1) and (x2, y2) are equals. + /// + /// # Safety + /// + /// No constraint is added. It should be used with caution. + unsafe fn is_same_ec_point( + &mut self, + pos: Self::Position, + x1: Self::Variable, + y1: Self::Variable, + x2: Self::Variable, + y2: Self::Variable, + ) -> Self::Variable; + + /// Inverse of a variable + /// + /// # Safety + /// + /// Zero is not allowed as an input. + /// Witness only + // IMPROVEME: when computing the witness, we could have an additional column + // that would stand for the inverse, and compute all inverses at the end + // using a batch inversion. + unsafe fn inverse(&mut self, pos: Self::Position, x: Self::Variable) -> Self::Variable; + + /// Compute the coefficient λ used in the elliptic curve addition. + /// If the two points are the same, the λ is computed as follows: + /// - λ = (3 X1^2 + a) / (2Y1) + /// Otherwise, the λ is computed as follows: + /// - λ = (Y1 - Y2) / (X1 - X2) + fn compute_lambda( + &mut self, + pos: Self::Position, + is_same_point: Self::Variable, + x1: Self::Variable, + y1: Self::Variable, + x2: Self::Variable, + y2: Self::Variable, + ) -> Self::Variable; + + /// Double the elliptic curve point given by the affine coordinates + /// `(x1, y1)` and save the result in the registers `pos_x` and `pos_y`. + /// The last argument, `row`, is used to decide if the result should be + /// written in the current or the next row of the variable position `pos_x` + /// and `pos_y`. + fn double_ec_point( + &mut self, + pos_x: Self::Position, + pos_y: Self::Position, + x1: Self::Variable, + y1: Self::Variable, + ) -> (Self::Variable, Self::Variable); + + /// Load the affine coordinates of the elliptic curve point currently saved + /// in the temporary accumulators. Temporary accumulators could be seen as + /// a CPU cache, an intermediate storage between the RAM (random access + /// memory) and the CPU registers (memory cells that are constrained). + /// + /// For now, it can only be used to load affine coordinates of elliptic + /// curve points given in the short Weierstrass form. + /// + /// Temporary accumulators could also be seen as return values of a function. + /// + /// # Safety + /// + /// No constraints are enforced. It is not also enforced that the + /// accumulators have been cleaned between two different gadgets. + unsafe fn load_temporary_accumulators( + &mut self, + pos_x: Self::Position, + pos_y: Self::Position, + side: Side, + ) -> (Self::Variable, Self::Variable); + + /// Save temporary accumulators into the environment + /// + /// # Safety + /// + /// It does not have any effect on the constraints. + unsafe fn save_temporary_accumulators( + &mut self, + _v1: Self::Variable, + _v2: Self::Variable, + _side: Side, + ); +} + +/// Run the application +pub fn run_app(env: &mut E) { + let x1 = { + let pos = env.allocate(); + env.fetch_input(pos) + }; + let _x1_square = { + let res = env.allocate(); + env.square(res, x1.clone()) + }; +} + +/// Run an iteration of the IVC scheme +/// +/// It consists of the following steps: +/// 1. Compute the hash of the public input. +/// 2. Compute the elliptic curve addition. +/// 3. Run the polynomial-time function. +/// 4. Compute the hash of the output. +/// The environment is updated over time. +/// When the environment is the one described in the [Witness +/// environment](crate::witness::Env), the structure will be updated +/// with the new accumulator, the new public input, etc. The public output will +/// be in the structure also. The user can simply rerun the function for the +/// next iteration. +/// A row must be created to generate a challenge to combine the constraints +/// later. The challenge will be also accumulated over time. +/// +/// FIXME: the resulting constraints do not include the selectors, yet. The +/// resulting constraints must be multiplied by the corresponding selectors. +pub fn run_ivc(env: &mut E, instr: Instruction) { + match instr { + Instruction::EllipticCurveScaling(i_comm, processing_bit) => { + assert!(processing_bit < MAXIMUM_FIELD_SIZE_IN_BITS, "Invalid bit index. The fields are maximum on {MAXIMUM_FIELD_SIZE_IN_BITS} bits, therefore we cannot process the bit {processing_bit}"); + assert!(i_comm < NUMBER_OF_COLUMNS, "Invalid index. We do only support the scaling of the commitments to the columns, for now. We must additionally support the scaling of cross-terms and error terms"); + debug!("Processing scaling of commitment {i_comm}, bit {processing_bit}"); + env.activate_gadget(Gadget::EllipticCurveScaling); + // When processing the first bit, we must load the scalar, and it + // comes from previous computation. + // The two first columns are supposed to be used for the output. + // It will be used to write the result of the scalar multiplication + // in the next row. + let res_col_x = env.allocate(); + let res_col_y = env.allocate(); + let tmp_col_x = env.allocate(); + let tmp_col_y = env.allocate(); + let scalar_col = env.allocate(); + let next_row_res_col_x = env.allocate_next_row(); + let next_row_res_col_y = env.allocate_next_row(); + let next_row_tmp_col_x = env.allocate_next_row(); + let next_row_tmp_col_y = env.allocate_next_row(); + let next_row_scalar_col = env.allocate_next_row(); + + // For the bit, we do have two cases. If we are processing the first + // bit, we must load the bit from a previous computed value. In the + // case of folding, it will be the output of the Poseidon hash. + // Therefore we do need the permutation argument. + // If it is not the first bit, we suppose the previous value has + // been written in the previous step in the current row. + let scalar = if processing_bit == 0 { + env.coin_folding_combiner(scalar_col) + } else { + env.read_position(scalar_col) + }; + // FIXME: we do add the blinder. We must substract it at the end. + // Perform the following algorithm (double-and-add): + // res = O <-- blinder + // tmp = P + // for i in 0..256: + // if r[i] == 1: + // res = res + tmp + // tmp = tmp + tmp + // + // - `processing_bit` is the i-th bit of the scalar, i.e. i in the loop + // described above. + // - `i_comm` is used to fetch the commitment. + // + // If it is the first bit, we must load the commitment using the + // permutation argument. + // FIXME: the initial value of the result should be a non-zero + // point. However, it must be a public value otherwise the prover + // might lie on the initial value. + // We do have 15 public inputs on each row, we could use some of them. + let (res_x, res_y) = if processing_bit == 0 { + // Load the commitment + unsafe { env.load_temporary_accumulators(res_col_x, res_col_y, Side::Right) } + } else { + // Otherwise, the previous step has written the previous result + // in its "next row", i.e. this row. + let res_x = { env.read_position(res_col_x) }; + let res_y = { env.read_position(res_col_y) }; + (res_x, res_y) + }; + // Same for the accumulated temporary value + let (tmp_x, tmp_y) = if processing_bit == 0 { + unsafe { env.load_temporary_accumulators(tmp_col_x, tmp_col_y, Side::Left) } + } else { + (env.read_position(tmp_col_x), env.read_position(tmp_col_y)) + }; + // Conditional addition: + // if bit == 1, then res = tmp + res + // else res = res + // First we compute tmp + res + // FIXME: we do suppose that res != tmp -> no doubling and no check + // if they are the same + // IMPROVEME: reuse elliptic curve addition + let (res_plus_tmp_x, res_plus_tmp_y) = { + let lambda = { + let pos = env.allocate(); + env.compute_lambda( + pos, + env.zero(), + tmp_x.clone(), + tmp_y.clone(), + res_x.clone(), + res_y.clone(), + ) + }; + // x3 = λ^2 - x1 - x2 + let x3 = { + let pos = env.allocate(); + let lambda_square = lambda.clone() * lambda.clone(); + let res = lambda_square.clone() - tmp_x.clone() - res_x.clone(); + env.write_column(pos, res) + }; + // y3 = λ (x1 - x3) - y1 + let y3 = { + let pos = env.allocate(); + let x1_minus_x3 = tmp_x.clone() - x3.clone(); + let res = lambda.clone() * x1_minus_x3.clone() - tmp_y.clone(); + env.write_column(pos, res) + }; + (x3, y3) + }; + // tmp = tmp + tmp + // Compute the double of the temporary value + // The slope is saved in a column created in the call to + // `double_ec_point` + // We ignore the result as it will be used at the next step only. + let (_double_tmp_x, _double_tmp_y) = { + env.double_ec_point( + next_row_tmp_col_x, + next_row_tmp_col_y, + tmp_x.clone(), + tmp_y.clone(), + ) + }; + let bit = { + let pos = env.allocate(); + unsafe { env.bitmask_be(&scalar, 1, 0, pos) } + }; + // Checking it is a boolean -> degree 2 + env.constrain_boolean(bit.clone()); + let next_scalar = { + unsafe { + env.bitmask_be( + &scalar, + MAXIMUM_FIELD_SIZE_IN_BITS.try_into().unwrap(), + 1, + next_row_scalar_col, + ) + } + }; + // Degree 1 + env.assert_equal( + scalar.clone(), + bit.clone() + env.constant(BigInt::from(2)) * next_scalar.clone(), + ); + let _x3 = { + let res = bit.clone() * res_plus_tmp_x.clone() + + (env.one() - bit.clone()) * res_x.clone(); + env.write_column(next_row_res_col_x, res) + }; + let _y3 = { + let res = bit.clone() * res_plus_tmp_y.clone() + + (env.one() - bit.clone()) * res_y.clone(); + env.write_column(next_row_res_col_y, res) + }; + } + Instruction::EllipticCurveAddition(i_comm) => { + env.activate_gadget(Gadget::EllipticCurveAddition); + assert!(i_comm < NUMBER_OF_COLUMNS, "Invalid index. We do only support the addition of the commitments to the columns, for now. We must additionally support the scaling of cross-terms and error terms"); + let (x1, y1) = { + let x1 = env.allocate(); + let y1 = env.allocate(); + unsafe { env.load_temporary_accumulators(x1, y1, Side::Left) } + }; + let (x2, y2) = { + let x2 = env.allocate(); + let y2 = env.allocate(); + unsafe { env.load_temporary_accumulators(x2, y2, Side::Right) } + }; + let is_same_point = { + let pos = env.allocate(); + unsafe { env.is_same_ec_point(pos, x1.clone(), y1.clone(), x2.clone(), y2.clone()) } + }; + let lambda = { + let pos = env.allocate(); + env.compute_lambda( + pos, + is_same_point, + x1.clone(), + y1.clone(), + x2.clone(), + y2.clone(), + ) + }; + // x3 = λ^2 - x1 - x2 + let x3 = { + let pos = env.allocate(); + let lambda_square = lambda.clone() * lambda.clone(); + let res = lambda_square.clone() - x1.clone() - x2.clone(); + env.write_column(pos, res) + }; + // y3 = λ (x1 - x3) - y1 + { + let pos = env.allocate(); + let x1_minus_x3 = x1.clone() - x3.clone(); + let res = lambda.clone() * x1_minus_x3.clone() - y1.clone(); + env.write_column(pos, res) + }; + } + Instruction::Poseidon(curr_round) => { + env.activate_gadget(Gadget::Poseidon); + debug!("Executing instruction Poseidon({curr_round})"); + if curr_round < POSEIDON_ROUNDS_FULL { + // Values to be absorbed are 0 when when the round is not zero, + // i.e. when we are processing the rounds. + let values_to_absorb: Vec = (0..POSEIDON_STATE_SIZE - 1) + .map(|_i| { + let pos = env.allocate_public_input(); + // fetch_value_to_absorb is supposed to return 0 if curr_round != 0. + unsafe { env.fetch_value_to_absorb(pos, curr_round) } + }) + .collect(); + let round_input_positions: Vec = + (0..POSEIDON_STATE_SIZE).map(|_i| env.allocate()).collect(); + let round_output_positions: Vec = (0..POSEIDON_STATE_SIZE) + .map(|_i| env.allocate_next_row()) + .collect(); + // If we are at the first round, we load the state from the environment. + // The permutation argument is used to load the state the + // current call to Poseidon might be a succession of Poseidon + // calls, like when we need to hash the public inputs, and the + // state might be from a previous place in the execution trace. + let state: Vec = if curr_round == 0 { + round_input_positions + .iter() + .enumerate() + .map(|(i, pos)| { + let res = env.load_poseidon_state(*pos, i); + // Absorb value. The capacity is POSEIDON_STATE_SIZE - 1 + if i < POSEIDON_STATE_SIZE - 1 { + res + values_to_absorb[i].clone() + } else { + res + } + }) + .collect() + } else { + // Otherwise, as we do use the "next row" trick, the current + // state has been loaded in the "next_row" state during the + // previous call, and we can simply load it. No permutation + // argument needed. + round_input_positions + .iter() + .map(|pos| env.read_position(*pos)) + .collect() + }; + + // 5 is the number of rounds we treat per row + (0..5).fold(state, |state, idx_round| { + let state: Vec = + state.iter().map(|x| env.compute_x5(x.clone())).collect(); + + let round = curr_round + idx_round; + + let rcs: Vec = (0..POSEIDON_STATE_SIZE) + .map(|i| { + let pos = env.allocate_public_input(); + env.get_poseidon_round_constant(pos, round, i) + }) + .collect(); + + let state: Vec = rcs + .iter() + .enumerate() + .map(|(i, rc)| { + let acc: E::Variable = + state.iter().enumerate().fold(env.zero(), |acc, (j, x)| { + acc + env.get_poseidon_mds_matrix(i, j) * x.clone() + }); + // The last iteration is written on the next row. + if idx_round == 4 { + env.write_column(round_output_positions[i], acc + rc.clone()) + } else { + // Otherwise, we simply allocate a new position + // in the circuit. + let pos = env.allocate(); + env.write_column(pos, acc + rc.clone()) + } + }) + .collect(); + // If we are at the last round, we save the state in the + // environment. + // FIXME/IMPROVEME: we might want to execute more Poseidon + // full hash in sequentially, and then save one row. For + // now, we will save the state at the end of the last round + // and reload it at the beginning of the next Poseidon full + // hash. + if round == POSEIDON_ROUNDS_FULL - 1 { + state.iter().enumerate().for_each(|(i, x)| { + unsafe { env.save_poseidon_state(x.clone(), i) }; + }); + env.reset(); + }; + state + }); + } else { + panic!("Invalid index: it is supposed to be less than {POSEIDON_ROUNDS_FULL}"); + } + } + Instruction::NoOp => {} + } + + // Compute the hash of the public input + // FIXME: add the verification key. We should have a hash of it. +} diff --git a/arrabiata/src/lib.rs b/arrabiata/src/lib.rs new file mode 100644 index 0000000000..e43a1b2fa0 --- /dev/null +++ b/arrabiata/src/lib.rs @@ -0,0 +1,67 @@ +use strum::EnumCount as _; + +pub mod column_env; +pub mod columns; +pub mod constraints; +pub mod interpreter; +pub mod logup; +pub mod poseidon_3_60_0_5_5_fp; +pub mod poseidon_3_60_0_5_5_fq; +pub mod proof; +pub mod prover; +pub mod verifier; +pub mod witness; + +/// The maximum degree of the polynomial that can be represented by the +/// polynomial-time function the library supports. +pub const MAX_DEGREE: u64 = 5; + +/// The minimum SRS size required to use Nova, in base 2. +/// Requiring at least 2^16 to perform 16bits range checks. +pub const MIN_SRS_LOG2_SIZE: usize = 16; + +/// The number of rows the IVC circuit requires. +// FIXME: that might change. We use a vertical layout for now. +pub const IVC_CIRCUIT_SIZE: usize = 1 << 13; + +/// The maximum number of columns that can be used in the circuit. +pub const NUMBER_OF_COLUMNS: usize = 15; + +/// The maximum number of public inputs the circuit can use per row +/// We do have 15 for now as we want to compute 5 rounds of poseidon per row +/// using the gadget [crate::columns::Gadget::Poseidon]. In addition to +/// the 12 public inputs required for the rounds, we add 2 more for the values +/// to absorb. +pub const NUMBER_OF_PUBLIC_INPUTS: usize = 15 + 2; + +/// The low-exponentiation value used by the Poseidon hash function for the +/// substitution box. +/// +/// The value is used to perform initialisation checks with the fields. +pub const POSEIDON_ALPHA: u64 = 5; + +/// The number of full rounds in the Poseidon hash function. +pub const POSEIDON_ROUNDS_FULL: usize = 60; + +/// The number of elements in the state of the Poseidon hash function. +pub const POSEIDON_STATE_SIZE: usize = 3; + +/// The maximum number of bits the fields can be. +/// It is critical as we have some assumptions for the gadgets describing the +/// IVC. +pub const MAXIMUM_FIELD_SIZE_IN_BITS: u64 = 255; + +/// Define the number of values we must absorb when computating the hash to the +/// public IO. +/// +/// FIXME: +/// For now, it is the number of columns as we are only absorbing the +/// accumulators, which consists of 2 native field elements. However, it doesn't +/// make the protocol sound. We must absorb, in addition to that the index, +/// the application inputs/outputs. +/// It is left for the future as at this time, we're still sketching the IVC +/// circuit. +pub const NUMBER_OF_VALUES_TO_ABSORB_PUBLIC_IO: usize = NUMBER_OF_COLUMNS * 2; + +/// The number of selectors used in the circuit. +pub const NUMBER_OF_SELECTORS: usize = columns::Gadget::COUNT; diff --git a/arrabiata/src/logup.rs b/arrabiata/src/logup.rs new file mode 100644 index 0000000000..2b8e3d6f05 --- /dev/null +++ b/arrabiata/src/logup.rs @@ -0,0 +1 @@ +//! This file will implement a logup argument to allow users performing lookup in their circuits. diff --git a/arrabiata/src/main.rs b/arrabiata/src/main.rs new file mode 100644 index 0000000000..5d68a9ff6b --- /dev/null +++ b/arrabiata/src/main.rs @@ -0,0 +1,140 @@ +use arrabiata::{ + interpreter::{self, InterpreterEnv}, + witness::Env, + IVC_CIRCUIT_SIZE, MIN_SRS_LOG2_SIZE, POSEIDON_STATE_SIZE, +}; +use log::{debug, info}; +use mina_curves::pasta::{Fp, Fq, Pallas, Vesta}; +use num_bigint::BigInt; +use std::time::Instant; + +pub fn main() { + // See https://github.com/rust-lang/log + env_logger::init(); + + let arg_n = + clap::arg!(--"n" "Number of iterations").value_parser(clap::value_parser!(u64)); + + let arg_srs_size = clap::arg!(--"srs-size" "Size of the SRS in base 2") + .value_parser(clap::value_parser!(usize)); + + let cmd = clap::Command::new("cargo") + .bin_name("cargo") + .subcommand_required(true) + .subcommand( + clap::Command::new("square-root") + .arg(arg_n) + .arg(arg_srs_size) + .arg_required_else_help(true), + ); + let matches = cmd.get_matches(); + let matches = match matches.subcommand() { + Some(("square-root", matches)) => matches, + _ => unreachable!("clap should ensure we don't get here"), + }; + let n_iteration = matches.get_one::("n").unwrap(); + let srs_log2_size = matches + .get_one::("srs-size") + .unwrap_or(&MIN_SRS_LOG2_SIZE); + + assert!( + *srs_log2_size >= MIN_SRS_LOG2_SIZE, + "SRS size must be at least 2^{MIN_SRS_LOG2_SIZE} to support IVC" + ); + + info!("Instantiating environment to execute square-root {n_iteration} times with SRS of size 2^{srs_log2_size}"); + + let domain_size = 1 << srs_log2_size; + + // FIXME: setup correctly the initial sponge state + let sponge_e1: [BigInt; POSEIDON_STATE_SIZE] = std::array::from_fn(|_i| BigInt::from(42u64)); + // FIXME: make a setup phase to build the selectors + let mut env = Env::::new( + *srs_log2_size, + BigInt::from(1u64), + sponge_e1.clone(), + sponge_e1.clone(), + ); + + let n_iteration_per_fold = domain_size - IVC_CIRCUIT_SIZE; + + while env.current_iteration < *n_iteration { + let start_iteration = Instant::now(); + + info!("Run iteration: {}/{}", env.current_iteration, n_iteration); + + // Build the application circuit + info!("Running N iterations of the application circuit"); + for _i in 0..n_iteration_per_fold { + interpreter::run_app(&mut env); + env.reset(); + } + + info!("Building the IVC circuit"); + // Build the IVC circuit + for _i in 0..IVC_CIRCUIT_SIZE { + let instr = env.fetch_instruction(); + interpreter::run_ivc(&mut env, instr); + env.current_instruction = env.fetch_next_instruction(); + env.reset(); + } + + debug!( + "Witness for iteration {i} computed in {elapsed} μs", + i = env.current_iteration, + elapsed = start_iteration.elapsed().as_micros() + ); + + // FIXME: + // update current instance with the previous "next" commitments (i.e. + // env.next_commitments) + // update next instance with current commitments + // FIXME: Check twice the updated commitments + env.compute_and_update_previous_commitments(); + + // FIXME: + // Absorb all commitments in the sponge. + + // FIXME: + // Coin chalenges β and γ for the permutation argument + + // FIXME: + // Compute the accumulator for the permutation argument + + // FIXME: + // Compute the cross-terms + + // FIXME: + // Absorb the cross-terms + + // FIXME: + // Coin challenge r to fold + + // FIXME: + // Compute the accumulated witness + + // FIXME: + // Compute the accumulation of the commitments to the witness columns + + // FIXME: + // Compute the accumulation of the challenges + + // FIXME: + // Compute the accumulation of the public inputs/selectors + + // FIXME: + // Compute the accumulation of the blinders for the PCS + + // FIXME: + // Compute the accumulated error + + debug!( + "Iteration {i} fully proven in {elapsed} μs", + i = env.current_iteration, + elapsed = start_iteration.elapsed().as_micros() + ); + + env.reset_for_next_iteration(); + env.current_iteration += 1; + } +} diff --git a/arrabiata/src/poseidon_3_60_0_5_5_fp.rs b/arrabiata/src/poseidon_3_60_0_5_5_fp.rs new file mode 100644 index 0000000000..296302b1e1 --- /dev/null +++ b/arrabiata/src/poseidon_3_60_0_5_5_fp.rs @@ -0,0 +1,903 @@ +use mina_curves::pasta::Fp; +use mina_poseidon::poseidon::ArithmeticSpongeParams; +use once_cell::sync::Lazy; + +/* Generated by params.sage */ + +use std::str::FromStr; + +fn params() -> ArithmeticSpongeParams { + ArithmeticSpongeParams { + mds: vec![ + vec![ + Fp::from_str( + "17388788707812278340106653511601894605475912579070132834621611278702208069948", + ) + .unwrap(), + Fp::from_str( + "17484584120788581687009266825661802806046812681695930884887739179865612965127", + ) + .unwrap(), + Fp::from_str( + "20222273432119919392686413983240325343169175238980369682703494201925192338899", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "14793820945145615522977558374530426960607001062183458732338387735995846367929", + ) + .unwrap(), + Fp::from_str( + "24787707239600295030700826184349996599183995839090051611212698450493462645188", + ) + .unwrap(), + Fp::from_str( + "11771148817784101747527405739120145967567678238800527722783086805798857719651", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "16508716375116042997058036529670318392110322034848187961172707099352209518970", + ) + .unwrap(), + Fp::from_str( + "15739116651395406608562600713426271836308703424634565895791987909508466043243", + ) + .unwrap(), + Fp::from_str( + "28623605191971738918745661811750738408184674297522199948691317659919248886550", + ) + .unwrap(), + ], + ], + round_constants: vec![ + vec![ + Fp::from_str( + "1903562405400753576949243515269615285116054102783580876258645693850149702661", + ) + .unwrap(), + Fp::from_str( + "12145461980598517942900142457759733122151931961026478643201020186654724736651", + ) + .unwrap(), + Fp::from_str( + "14420796192651222913546080713404272156939957726765855389858995038789820001804", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "21517370503810233496366468547482074588221727989678744286247330878704544316383", + ) + .unwrap(), + Fp::from_str( + "15225969289722373268482465496656569076419025051705782716566389303532217084020", + ) + .unwrap(), + Fp::from_str( + "8846000659801689151731760047733045550865254271798157362719412362663614394491", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "18935097269079670283337692690687684251434263151180779624497759807708536797425", + ) + .unwrap(), + Fp::from_str( + "11880535611248806296873834922322963716995235558587915595960665294503164390775", + ) + .unwrap(), + Fp::from_str( + "4966952290863524466749992685720848103165088080154965644116548363728505924898", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "5695601123099377308535230161063260357462084030710977307989288026703494805867", + ) + .unwrap(), + Fp::from_str( + "19358923941951882694765531912678559430342446235206141930368077372251272883095", + ) + .unwrap(), + Fp::from_str( + "23755683990903955379102881242285720263800457546329246011571293009879382037324", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "4614529323886042491467258281398251320543001635932884238869706056505001561622", + ) + .unwrap(), + Fp::from_str( + "18178687842435997804040286122580760454724076239014435481522060920560363791313", + ) + .unwrap(), + Fp::from_str( + "28707081058841709750305029757116229949752540572908874123074300632976262675560", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "26513451653917043248828666298783032866650283517004393405152662808871544852578", + ) + .unwrap(), + Fp::from_str( + "637856924048776362122688101401335178593824113300752902740192957438857482880", + ) + .unwrap(), + Fp::from_str( + "27061302593749942872613984145427359044280931099730672299593145016397559316854", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "17086541175064333656429923353437749638367197672441131421818336471076792274417", + ) + .unwrap(), + Fp::from_str( + "22378823921446696593158648054751403544895402815949246849930871622932273897262", + ) + .unwrap(), + Fp::from_str( + "7471606467058823404243598955192476956109765276588361711436869278415715685555", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "3928557571078601542767907007995998918860813953786788104700668054202369027920", + ) + .unwrap(), + Fp::from_str( + "1622978988680538412469464393021497104205344379488036513691059738730485529384", + ) + .unwrap(), + Fp::from_str( + "25348298722272181188944604553143133888199723971524975712225477416979636153809", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "20022907445604144480168608448347765614919496369114903389791611894930600639457", + ) + .unwrap(), + Fp::from_str( + "3481377079329113165496017180069328360206210042097594787790118481438395375457", + ) + .unwrap(), + Fp::from_str( + "21166609738526710164626635462319727216177558537946938784205031450271659242855", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "8455925034386545871718277080905034740388604003477849955924129416748089646116", + ) + .unwrap(), + Fp::from_str( + "26773441169691515057846574967585859574582109639578915243928785569726140606553", + ) + .unwrap(), + Fp::from_str( + "2913159669517815269065793193229148371685518752834788468451989637044410562890", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "13733055659842873083929035789685705127937709073081728761593630538541660386631", + ) + .unwrap(), + Fp::from_str( + "7094841152246788883784693079790542928345429024501849515338143672220402116718", + ) + .unwrap(), + Fp::from_str( + "23335458498693263537039637624930590026964628612174283981996523395533970491788", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "8323603019831115078884312497677083774960255504588814695508394138083910737764", + ) + .unwrap(), + Fp::from_str( + "10889037464529438880594665188490627457838577370305053321795090351671849480580", + ) + .unwrap(), + Fp::from_str( + "19741077166253609781036624985094621933484635487186919619811795992181317229193", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "8667395531522799146020035900615581833701541904060363209924245474697787860767", + ) + .unwrap(), + Fp::from_str( + "23480515713546465861699304270258279318295610326421313287700256087428145719346", + ) + .unwrap(), + Fp::from_str( + "27665947217093168220707439172668384413890100764912309430345254226832078720314", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "4639795971462048070973072285594933544611389371813424479637597322647960287645", + ) + .unwrap(), + Fp::from_str( + "18948400171613058934992563839850391566413075429250683972497145999530094050320", + ) + .unwrap(), + Fp::from_str( + "28279109981622486378191729952830852562479044052286399608321497750394545699625", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "2092038810157687026205598249212363846030222883362654794505008774922109661877", + ) + .unwrap(), + Fp::from_str( + "1794105537573874633098411080527826917947521521262259715828763077454147743552", + ) + .unwrap(), + Fp::from_str( + "28398830565043683871884394691026125204053425054457912681111511801642616450943", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "1473064465334694881231726328414560915880003310324779851459193409409641831424", + ) + .unwrap(), + Fp::from_str( + "13921517838484129412715278616815809758789091105166071016229326715660827585214", + ) + .unwrap(), + Fp::from_str( + "4389330634029172310373488087578956450630142427385288782494595113779386643702", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "27122272032648674069752689649732648302114648582621981437921265216529832360222", + ) + .unwrap(), + Fp::from_str( + "3288061546803754191885978583730966733886381991277253711650695644030239520860", + ) + .unwrap(), + Fp::from_str( + "21341688199305617719670692377730177476913828575425542077698540153025069415671", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "8218182644930435745666848816651777405577055711696681086133242403956009509291", + ) + .unwrap(), + Fp::from_str( + "12459324410302708241422860670946935589758584065477332528893821612381418188320", + ) + .unwrap(), + Fp::from_str( + "20727739665318638529185321321383253409564642610127652500624694963112552024769", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "23892261917575116777187905858809323351280923458863739792394583728503471961761", + ) + .unwrap(), + Fp::from_str( + "17683465006569325745513034545571120190469151492225906086935445748635807917924", + ) + .unwrap(), + Fp::from_str( + "21870378779613474342041808692670827012911667706358950547255098894249690046288", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "12506368118150834468930194117868897693463472763327986261447720462304069000570", + ) + .unwrap(), + Fp::from_str( + "16363070805461420960870543541744870758454731928475945586355394093468060430399", + ) + .unwrap(), + Fp::from_str( + "1943516193735799009739607039743724089512930723510180083535989879731754316273", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "7129998627870285953220701383518862449554741396274635204771810951731053784770", + ) + .unwrap(), + Fp::from_str( + "15664058584028571395004373641450989214161646789967183904969389720357082283250", + ) + .unwrap(), + Fp::from_str( + "20552066943271163327442451767659532197126635966059277605534778781484138633139", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "837384561060063770502319677871011758112768694414166390079328653082931386489", + ) + .unwrap(), + Fp::from_str( + "3704263561344746110064913824500321149484804257881233835698811910531316541086", + ) + .unwrap(), + Fp::from_str( + "6393900199557481719092294199323547003374571791762662902261426564763652369401", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "15279200860211060844541067991654350499760515939535039833307994223016429031915", + ) + .unwrap(), + Fp::from_str( + "26985260264072804233457332565625452265819057118175712295039599663785867528686", + ) + .unwrap(), + Fp::from_str( + "23762344035136554592534657492259843646585644687247859154099686371337251180558", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "28629412652946773219137169933363498238539857584436327956640824482910132247750", + ) + .unwrap(), + Fp::from_str( + "28534296835085804577921789342106129297870492252602947746397766207330777800687", + ) + .unwrap(), + Fp::from_str( + "7010171840415685376889067428592526835996702763760124252243310579212924906336", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "15115877507769307865100373437455075067963362433824296860517549375604509199077", + ) + .unwrap(), + Fp::from_str( + "27322038485974585090796976492935612725270747730646246142287058395919972668833", + ) + .unwrap(), + Fp::from_str( + "12545303079405384360051990119898298947201962285928865565849006166887510681847", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "22402229074453220589136435699079429777640817397671403381476785296372580549858", + ) + .unwrap(), + Fp::from_str( + "9901365148685926116427078767190546933876050085091579016922595079995750903884", + ) + .unwrap(), + Fp::from_str( + "24337498889894620917237382701526353977081850869701385679164276459683148226638", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "12082441852836765995894129255249968009151939521235506465636962125135332440553", + ) + .unwrap(), + Fp::from_str( + "13591394560787149077761220569421990892747904559677202394178297433750401517295", + ) + .unwrap(), + Fp::from_str( + "820346578826664446967541661807671907060311534721583510850758219254789063525", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "24215695518459878742143068844160846527323444039971532300823464386645604087000", + ) + .unwrap(), + Fp::from_str( + "13101964695384096763097324862644947302557977695504072497333846057745440467108", + ) + .unwrap(), + Fp::from_str( + "24051759082465526600028122824371590042247611846854441985088350490205244474164", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "14134815818472985968740149708977389980339449561590100111958140232669752667785", + ) + .unwrap(), + Fp::from_str( + "15403703790366432688624212690576313146300536172390325431587338971565850239268", + ) + .unwrap(), + Fp::from_str( + "3860940122072688789678493869688816836883700739879926728493744350258318418970", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "25455283782061227089561784456070939721429904550086416675926941481988389611408", + ) + .unwrap(), + Fp::from_str( + "26977357393229872205773968079088763115741561597073371811262615331984270947870", + ) + .unwrap(), + Fp::from_str( + "24834730534015581468734043243902112857058318830375336906511622322355136736596", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "27652398859855980013005741036309435387669937475210024717304846274535246837370", + ) + .unwrap(), + Fp::from_str( + "16434196447540864286943770626416181484651278178598979685376365171138289731486", + ) + .unwrap(), + Fp::from_str( + "7922731384656857738011738126442917481655592952969291616357780117919476015055", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "15709070462558549336491348992414374370973707720054659786530259503570602246228", + ) + .unwrap(), + Fp::from_str( + "28337292776941413201036572795861511102372227244242970132891465942851528836861", + ) + .unwrap(), + Fp::from_str( + "25426460061895780663299560424657402107131480564650073175026211502391078663882", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "2625713925773366883469271869689397545542254691907461055484428506278119836700", + ) + .unwrap(), + Fp::from_str( + "18893574847900839380547868476930946750088960440844419973918601764038727814681", + ) + .unwrap(), + Fp::from_str( + "766631282190587378177611951477031021550773640264111386671837572118633719809", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "13137601490246132986137466920796584607886655745019808761526435041805205717317", + ) + .unwrap(), + Fp::from_str( + "18259286816093077983468656656237071231169868820742465348544118978069177543570", + ) + .unwrap(), + Fp::from_str( + "26757597187175444885193670853080212204940693683428672422895009364630049028206", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "15653769708758819121178811259026600787053993280353635479680029218091331593417", + ) + .unwrap(), + Fp::from_str( + "13419246943959211329987723815853581344931700894862789561848126275923663717128", + ) + .unwrap(), + Fp::from_str( + "8246136206295335525843735689248256487311631616664947617631250251488088775292", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "4648299921946646012251054417595794331655655293458746215060084503543519513637", + ) + .unwrap(), + Fp::from_str( + "9743757880994843874328260661480407066979773294827260780036126204020122446689", + ) + .unwrap(), + Fp::from_str( + "28848155389668727551633588553511280530765255045419419435004024289027914544446", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "10709971516224634369084875632690920087873994425819305220522050064040270914692", + ) + .unwrap(), + Fp::from_str( + "6406328864104874055055834491697206998962575444374922213291036274060410244732", + ) + .unwrap(), + Fp::from_str( + "8609244662654888009367200669994278401918186336049196594307540095574503942361", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "7941068229552817686752437248744552041703010226226513022642595220249577488302", + ) + .unwrap(), + Fp::from_str( + "14600142520093224163990036885240402551895887296526851496250352494620360954106", + ) + .unwrap(), + Fp::from_str( + "23081105723211054332615217306201375022785160149300992361133036123950858272161", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "25145386407108334893663917883742050070744522677607944262416596312572418856110", + ) + .unwrap(), + Fp::from_str( + "22415032089702588589943876844173210430005217493899276012184671097877647559690", + ) + .unwrap(), + Fp::from_str( + "25794838706252143228475932981269875142662580183382355816219518241697349965126", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "28280271973628642728204240619577482766879646721263904790314880789683284786153", + ) + .unwrap(), + Fp::from_str( + "22931262827100816124759692095131386101929469695656151272643079155600267780221", + ) + .unwrap(), + Fp::from_str( + "1565584453094019973482028442040272288669276571682674054524209013569923578826", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "6176119092619641909755543614365503575755823782036505145363923576352647502270", + ) + .unwrap(), + Fp::from_str( + "14718283097048302858008219596481145523858456878760570559346085512848406401803", + ) + .unwrap(), + Fp::from_str( + "14026593842054434070610926143877468309529731518748953977793091018486749904942", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "28204851003893693047702499213921311944599169390101794939028138808109049040591", + ) + .unwrap(), + Fp::from_str( + "5016753140288828986502233151126294971871174105161700925754095820220266223892", + ) + .unwrap(), + Fp::from_str( + "7973403363737358149457257297277923114511617066406497494746386837087163674642", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "19072245656175440040557406435572739595903535117323585374950538753629903559860", + ) + .unwrap(), + Fp::from_str( + "5895955315480535531190434880350785096314979368725672362470113474967721384904", + ) + .unwrap(), + Fp::from_str( + "15361544953620210040531548559776865378252632698543079305928467170054695183446", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "651504417665181682948041922247515404791725230006382748879704277725664764202", + ) + .unwrap(), + Fp::from_str( + "10424974975578578024653985758412888912364191430425514312351050189523438935518", + ) + .unwrap(), + Fp::from_str( + "2704357630034530476415155083392819826297328990599197368573629854494696107510", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "9967165781104268851556312839513912991882679874086243203063983754479587202683", + ) + .unwrap(), + Fp::from_str( + "4521829700958821910734707941695050304452742079173589559101469846854042608545", + ) + .unwrap(), + Fp::from_str( + "2998903411973640806805465273398154562735583964536850253848556867756388980285", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "26429477232785835851298220245960275967938738628597310655820557830242665904855", + ) + .unwrap(), + Fp::from_str( + "21678436092012596248378956354730602656351869767989635487409918503988290184483", + ) + .unwrap(), + Fp::from_str( + "24063665784074951477368898210440817717884477863516236933384693405846596412482", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "25430143753711785859579965731150092218728133630080931796591554990359469220658", + ) + .unwrap(), + Fp::from_str( + "26815274039242379661454506747661568656069040162204136102861214910454990923904", + ) + .unwrap(), + Fp::from_str( + "3430118456678344681897851157335683628162555912328852095953534308373439660104", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "2981274969203350321903730794067603535094244032208803309667813208952049092596", + ) + .unwrap(), + Fp::from_str( + "19651230574896516182887662893630354490594265491835883047458795623880643881162", + ) + .unwrap(), + Fp::from_str( + "16094332558630876779073727324702015599602304069749665753669556130114772373413", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "19601749824025963954313915092665025360520289856101185686338038789042709250125", + ) + .unwrap(), + Fp::from_str( + "23327888594672849870327674644151087157611742523569944071966728815429929652938", + ) + .unwrap(), + Fp::from_str( + "8116745397777693260858870181009677972224760693455022816265297486527432945045", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "4137571281002932521130058362546115360865156001608021813820779022979489660619", + ) + .unwrap(), + Fp::from_str( + "11761622932288686989130933541784121202686598493467608591293488578255939501435", + ) + .unwrap(), + Fp::from_str( + "1477471281573077920490967814842546296206517201272066092463513135414201737665", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "13213470188615826068959233674301632840131326862016240364178432423050398911829", + ) + .unwrap(), + Fp::from_str( + "8292570860382474814437956284198549626305083694280495916838765504191531215719", + ) + .unwrap(), + Fp::from_str( + "27752166055295315896861968133714816463396127145387525431061852199916374334174", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "25286350190750659331352050652212587441953337718070847188880599248224021669632", + ) + .unwrap(), + Fp::from_str( + "16498832049600508872577329754971814475950970528852711408876393746284042426939", + ) + .unwrap(), + Fp::from_str( + "7392191762292840825729788948558830929865673368941777854236773613330343865068", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "21249495110129175388076546800063944487064178143642566808220266761214562773750", + ) + .unwrap(), + Fp::from_str( + "7465463298790484828954279498789998238969853918772021592217194784214120853718", + ) + .unwrap(), + Fp::from_str( + "13946746439585416804412368608158016335742973976921109234593198399138179526861", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "14793765608310114884548507442019029529042234633561254614180297773338630723865", + ) + .unwrap(), + Fp::from_str( + "5228748839064876489771894388718869363423171800475936916385717054387856756818", + ) + .unwrap(), + Fp::from_str( + "6229807347296064570491153926273834011507648416709272963104385150388106445223", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "26042678024105467318614085746031057804729215620050706867703802616290481588552", + ) + .unwrap(), + Fp::from_str( + "9053451132804987755572468447657406503881013072149560732536522060099876403552", + ) + .unwrap(), + Fp::from_str( + "14110853126804402715907233830338691719977501655861749631604908726533408970320", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "14288656666085023020088118422129853486594106622284724092353512077848915125114", + ) + .unwrap(), + Fp::from_str( + "1212682384116502438654433042071647615118591089203880808333382883873619921662", + ) + .unwrap(), + Fp::from_str( + "5988899602387758330558731670436738270751794428407170408962179005119993219587", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "1434317295830874756899512684995913826508209647433651974750835573907602685329", + ) + .unwrap(), + Fp::from_str( + "15203591931895943523011062306208531325016707325495264405003917136727119291747", + ) + .unwrap(), + Fp::from_str( + "23086489036620293939736469046064122088358371666195380813522996849945295023336", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "24105037587019929735298246017974918300824784678105378543277326520108773515634", + ) + .unwrap(), + Fp::from_str( + "12189227049757868086225331406747956954270474255091400827432756832961798683570", + ) + .unwrap(), + Fp::from_str( + "21952062283147919705092413713671635544216290792441007949439117937492432437237", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "17041683802932506063779028216158120947310712754022635110675172523772989287677", + ) + .unwrap(), + Fp::from_str( + "20042557027275102787522087348982852275244256400725639332062987749821748414919", + ) + .unwrap(), + Fp::from_str( + "13492385017578082290900892131214249854374840588907466293618874316679735484809", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "16869753134138798313943079259853823216757998094193096682584190132057006154113", + ) + .unwrap(), + Fp::from_str( + "16799687027512627515219958217734359517253987866039702314209148477371706686085", + ) + .unwrap(), + Fp::from_str( + "27812183967136745156253055220651069728071975872515090894475195502138561993819", + ) + .unwrap(), + ], + ], + } +} + +pub fn static_params() -> &'static ArithmeticSpongeParams { + static PARAMS: Lazy> = Lazy::new(params); + &PARAMS +} diff --git a/arrabiata/src/poseidon_3_60_0_5_5_fq.rs b/arrabiata/src/poseidon_3_60_0_5_5_fq.rs new file mode 100644 index 0000000000..2d284292a7 --- /dev/null +++ b/arrabiata/src/poseidon_3_60_0_5_5_fq.rs @@ -0,0 +1,903 @@ +use mina_curves::pasta::Fq; +use mina_poseidon::poseidon::ArithmeticSpongeParams; +use once_cell::sync::Lazy; + +/* Generated by params.sage */ + +use std::str::FromStr; + +fn params() -> ArithmeticSpongeParams { + ArithmeticSpongeParams { + mds: vec![ + vec![ + Fq::from_str( + "13841249378323021054377890016443265503914469864575312305567855128857444259487", + ) + .unwrap(), + Fq::from_str( + "13474541630870351706035366868551274745344722775965283971488625444792482938921", + ) + .unwrap(), + Fq::from_str( + "12151399309130104322934698634832550272633145452940229406831079937203453962655", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "23740172672243477188057832304277375136172708212574345504502655344614455663606", + ) + .unwrap(), + Fq::from_str( + "21546379895052337509105835711466137751271439534383664473903037983361236182649", + ) + .unwrap(), + Fq::from_str( + "23066889761122039690865858791355498392807739602525834512710419267656039395072", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "21257914241847355181936015256361686342208107590210872866396751638657186725269", + ) + .unwrap(), + Fq::from_str( + "26067439960311266372107947555750482427352889677249991931425889311346331301244", + ) + .unwrap(), + Fq::from_str( + "12928079059102104389399617486508810710959599542061627402082516871089532252204", + ) + .unwrap(), + ], + ], + round_constants: vec![ + vec![ + Fq::from_str( + "26670984196527839467673167978149437755844366956321273026656650542011198272197", + ) + .unwrap(), + Fq::from_str( + "26013004827491165884682917918428377897255879902000375849310858738009760470562", + ) + .unwrap(), + Fq::from_str( + "5420685313310124987480400044757399826914052519997427998887189691416232403245", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "7380609437725388759172027592963234162853451534319456286538277978633614936971", + ) + .unwrap(), + Fq::from_str( + "9966232995568830110180758641581023485115951463148554382368561029317340036027", + ) + .unwrap(), + Fq::from_str( + "28754092216755158675001416045004197886112586989030713157273469167863632395356", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "28357886559269113664297190499547194903533546901013467270751349984975123919000", + ) + .unwrap(), + Fq::from_str( + "19730820496452085203255679263480518276180414649569269459731768952000235952971", + ) + .unwrap(), + Fq::from_str( + "6101012598333622734057793807254456227076559878860659397169023699671171625527", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "19354296086288079144159688795327823522398391728875212438965010638655375455562", + ) + .unwrap(), + Fq::from_str( + "963886109865447521127743832003203330310842182698308283277967974115778110471", + ) + .unwrap(), + Fq::from_str( + "4281563786504771719885397038176034931017772107152448013142531584031665653873", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "22434977466990605082277346891649372581276997387643726171473862867763924117097", + ) + .unwrap(), + Fq::from_str( + "13174550665662255841009687929042395185477189619407134958480174063404635971634", + ) + .unwrap(), + Fq::from_str( + "26005789312462815321328498182081852420665006828682992304348915245386398610873", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "14261478113012663354046752356759087159765688199500179875954642395193867447371", + ) + .unwrap(), + Fq::from_str( + "5319741634899083444347353099902602066490888021034732217012556056837111902901", + ) + .unwrap(), + Fq::from_str( + "24281416757910026946666445460917181064060869667150825997973925467333387775115", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "15416895049592059092638358872355212898998998908558794698876839869683256456004", + ) + .unwrap(), + Fq::from_str( + "15390253357931539610194542193095570692890805122713279610257215942935464374719", + ) + .unwrap(), + Fq::from_str( + "20789588747734416786711125246746886476076139970160040037960650024173751844025", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "19268124713110905659901812882050400279035919397728249325940775594990663211110", + ) + .unwrap(), + Fq::from_str( + "12195786313905509291107072431512687605919725851081638502151678111133621832441", + ) + .unwrap(), + Fq::from_str( + "10996797254442470598461241872576533474264956098948895636942413022641449567843", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "16031386136507188699582593290815695992163943378269473017916874103066734548603", + ) + .unwrap(), + Fq::from_str( + "26393086215631991843496465996980424738477421979529423684963588301934817887804", + ) + .unwrap(), + Fq::from_str( + "10020908459421884208513359412128136167359628801846395421852688157171038382652", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "2275674645721881281344332559044371002228696505945826386093638383129626987820", + ) + .unwrap(), + Fq::from_str( + "20871954636987436302289348956897439215309251875529050679580127597263497272022", + ) + .unwrap(), + Fq::from_str( + "2805912099911037452690852710034714538324835563618608205533675796786012172540", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "12915788094434266346171730509109106471648198000850852516599691611952366277595", + ) + .unwrap(), + Fq::from_str( + "10455858116745097341023546462392222663807393446470812029549928814838698828197", + ) + .unwrap(), + Fq::from_str( + "18548429514862373104120945125763670495246981395802510059301488453312220138142", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "15373116094599762914903524023093729273461654612461234960666295498677422349953", + ) + .unwrap(), + Fq::from_str( + "3525617356615259862195106796336657658007178396368139338747079341939583684666", + ) + .unwrap(), + Fq::from_str( + "12023105750813777481149114228471616659998426860941299094537298344838126075345", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "26587730386641627909670986517353025269633372837450910451825098440250303207114", + ) + .unwrap(), + Fq::from_str( + "9074270097550873666867300148771600909014032474714272120531292635080024578351", + ) + .unwrap(), + Fq::from_str( + "16615827041072034692314621590279489654646723822721197974770794031303667078327", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "20461572479624705365303610120753007669488016680882759421045733709021889519289", + ) + .unwrap(), + Fq::from_str( + "26965967315571780826140179998689464673459701276445944370512797087644375234978", + ) + .unwrap(), + Fq::from_str( + "28024057220417802371648952762950742790050995995116405638415899316978666577934", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "21950318517032882689049872365078159816107931788675573974069114493584495431253", + ) + .unwrap(), + Fq::from_str( + "19426502278378560572857838763658840255394172603117177227007593438933781199090", + ) + .unwrap(), + Fq::from_str( + "17109968150768182960870375097797329526321434967783065142585178988441418255970", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "4900697972992083407198577086201839674947019834151233581099269472024415286865", + ) + .unwrap(), + Fq::from_str( + "21257266248810969336901990518010397339194787485618799904331928858369131592715", + ) + .unwrap(), + Fq::from_str( + "7853338111392902876480479512761358182087148519956105729210068733325256641242", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "14953731463927686511781103109299419908680866708971902685494637646180828533194", + ) + .unwrap(), + Fq::from_str( + "1039568133358048217256645444858779690222593722360933428132029595856197231708", + ) + .unwrap(), + Fq::from_str( + "21957403689654416397064622231922871551099838312101526423356724466906188357692", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "26854225070716528408870645522239189278896427207500130710682305948803020238931", + ) + .unwrap(), + Fq::from_str( + "26943892344604571557701285028044912982696825664345334371487484451303282475647", + ) + .unwrap(), + Fq::from_str( + "1055486347679417723372054024522800671210435936363188971739620139282176021577", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "12039988523641802893655502894224498050885828461811563074139032284005801213403", + ) + .unwrap(), + Fq::from_str( + "8902157540895169828625719473847119761637916797736988512418726291566892375478", + ) + .unwrap(), + Fq::from_str( + "19808247018928888974783844516743773328294296650472582590579002155958871262899", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "1166679276532448467842854177763358821231571989731998006764195336801763455021", + ) + .unwrap(), + Fq::from_str( + "3606106570684658925413634110366206918044489665389891978146383152327787874509", + ) + .unwrap(), + Fq::from_str( + "6852721786435315500790559962812048058053541164453308515683899928808277370093", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "13806687138702327718737608018875659025193922209134042758115965685374173572631", + ) + .unwrap(), + Fq::from_str( + "22097540835610763700704213032962472026569342000865482925296449079325231775952", + ) + .unwrap(), + Fq::from_str( + "23854994668788070518164362677404876296654138145727636839907804381279830692184", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "14328285116930040297167480888093252040682222507378296723083675894276997543320", + ) + .unwrap(), + Fq::from_str( + "20209427422345103632021534022650468075947344519137728114160624684121322455710", + ) + .unwrap(), + Fq::from_str( + "27033255520536361062197858564300467125597172428228104135004669745440717000280", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "28248381813468035169220139433513611820173498138773431024157751299160860251101", + ) + .unwrap(), + Fq::from_str( + "21526871958088430765907319854118302865893119689569733282961302685994303096697", + ) + .unwrap(), + Fq::from_str( + "4746295740702898653769939686610795189271514521752909912067264117422778228375", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "9101401092044932317062159778689819323115945875284594416907500031473455978135", + ) + .unwrap(), + Fq::from_str( + "27308974038596559880820319745148649731010999017812218694960076804027626479421", + ) + .unwrap(), + Fq::from_str( + "27024963874124411093581981994784661747591122603933013537982183388422444695841", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "21057964411478038759944416368533075071641109811620242561938792043892897144239", + ) + .unwrap(), + Fq::from_str( + "7934667723159326341262909811771836782101934647017223031384894873861929276033", + ) + .unwrap(), + Fq::from_str( + "24057087430176896270365331032150773107194301890269528289146941949706298382245", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "24872091794585731195123232743434894841960966480807109028172616559410714340378", + ) + .unwrap(), + Fq::from_str( + "15518684264925082319927189973387561377010439235348248553623063845923579402660", + ) + .unwrap(), + Fq::from_str( + "14948609980173019585444594725654208043748147783660285932086415887394611752597", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "10984170585530607867489913407826422000086932773059636533222090923195886168825", + ) + .unwrap(), + Fq::from_str( + "9321826953632892919090736157541940967379221312916602977566838169907380096870", + ) + .unwrap(), + Fq::from_str( + "17487804874644538928028029763348613533582076810998020571125589343942948265817", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "15275480059698629965360784554938662843156630611714827458212571654634240046605", + ) + .unwrap(), + Fq::from_str( + "26413670478329619782619613350537218886176471689804000980910258836369485101702", + ) + .unwrap(), + Fq::from_str( + "22110235271545922012664286788630290002548452970188851575240062241974646215797", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "5711011179980344348611755788272653483339509530994363809978304737350075716000", + ) + .unwrap(), + Fq::from_str( + "23953629463804192365925501930661733866204203811414692118288576654795847703470", + ) + .unwrap(), + Fq::from_str( + "5136563821326800571661592018591914751723715382808500909300542629456466994789", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "21002177360068475415405665266698677701230625062874672384664043086721899565078", + ) + .unwrap(), + Fq::from_str( + "16160705656234082318664468980148954204257542989847179985485042083802028404005", + ) + .unwrap(), + Fq::from_str( + "9584652138885123343129154370726916454016668838741647553957283602972361563118", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "11976558678099062367003807661135049597791370468907580902513666607265118739358", + ) + .unwrap(), + Fq::from_str( + "17770574027300583122009021603974295849535945441531452157529578670591628790562", + ) + .unwrap(), + Fq::from_str( + "8779857893007641767450158410931422967669604621388757865092237959422562566512", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "1106764359170752442062364091255316123175951798705034870298529962585584020572", + ) + .unwrap(), + Fq::from_str( + "14282358984863272931851137163247018760911529915646224138003781953088609751794", + ) + .unwrap(), + Fq::from_str( + "23068266846998739315579280127999495081981085207657183174657687901108103620833", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "8653390726873789662036821861474930339237642221033368679523999558683634039983", + ) + .unwrap(), + Fq::from_str( + "21189350969079886399299663440211026207756289316640749637497110786199356061458", + ) + .unwrap(), + Fq::from_str( + "5874508707195808729739662175554323653263260150428377752434606463967243900638", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "22562647947455785733109059717137970098038603103702164680034643705156923696289", + ) + .unwrap(), + Fq::from_str( + "28784044665662723346322790792841446800014639216013400198981660575412845188995", + ) + .unwrap(), + Fq::from_str( + "21617548860882346762468289662184539964972003273262239453456162866555087501784", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "7921707293863412887763762425340047557212461768642176100113995468887128120848", + ) + .unwrap(), + Fq::from_str( + "5079638071684390339031549260477894630888951581139107799635162857532169033203", + ) + .unwrap(), + Fq::from_str( + "1906474501718235725268246233868869738847898113905316888545913829010490521322", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "15010703057223301210797194186273140083891325609163657533528230982592512926640", + ) + .unwrap(), + Fq::from_str( + "10211530488461406492044942258309860029358718988798555599263583230449549386184", + ) + .unwrap(), + Fq::from_str( + "7926214251232883423235385814093337865027940862721299192910052080502780132651", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "1615858616585740214171190013584453494816062927677997768793286983960714636586", + ) + .unwrap(), + Fq::from_str( + "6300851832008087933477316424296862623492613015770943553911344820034487697725", + ) + .unwrap(), + Fq::from_str( + "16997495024966063782859434171461822461915808567641951044575476421161801044996", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "19093867154823521476631513454884258559801391479742181351464825851713886882804", + ) + .unwrap(), + Fq::from_str( + "9876623979316708918029048066191049762079725911343449898764854106968461086193", + ) + .unwrap(), + Fq::from_str( + "7438377435815100341595798523554450500128694745187932659511819662569177641879", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "13237643488507212932551762727914712522464902552479712766778808080089035951383", + ) + .unwrap(), + Fq::from_str( + "1319931255447271194675270558717799001431869893418472253731223591074983480311", + ) + .unwrap(), + Fq::from_str( + "7612481215526800118812690896530728866481443352597194934339693910825696483013", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "21648850356578602757279748986908901999065310414626834013406905545000047414361", + ) + .unwrap(), + Fq::from_str( + "5114318994465500116617192593880584287425659388356129222869162968973137431628", + ) + .unwrap(), + Fq::from_str( + "14274034502399536699480454286342319572996073205595632914757715057595605987972", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "26790719393278821612050895247884938435625613110754118299610499799014565923461", + ) + .unwrap(), + Fq::from_str( + "4610514154584214123082003352374212407564345220952368178746170358675112402908", + ) + .unwrap(), + Fq::from_str( + "1698630880910166583271254572454135552446512741333083347763483525435981457103", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "20552720912510443152155372945469698492199625763926700431467362374595133355503", + ) + .unwrap(), + Fq::from_str( + "13027934210042230481275252151174239452268620117959842710803374999950407743674", + ) + .unwrap(), + Fq::from_str( + "12928773210332716904967599591061654461064257952845874859809010183743576804183", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "9898388983730091833586496876143319487640591474664562180048110968261504227763", + ) + .unwrap(), + Fq::from_str( + "246696781145054333994875178319722243107187829835409556837932862368780043827", + ) + .unwrap(), + Fq::from_str( + "3643500059049840720789605431319295724288225718414327982164300578592452239677", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "3565425667601106009770782071884148125071341260281920273751052503326375058481", + ) + .unwrap(), + Fq::from_str( + "11864389139266614505889554458655451207930913861631435830492116979608557714974", + ) + .unwrap(), + Fq::from_str( + "8274657129590003581058306529364964194968347052625573424884303363838851417154", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "1275743209701804744206403454856572943723564090814290111897584099271022701779", + ) + .unwrap(), + Fq::from_str( + "21973662159342416905656953138735231968923660220069068196765641337893610648313", + ) + .unwrap(), + Fq::from_str( + "25894732957240328820500628771930855661060611454759858073185019889253632110910", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "24235983361660621912660346593434103797222337145613774587667715180072880402932", + ) + .unwrap(), + Fq::from_str( + "28064151545062294376369933084121731998678552213732523790998247369967442851222", + ) + .unwrap(), + Fq::from_str( + "17349893890535912362335854141505342728498068491067316154965484585819728433057", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "15761968906322825313255937439984951429369868993224115833647535952461101665403", + ) + .unwrap(), + Fq::from_str( + "12730132432828071741894158325728434195442074356127747454789604974311569256382", + ) + .unwrap(), + Fq::from_str( + "15546396661635542412156226214532349631117485335067100794480597316674413181741", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "26114367077546812088524649139586722615602988031071908939189781274012523720259", + ) + .unwrap(), + Fq::from_str( + "24120384959346861175854010829333655347217114314748265093541001656605937854873", + ) + .unwrap(), + Fq::from_str( + "19948360916941737903881302285818644687244482456463503655644971939056521509651", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "19726197518649987200467574937800886411552656360200944676459718139782252781161", + ) + .unwrap(), + Fq::from_str( + "28474607446546947575504882600162047140263193547832650495916965063574806568486", + ) + .unwrap(), + Fq::from_str( + "23853413930706347232024181413476257923113376782352708959863255759270297482643", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "12172967092520478036533278879321417885370166446952416302138802149554076887765", + ) + .unwrap(), + Fq::from_str( + "23009599418043991694882996890518263678873774603216345360596994131664213566862", + ) + .unwrap(), + Fq::from_str( + "10115258270635971171995411905168552695793622404699352160197062614831363007330", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "11872854830696109949651382814501844672018549027397589681743730297592698621988", + ) + .unwrap(), + Fq::from_str( + "27242375034974808335639708106328197436180434795006391706142292135621509320179", + ) + .unwrap(), + Fq::from_str( + "12240585689619491656334069880205175383446802092185679587077983154058337550941", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "21354477485311779262361916142854589865704748405272954734490890336687615138805", + ) + .unwrap(), + Fq::from_str( + "23957401946708184815950634880682726184066491070197994539079194824221621571898", + ) + .unwrap(), + Fq::from_str( + "8357936131918521617894415271376877545637994736685320868457215391964819334380", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "23556508762893476575740526218250109012025360589353787405830042158273933308342", + ) + .unwrap(), + Fq::from_str( + "26459483525453534734424286694012535960391097465802298977757104519791714946079", + ) + .unwrap(), + Fq::from_str( + "10475806893121710950614249918125355895080858562674206501925615132176728614089", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "509921547172201949904713166493001584273069976849368713467345580264015783495", + ) + .unwrap(), + Fq::from_str( + "18125203608474211016238804504111319380002458161505700677366869384368972046217", + ) + .unwrap(), + Fq::from_str( + "21924276575871168170965471499148658434002761718539337627176793776405331995359", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "20766228186393497584471782099005176987747009290594643934202011294470303807376", + ) + .unwrap(), + Fq::from_str( + "14375963887626641531579004947270417067489282157820693664981439571201925786507", + ) + .unwrap(), + Fq::from_str( + "21159389101933617620301753100914640106752229791897141844262326158946032983368", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "4461134990679762924709186776053724197764516354075947521348677372357740214228", + ) + .unwrap(), + Fq::from_str( + "16252169654788390820854791971943967502618323077922450508362670811049952999068", + ) + .unwrap(), + Fq::from_str( + "12230304177499919762152198881532284912863465765960540568879309220349311440639", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "28841011969818167186628412469014473288587924982384920380582573823320491690149", + ) + .unwrap(), + Fq::from_str( + "2799323756756092759606808189644604436886324385103205704619694119270259152673", + ) + .unwrap(), + Fq::from_str( + "7552245287463888227716869296827175678848435662455435319330135096459780527367", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "14559878705009013461912521439177189262826320003053264136037120330161817373798", + ) + .unwrap(), + Fq::from_str( + "135149832293512941307855088962585950397809468736297008779181618805608043951", + ) + .unwrap(), + Fq::from_str( + "10076925352984375920319893131906409552943218948970513115140185752025602734367", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "21180884538099273847084389539895114475410666123121133636810594347139251692132", + ) + .unwrap(), + Fq::from_str( + "16264335492355657460979899000171576784721337003978974122245774555132733725877", + ) + .unwrap(), + Fq::from_str( + "11243620954294652682442140545184810781230200439960589447680494425973184831880", + ) + .unwrap(), + ], + vec![ + Fq::from_str( + "11960277331258705530567654798295965395742512385392290832518951609500109974170", + ) + .unwrap(), + Fq::from_str( + "12193804570375975216227507559799714885990198774010330434569978857678369247842", + ) + .unwrap(), + Fq::from_str( + "19383474201008690035488023798780015353919506197647243443206760029489523188626", + ) + .unwrap(), + ], + ], + } +} + +pub fn static_params() -> &'static ArithmeticSpongeParams { + static PARAMS: Lazy> = Lazy::new(params); + &PARAMS +} diff --git a/arrabiata/src/proof.rs b/arrabiata/src/proof.rs new file mode 100644 index 0000000000..676c547caa --- /dev/null +++ b/arrabiata/src/proof.rs @@ -0,0 +1,4 @@ +/// FIXME: a proof for the Nova recursive SNARK +// FIXME: type over curves +// FIXME: add a (de-)serializer to publish it somewhere +pub struct Proof {} diff --git a/arrabiata/src/prover.rs b/arrabiata/src/prover.rs new file mode 100644 index 0000000000..7e9ea7b36b --- /dev/null +++ b/arrabiata/src/prover.rs @@ -0,0 +1,21 @@ +//! A prover for the folding/accumulation scheme + +use crate::proof::Proof; +use ark_ec::AffineRepr; +use ark_ff::PrimeField; + +use crate::witness::Env; + +/// Generate a proof for the IVC circuit. +/// All the information to make a proof is available in the environment given in +/// parameter. +pub fn prove< + Fp: PrimeField, + Fq: PrimeField, + E1: AffineRepr, + E2: AffineRepr, +>( + _env: &Env, +) -> Result { + unimplemented!() +} diff --git a/arrabiata/src/verifier.rs b/arrabiata/src/verifier.rs new file mode 100644 index 0000000000..96fe150143 --- /dev/null +++ b/arrabiata/src/verifier.rs @@ -0,0 +1 @@ +//! A verifier for the folding/accumulation scheme diff --git a/arrabiata/src/witness.rs b/arrabiata/src/witness.rs new file mode 100644 index 0000000000..179e9d520a --- /dev/null +++ b/arrabiata/src/witness.rs @@ -0,0 +1,1052 @@ +use ark_ec::{models::short_weierstrass::SWCurveConfig, AffineRepr}; +use ark_ff::PrimeField; +use ark_poly::Evaluations; +use kimchi::circuits::{domains::EvaluationDomains, gate::CurrOrNext}; +use log::{debug, info}; +use num_bigint::{BigInt, BigUint}; +use num_integer::Integer; +use o1_utils::field_helpers::FieldHelpers; +use poly_commitment::{commitment::CommitmentCurve, ipa::SRS, PolyComm, SRS as _}; +use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; +use std::time::Instant; + +use crate::{ + columns::{Column, Gadget}, + interpreter::{Instruction, InterpreterEnv, Side}, + poseidon_3_60_0_5_5_fp, poseidon_3_60_0_5_5_fq, MAXIMUM_FIELD_SIZE_IN_BITS, NUMBER_OF_COLUMNS, + NUMBER_OF_PUBLIC_INPUTS, NUMBER_OF_SELECTORS, NUMBER_OF_VALUES_TO_ABSORB_PUBLIC_IO, + POSEIDON_ALPHA, POSEIDON_ROUNDS_FULL, POSEIDON_STATE_SIZE, +}; + +pub const IVC_STARTING_INSTRUCTION: Instruction = Instruction::Poseidon(0); + +/// An environment that can be shared between IVC instances. +/// +/// It contains all the accumulators that can be picked for a given fold +/// instance k, including the sponges. +/// +/// The environment is run over big integers to avoid performing +/// reduction at all step. Instead the user implementing the interpreter can +/// reduce in the corresponding field when they want. +pub struct Env< + Fp: PrimeField, + Fq: PrimeField, + E1: AffineRepr, + E2: AffineRepr, +> { + // ---------------- + // Setup related (domains + SRS) + /// Domain for Fp + pub domain_fp: EvaluationDomains, + + /// Domain for Fq + pub domain_fq: EvaluationDomains, + + /// SRS for the first curve + pub srs_e1: SRS, + + /// SRS for the second curve + pub srs_e2: SRS, + // ---------------- + + // ---------------- + // Information related to the IVC, which will be used by the prover/verifier + // at the end of the whole execution + // FIXME: use a blinded comm and also fold the blinder + pub ivc_accumulator_e1: Vec>, + + // FIXME: use a blinded comm and also fold the blinder + pub ivc_accumulator_e2: Vec>, + + /// Commitments to the previous instances + pub previous_commitments_e1: Vec>, + pub previous_commitments_e2: Vec>, + // ---------------- + + // ---------------- + // Data only used by the interpreter while building the witness over time + /// The index of the latest allocated variable in the circuit. + /// It is used to allocate new variables without having to keep track of the + /// position. + pub idx_var: usize, + + pub idx_var_next_row: usize, + + /// The index of the latest allocated public inputs in the circuit. + /// It is used to allocate new public inputs without having to keep track of + /// the position. + pub idx_var_pi: usize, + + /// Current processing row. Used to build the witness. + pub current_row: usize, + + /// State of the current row in the execution trace + pub state: [BigInt; NUMBER_OF_COLUMNS], + + /// Next row in the execution trace. It is useful when we deal with + /// polynomials accessing "the next row", i.e. witness columns where we do + /// evaluate at ζ and ζω. + pub next_state: [BigInt; NUMBER_OF_COLUMNS], + + /// Contain the public state + // FIXME: I don't like this design. Feel free to suggest a better solution + pub public_state: [BigInt; NUMBER_OF_PUBLIC_INPUTS], + + /// Selectors to activate the gadgets. + /// The size of the outer vector must be equal to the number of gadgets in + /// the circuit. + /// The size of the inner vector must be equal to the number of rows in + /// the circuit. + /// + /// The layout columns/rows is used to avoid rebuilding the arrays per + /// column when committing to the witness. + pub selectors: Vec>, + + /// While folding, we must keep track of the challenges the verifier would + /// have sent in the SNARK, and we must aggregate them. + // FIXME: nothing is done yet, and the challenges haven't been decided yet. + // See top-level documentation of the interpreter for more information. + pub challenges: Vec, + + /// Keep the current executed instruction + /// This can be used to identify which gadget the interpreter is currently + /// building. + pub current_instruction: Instruction, + + /// The sponges will be used to simulate the verifier messages, and will + /// also be used to verify the consistency of the computation by hashing the + /// public IO. + // IMPROVEME: use a list of BigInt? It might be faster as the CPU will + // already have in its cache the values, and we can use a flat array + pub sponge_e1: [BigInt; POSEIDON_STATE_SIZE], + pub sponge_e2: [BigInt; POSEIDON_STATE_SIZE], + + /// The current iteration of the IVC + pub current_iteration: u64, + + /// A previous hash, encoded in 2 chunks of 128 bits. + pub previous_hash: [u128; 2], + + /// The coin folding combiner will be used to generate the combinaison of + /// folding instances + pub r: BigInt, + + /// Temporary registers for elliptic curve points in affine coordinates than + /// can be used to save values between instructions. + /// + /// These temporary registers can be loaded into the state by using the + /// function `load_temporary_accumulators`. + /// + /// The registers can, and must, be cleaned after the gadget is computed. + /// + /// The values are considered as BigInt, even though we should add some + /// type. As we want to apply the KISS method, we tend to avoid adding + /// types. We leave this for future work. + /// + /// Two registers are provided, represented by a tuple for the coordinates + /// (x, y). + pub temporary_accumulators: ((BigInt, BigInt), (BigInt, BigInt)), + + /// Index of the values to absorb in the sponge + pub idx_values_to_absorb: usize, + // ---------------- + /// The witness of the current instance of the circuit. + /// The size of the outer vector must be equal to the number of columns in the + /// circuit. + /// The size of the inner vector must be equal to the number of rows in + /// the circuit. + /// + /// The layout columns/rows is used to avoid rebuilding the witness per + /// column when committing to the witness. + pub witness: Vec>, + + // -------------- + // Inputs + /// Initial input + pub z0: BigInt, + + /// Current input + pub zi: BigInt, + // --------------- + + // --------------- + // Only used to have type safety and think about the design at the + // type-level + pub _marker: std::marker::PhantomData<(Fp, Fq, E1, E2)>, + // --------------- +} + +// The condition on the parameters for E1 and E2 is to get the coefficients and +// convert them into biguint. +// The condition SWModelParameters is to get the parameters of the curve as +// biguint to use them to compute the slope in the elliptic curve addition +// algorithm. +impl< + Fp: PrimeField, + Fq: PrimeField, + E1: CommitmentCurve, + E2: CommitmentCurve, + > InterpreterEnv for Env +where + ::BaseField: PrimeField, + ::BaseField: PrimeField, +{ + type Position = (Column, CurrOrNext); + + /// For efficiency, and for having a single interpreter, we do not use one + /// of the fields. We use a generic BigInt to represent the values. + /// When building the witness, we will reduce into the corresponding field. + // FIXME: it might not be efficient as I initially thought. We do need to + // make some transformations between biguint and bigint, with an extra cost + // for allocations. + type Variable = BigInt; + + fn allocate(&mut self) -> Self::Position { + assert!(self.idx_var < NUMBER_OF_COLUMNS, "Maximum number of columns reached ({NUMBER_OF_COLUMNS}), increase the number of columns"); + let pos = Column::X(self.idx_var); + self.idx_var += 1; + (pos, CurrOrNext::Curr) + } + + fn allocate_next_row(&mut self) -> Self::Position { + assert!(self.idx_var_next_row < NUMBER_OF_COLUMNS, "Maximum number of columns reached ({NUMBER_OF_COLUMNS}), increase the number of columns"); + let pos = Column::X(self.idx_var_next_row); + self.idx_var_next_row += 1; + (pos, CurrOrNext::Next) + } + + fn read_position(&self, pos: Self::Position) -> Self::Variable { + let (col, row) = pos; + let Column::X(idx) = col else { + unimplemented!("Only works for private inputs") + }; + match row { + CurrOrNext::Curr => self.state[idx].clone(), + CurrOrNext::Next => self.next_state[idx].clone(), + } + } + + fn allocate_public_input(&mut self) -> Self::Position { + assert!(self.idx_var_pi < NUMBER_OF_PUBLIC_INPUTS, "Maximum number of public inputs reached ({NUMBER_OF_PUBLIC_INPUTS}), increase the number of public inputs"); + let pos = Column::PublicInput(self.idx_var_pi); + self.idx_var_pi += 1; + (pos, CurrOrNext::Curr) + } + + fn write_column(&mut self, pos: Self::Position, v: Self::Variable) -> Self::Variable { + let (col, row) = pos; + let Column::X(idx) = col else { + unimplemented!("Only works for private inputs") + }; + let modulus: BigInt = if self.current_iteration % 2 == 0 { + Fp::modulus_biguint().into() + } else { + Fq::modulus_biguint().into() + }; + let v = v.mod_floor(&modulus); + match row { + CurrOrNext::Curr => { + self.state[idx] = v.clone(); + } + CurrOrNext::Next => { + self.next_state[idx] = v.clone(); + } + } + v + } + + fn write_public_input(&mut self, pos: Self::Position, v: BigInt) -> Self::Variable { + let (col, _row) = pos; + let Column::PublicInput(idx) = col else { + unimplemented!("Only works for public input columns") + }; + let modulus: BigInt = if self.current_iteration % 2 == 0 { + Fp::modulus_biguint().into() + } else { + Fq::modulus_biguint().into() + }; + let v = v.mod_floor(&modulus); + self.public_state[idx] = v.clone(); + v + } + + /// Activate the gadget for the current row + fn activate_gadget(&mut self, gadget: Gadget) { + // IMPROVEME: it should be called only once per row + self.selectors[gadget as usize][self.current_row] = true; + } + + fn constrain_boolean(&mut self, x: Self::Variable) { + let modulus: BigInt = if self.current_iteration % 2 == 0 { + Fp::modulus_biguint().into() + } else { + Fq::modulus_biguint().into() + }; + let x = x.mod_floor(&modulus); + assert!(x == BigInt::from(0_usize) || x == BigInt::from(1_usize)); + } + + fn constant(&self, v: BigInt) -> Self::Variable { + v + } + + fn add_constraint(&mut self, _x: Self::Variable) { + unimplemented!("Only when building the constraints") + } + + fn assert_zero(&mut self, var: Self::Variable) { + assert_eq!(var, BigInt::from(0_usize)); + } + + fn assert_equal(&mut self, x: Self::Variable, y: Self::Variable) { + assert_eq!(x, y); + } + + fn square(&mut self, pos: Self::Position, x: Self::Variable) -> Self::Variable { + let res = x.clone() * x.clone(); + self.write_column(pos, res.clone()); + res + } + + /// Flagged as unsafe as it does require an additional range check + unsafe fn bitmask_be( + &mut self, + x: &Self::Variable, + highest_bit: u32, + lowest_bit: u32, + pos: Self::Position, + ) -> Self::Variable { + let diff: u32 = highest_bit - lowest_bit; + if diff == 0 { + self.write_column(pos, BigInt::from(0_usize)) + } else { + assert!( + diff > 0, + "The difference between the highest and lowest bit should be greater than 0" + ); + let rht = (BigInt::from(1_usize) << diff) - BigInt::from(1_usize); + let lft = x >> lowest_bit; + let res: BigInt = lft & rht; + self.write_column(pos, res) + } + } + + // FIXME: for now, we use the row number and compute the square. + // This is only for testing purposes, and having something to build the + // witness. + fn fetch_input(&mut self, pos: Self::Position) -> Self::Variable { + let x = BigInt::from(self.current_row as u64); + self.write_column(pos, x.clone()); + x + } + + /// Reset the environment to build the next row + fn reset(&mut self) { + // Save the current state in the witness + self.state.iter().enumerate().for_each(|(i, x)| { + self.witness[i][self.current_row] = x.clone(); + }); + // We increment the row + // TODO: should we check that we are not going over the domain size? + self.current_row += 1; + // We reset the indices for the variables + self.idx_var = 0; + self.idx_var_next_row = 0; + self.idx_var_pi = 0; + // We keep track of the values we already set. + self.state = self.next_state.clone(); + // And we reset the next state + self.next_state = std::array::from_fn(|_| BigInt::from(0_usize)); + } + + /// FIXME: check if we need to pick the left or right sponge + fn coin_folding_combiner(&mut self, pos: Self::Position) -> Self::Variable { + let r = if self.current_iteration % 2 == 0 { + self.sponge_e1[0].clone() + } else { + self.sponge_e2[0].clone() + }; + let (col, _) = pos; + let Column::X(idx) = col else { + unimplemented!("Only works for private columns") + }; + self.state[idx] = r.clone(); + self.r = r.clone(); + r + } + + fn load_poseidon_state(&mut self, pos: Self::Position, i: usize) -> Self::Variable { + let state = if self.current_iteration % 2 == 0 { + self.sponge_e1[i].clone() + } else { + self.sponge_e2[i].clone() + }; + self.write_column(pos, state) + } + + fn get_poseidon_round_constant( + &mut self, + pos: Self::Position, + round: usize, + i: usize, + ) -> Self::Variable { + let rc = if self.current_iteration % 2 == 0 { + poseidon_3_60_0_5_5_fp::static_params().round_constants[round][i] + .to_biguint() + .into() + } else { + poseidon_3_60_0_5_5_fq::static_params().round_constants[round][i] + .to_biguint() + .into() + }; + self.write_public_input(pos, rc) + } + + fn get_poseidon_mds_matrix(&mut self, i: usize, j: usize) -> Self::Variable { + if self.current_iteration % 2 == 0 { + poseidon_3_60_0_5_5_fp::static_params().mds[i][j] + .to_biguint() + .into() + } else { + poseidon_3_60_0_5_5_fq::static_params().mds[i][j] + .to_biguint() + .into() + } + } + + unsafe fn save_poseidon_state(&mut self, x: Self::Variable, i: usize) { + if self.current_iteration % 2 == 0 { + let modulus: BigInt = Fp::modulus_biguint().into(); + self.sponge_e1[i] = x.mod_floor(&modulus) + } else { + let modulus: BigInt = Fq::modulus_biguint().into(); + self.sponge_e2[i] = x.mod_floor(&modulus) + } + } + + // The following values are expected to be absorbed in order: + // - z0 + // - z1 + // - acc[0] + // - acc[1] + // - ... + // - acc[N_COL - 1] + // FIXME: for now, we will only absorb the accumulators as z0 and z1 are not + // updated yet. + unsafe fn fetch_value_to_absorb( + &mut self, + pos: Self::Position, + curr_round: usize, + ) -> Self::Variable { + let (col, _) = pos; + let Column::PublicInput(_idx) = col else { + panic!("Only works for public inputs") + }; + // If we are not the round 0, we must absorb nothing. + if curr_round != 0 { + self.write_public_input(pos, self.zero()) + } else { + // FIXME: we must absorb z0, z1 and i! + // We multiply by 2 as we have two coordinates + let idx = self.idx_values_to_absorb; + let res = if idx < 2 * NUMBER_OF_COLUMNS { + let idx_col = idx / 2; + debug!("Absorbing the accumulator for the column index {idx_col}. After this, there will still be {} elements to absorb", NUMBER_OF_VALUES_TO_ABSORB_PUBLIC_IO - idx - 1); + if self.current_iteration % 2 == 0 { + let (pt_x, pt_y) = self.ivc_accumulator_e2[idx_col] + .get_first_chunk() + .to_coordinates() + .unwrap(); + if idx % 2 == 0 { + self.write_public_input(pos, pt_x.to_biguint().into()) + } else { + self.write_public_input(pos, pt_y.to_biguint().into()) + } + } else { + let (pt_x, pt_y) = self.ivc_accumulator_e1[idx_col] + .get_first_chunk() + .to_coordinates() + .unwrap(); + if idx % 2 == 0 { + self.write_public_input(pos, pt_x.to_biguint().into()) + } else { + self.write_public_input(pos, pt_y.to_biguint().into()) + } + } + } else { + unimplemented!( + "We only absorb the accumulators for now. Of course, this is not sound." + ) + }; + self.idx_values_to_absorb += 1; + res + } + } + + unsafe fn load_temporary_accumulators( + &mut self, + pos_x: Self::Position, + pos_y: Self::Position, + side: Side, + ) -> (Self::Variable, Self::Variable) { + match self.current_instruction { + Instruction::EllipticCurveScaling(i_comm, bit) => { + // If we're processing the leftmost bit (i.e. bit == 0), we must load + // the initial value into the accumulators from the environment. + // In the left accumulator, we keep track of the value we keep doubling. + // In the right accumulator, we keep the result. + if bit == 0 { + if self.current_iteration % 2 == 0 { + match side { + Side::Left => { + let pt = self.previous_commitments_e2[i_comm].get_first_chunk(); + // We suppose we never have a commitment equals to the + // point at infinity + let (pt_x, pt_y) = pt.to_coordinates().unwrap(); + let pt_x = self.write_column(pos_x, pt_x.to_biguint().into()); + let pt_y = self.write_column(pos_y, pt_y.to_biguint().into()); + (pt_x, pt_y) + } + // As it is the first iteration, we must use the point at infinity. + // However, to avoid handling the case equal to zero, we will + // use a blinder, that we will substract at the end. + // As we suppose the probability to get a folding combiner + // equals to zero is negligible, we know we have a negligible + // probability to request to compute `0 * P`. + // FIXME: ! check this statement ! + Side::Right => { + let pt = self.srs_e2.h; + let (pt_x, pt_y) = pt.to_coordinates().unwrap(); + let pt_x = self.write_column(pos_x, pt_x.to_biguint().into()); + let pt_y = self.write_column(pos_y, pt_y.to_biguint().into()); + (pt_x, pt_y) + } + } + } else { + match side { + Side::Left => { + let pt = self.previous_commitments_e1[i_comm].get_first_chunk(); + // We suppose we never have a commitment equals to the + // point at infinity + let (pt_x, pt_y) = pt.to_coordinates().unwrap(); + let pt_x = self.write_column(pos_x, pt_x.to_biguint().into()); + let pt_y = self.write_column(pos_y, pt_y.to_biguint().into()); + (pt_x, pt_y) + } + // As it is the first iteration, we must use the point at infinity. + // However, to avoid handling the case equal to zero, we will + // use a blinder, that we will substract at the end. + // As we suppose the probability to get a folding combiner + // equals to zero is negligible, we know we have a negligible + // probability to request to compute `0 * P`. + // FIXME: ! check this statement ! + Side::Right => { + let pt = self.srs_e1.h; + let (pt_x, pt_y) = pt.to_coordinates().unwrap(); + let pt_x = self.write_column(pos_x, pt_x.to_biguint().into()); + let pt_y = self.write_column(pos_x, pt_y.to_biguint().into()); + (pt_x, pt_y) + } + } + } + } else { + panic!("We should not load the temporary accumulators for the bits different than 0 when using the elliptic curve scaling. It has been deactivated since we use the 'next row'"); + } + } + Instruction::EllipticCurveAddition(i_comm) => { + // FIXME: we must get the scaled commitment, not simply the commitment + let (pt_x, pt_y): (BigInt, BigInt) = match side { + Side::Left => { + if self.current_iteration % 2 == 0 { + let pt = self.ivc_accumulator_e2[i_comm].get_first_chunk(); + let (x, y) = pt.to_coordinates().unwrap(); + (x.to_biguint().into(), y.to_biguint().into()) + } else { + let pt = self.ivc_accumulator_e1[i_comm].get_first_chunk(); + let (x, y) = pt.to_coordinates().unwrap(); + (x.to_biguint().into(), y.to_biguint().into()) + } + } + Side::Right => { + if self.current_iteration % 2 == 0 { + let pt = self.previous_commitments_e2[i_comm].get_first_chunk(); + let (x, y) = pt.to_coordinates().unwrap(); + (x.to_biguint().into(), y.to_biguint().into()) + } else { + let pt = self.previous_commitments_e1[i_comm].get_first_chunk(); + let (x, y) = pt.to_coordinates().unwrap(); + (x.to_biguint().into(), y.to_biguint().into()) + } + } + }; + let pt_x = self.write_column(pos_x, pt_x.clone()); + let pt_y = self.write_column(pos_y, pt_y.clone()); + (pt_x, pt_y) + } + _ => unimplemented!("For now, the accumulators can only be used by the elliptic curve scaling gadget and {:?} is not supported. This should be changed as soon as the gadget is implemented.", self.current_instruction), + } + } + + unsafe fn save_temporary_accumulators( + &mut self, + x: Self::Variable, + y: Self::Variable, + side: Side, + ) { + match side { + Side::Left => { + self.temporary_accumulators.0 = (x, y); + } + Side::Right => { + self.temporary_accumulators.1 = (x, y); + } + } + } + + // It is unsafe as no constraint is added + unsafe fn is_same_ec_point( + &mut self, + pos: Self::Position, + x1: Self::Variable, + y1: Self::Variable, + x2: Self::Variable, + y2: Self::Variable, + ) -> Self::Variable { + let res = if x1 == x2 && y1 == y2 { + BigInt::from(1_usize) + } else { + BigInt::from(0_usize) + }; + self.write_column(pos, res) + } + + fn zero(&self) -> Self::Variable { + BigInt::from(0_usize) + } + + fn one(&self) -> Self::Variable { + BigInt::from(1_usize) + } + + /// Inverse of a variable + /// + /// # Safety + /// + /// Zero is not allowed as an input. + unsafe fn inverse(&mut self, pos: Self::Position, x: Self::Variable) -> Self::Variable { + let res = if self.current_iteration % 2 == 0 { + Fp::from_biguint(&x.to_biguint().unwrap()) + .unwrap() + .inverse() + .unwrap() + .to_biguint() + .into() + } else { + Fq::from_biguint(&x.to_biguint().unwrap()) + .unwrap() + .inverse() + .unwrap() + .to_biguint() + .into() + }; + self.write_column(pos, res) + } + + fn compute_lambda( + &mut self, + pos: Self::Position, + is_same_point: Self::Variable, + x1: Self::Variable, + y1: Self::Variable, + x2: Self::Variable, + y2: Self::Variable, + ) -> Self::Variable { + let modulus: BigInt = if self.current_iteration % 2 == 0 { + Fp::modulus_biguint().into() + } else { + Fq::modulus_biguint().into() + }; + // If it is not the same point, we compute lambda as: + // - λ = (Y1 - Y2) / (X1 - X2) + let (num, denom): (BigInt, BigInt) = if is_same_point == BigInt::from(0_usize) { + let num: BigInt = y1.clone() - y2.clone(); + let x1_minus_x2: BigInt = (x1.clone() - x2.clone()).mod_floor(&modulus); + // We temporarily store the inverse of the denominator into the + // given position. + let denom = unsafe { self.inverse(pos, x1_minus_x2) }; + (num, denom) + } else { + // Otherwise, we compute λ as: + // - λ = (3X1^2 + a) / (2Y1) + let denom = { + let double_y1 = y1.clone() + y1.clone(); + // We temporarily store the inverse of the denominator into the + // given position. + unsafe { self.inverse(pos, double_y1) } + }; + let num = { + let a: BigInt = if self.current_iteration % 2 == 0 { + (E2::Params::COEFF_A).to_biguint().into() + } else { + (E1::Params::COEFF_A).to_biguint().into() + }; + let x1_square = x1.clone() * x1.clone(); + let two_x1_square = x1_square.clone() + x1_square.clone(); + two_x1_square + x1_square + a + }; + (num, denom) + }; + let res = (num * denom).mod_floor(&modulus); + self.write_column(pos, res) + } + + /// Double the elliptic curve point given by the affine coordinates + /// `(x1, y1)` and save the result in the registers `pos_x` and `pos_y`. + fn double_ec_point( + &mut self, + pos_x: Self::Position, + pos_y: Self::Position, + x1: Self::Variable, + y1: Self::Variable, + ) -> (Self::Variable, Self::Variable) { + let modulus: BigInt = if self.current_iteration % 2 == 0 { + Fp::modulus_biguint().into() + } else { + Fq::modulus_biguint().into() + }; + // - λ = (3X1^2 + a) / (2Y1) + // We compute λ and use an additional column as a temporary value + // otherwise, we get a constraint of degree higher than 5 + let lambda_pos = self.allocate(); + let denom = { + let double_y1 = y1.clone() + y1.clone(); + // We temporarily store the inverse of the denominator into the + // given position. + unsafe { self.inverse(lambda_pos, double_y1) } + }; + let num = { + let a: BigInt = if self.current_iteration % 2 == 0 { + (E2::Params::COEFF_A).to_biguint().into() + } else { + (E1::Params::COEFF_A).to_biguint().into() + }; + let x1_square = x1.clone() * x1.clone(); + let two_x1_square = x1_square.clone() + x1_square.clone(); + two_x1_square + x1_square + a + }; + let lambda = (num * denom).mod_floor(&modulus); + self.write_column(lambda_pos, lambda.clone()); + // - X3 = λ^2 - X1 - X2 + let x3 = { + let double_x1 = x1.clone() + x1.clone(); + let res = lambda.clone() * lambda.clone() - double_x1.clone(); + self.write_column(pos_x, res.clone()) + }; + // - Y3 = λ(X1 - X3) - Y1 + let y3 = { + let x1_minus_x3 = x1.clone() - x3.clone(); + let res = lambda.clone() * x1_minus_x3 - y1.clone(); + self.write_column(pos_y, res.clone()) + }; + (x3, y3) + } +} + +impl< + Fp: PrimeField, + Fq: PrimeField, + E1: CommitmentCurve, + E2: CommitmentCurve, + > Env +{ + pub fn new( + srs_log2_size: usize, + z0: BigInt, + sponge_e1: [BigInt; 3], + sponge_e2: [BigInt; 3], + ) -> Self { + { + assert!(Fp::MODULUS_BIT_SIZE <= MAXIMUM_FIELD_SIZE_IN_BITS.try_into().unwrap(), "The size of the field Fp is too large, it should be less than {MAXIMUM_FIELD_SIZE_IN_BITS}"); + assert!(Fq::MODULUS_BIT_SIZE <= MAXIMUM_FIELD_SIZE_IN_BITS.try_into().unwrap(), "The size of the field Fq is too large, it should be less than {MAXIMUM_FIELD_SIZE_IN_BITS}"); + let modulus_fp = Fp::modulus_biguint(); + assert!( + (modulus_fp - BigUint::from(1_u64)).gcd(&BigUint::from(POSEIDON_ALPHA)) + == BigUint::from(1_u64), + "The modulus of Fp should be coprime with {POSEIDON_ALPHA}" + ); + let modulus_fq = Fq::modulus_biguint(); + assert!( + (modulus_fq - BigUint::from(1_u64)).gcd(&BigUint::from(POSEIDON_ALPHA)) + == BigUint::from(1_u64), + "The modulus of Fq should be coprime with {POSEIDON_ALPHA}" + ); + } + let srs_size = 1 << srs_log2_size; + let domain_fp = EvaluationDomains::::create(srs_size).unwrap(); + let domain_fq = EvaluationDomains::::create(srs_size).unwrap(); + + info!("Create an SRS of size {srs_log2_size} for the first curve"); + let srs_e1: SRS = { + let start = Instant::now(); + let srs = SRS::create(srs_size); + debug!("SRS for E1 created in {:?}", start.elapsed()); + let start = Instant::now(); + srs.get_lagrange_basis(domain_fp.d1); + debug!("Lagrange basis for E1 added in {:?}", start.elapsed()); + srs + }; + info!("Create an SRS of size {srs_log2_size} for the second curve"); + let srs_e2: SRS = { + let start = Instant::now(); + let srs = SRS::create(srs_size); + debug!("SRS for E2 created in {:?}", start.elapsed()); + let start = Instant::now(); + srs.get_lagrange_basis(domain_fq.d1); + debug!("Lagrange basis for E2 added in {:?}", start.elapsed()); + srs + }; + + let mut witness: Vec> = Vec::with_capacity(NUMBER_OF_COLUMNS); + { + let mut vec: Vec = Vec::with_capacity(srs_size); + (0..srs_size).for_each(|_| vec.push(BigInt::from(0_usize))); + (0..NUMBER_OF_COLUMNS).for_each(|_| witness.push(vec.clone())); + }; + + let mut selectors: Vec> = Vec::with_capacity(NUMBER_OF_SELECTORS); + { + let mut vec: Vec = Vec::with_capacity(srs_size); + (0..srs_size).for_each(|_| vec.push(false)); + (0..NUMBER_OF_SELECTORS).for_each(|_| selectors.push(vec.clone())); + }; + + // Default set to the blinders. Using double to make the EC scaling happy. + let previous_commitments_e1: Vec> = (0..NUMBER_OF_COLUMNS) + .map(|_| PolyComm::new(vec![(srs_e1.h + srs_e1.h).into()])) + .collect(); + let previous_commitments_e2: Vec> = (0..NUMBER_OF_COLUMNS) + .map(|_| PolyComm::new(vec![(srs_e2.h + srs_e2.h).into()])) + .collect(); + // FIXME: zero will not work. + let ivc_accumulator_e1: Vec> = (0..NUMBER_OF_COLUMNS) + .map(|_| PolyComm::new(vec![srs_e1.h])) + .collect(); + let ivc_accumulator_e2: Vec> = (0..NUMBER_OF_COLUMNS) + .map(|_| PolyComm::new(vec![srs_e2.h])) + .collect(); + + // FIXME: challenges + let challenges: Vec = vec![]; + + Self { + // ------- + // Setup + domain_fp, + domain_fq, + srs_e1, + srs_e2, + // ------- + // ------- + // IVC only + ivc_accumulator_e1, + ivc_accumulator_e2, + previous_commitments_e1, + previous_commitments_e2, + // ------ + // ------ + idx_var: 0, + idx_var_next_row: 0, + idx_var_pi: 0, + current_row: 0, + state: std::array::from_fn(|_| BigInt::from(0_usize)), + next_state: std::array::from_fn(|_| BigInt::from(0_usize)), + public_state: std::array::from_fn(|_| BigInt::from(0_usize)), + selectors, + challenges, + current_instruction: IVC_STARTING_INSTRUCTION, + sponge_e1, + sponge_e2, + current_iteration: 0, + previous_hash: [0; 2], + r: BigInt::from(0_usize), + // Initialize the temporary accumulators with 0 + temporary_accumulators: ( + (BigInt::from(0_u64), BigInt::from(0_u64)), + (BigInt::from(0_u64), BigInt::from(0_u64)), + ), + idx_values_to_absorb: 0, + // ------ + // ------ + // Used by the interpreter + // Used to allocate variables + // Witness builder related + witness, + // ------ + // Inputs + z0: z0.clone(), + zi: z0, + // ------ + _marker: std::marker::PhantomData, + } + } + + /// Reset the environment to build the next iteration + pub fn reset_for_next_iteration(&mut self) { + // Rest the state for the next row + self.current_row = 0; + self.state = std::array::from_fn(|_| BigInt::from(0_usize)); + self.idx_var = 0; + self.current_instruction = IVC_STARTING_INSTRUCTION; + self.idx_values_to_absorb = 0; + } + + /// The blinder used to commit, to avoid committing to the zero polynomial + /// and accumulate it in the IVC. + /// + /// It is part of the instance, and it is accumulated in the IVC. + pub fn accumulate_commitment_blinder(&mut self) { + // TODO + } + + /// Compute the commitments to the current witness, and update the previous + /// instances. + // Might be worth renaming this function + pub fn compute_and_update_previous_commitments(&mut self) { + if self.current_iteration % 2 == 0 { + let comms: Vec> = self + .witness + .par_iter() + .map(|evals| { + let evals: Vec = evals + .par_iter() + .map(|x| Fp::from_biguint(&x.to_biguint().unwrap()).unwrap()) + .collect(); + let evals = Evaluations::from_vec_and_domain(evals.to_vec(), self.domain_fp.d1); + self.srs_e1 + .commit_evaluations_non_hiding(self.domain_fp.d1, &evals) + }) + .collect(); + self.previous_commitments_e1 = comms + } else { + let comms: Vec> = self + .witness + .iter() + .map(|evals| { + let evals: Vec = evals + .par_iter() + .map(|x| Fq::from_biguint(&x.to_biguint().unwrap()).unwrap()) + .collect(); + let evals = Evaluations::from_vec_and_domain(evals.to_vec(), self.domain_fq.d1); + self.srs_e2 + .commit_evaluations_non_hiding(self.domain_fq.d1, &evals) + }) + .collect(); + self.previous_commitments_e2 = comms + } + } + + /// Compute the output of the application on the previous output + // TODO: we should compute the hash of the previous commitments, only on + // CPU? + pub fn compute_output(&mut self) { + self.zi = BigInt::from(42_usize) + } + + pub fn fetch_instruction(&self) -> Instruction { + self.current_instruction + } + + /// Describe the control-flow for the IVC circuit. + /// + /// For a step i + 1, the IVC circuit receives as public input the following + /// values: + /// + /// - The commitments to the previous witnesses. + /// - The previous challenges (α_{i}, β_{i}, γ_{i}) - the challenges β and γ + /// are used by the permutation argument where α is used by the quotient + /// polynomial, generated after also absorbing the accumulator of the + /// permutation argument. + /// - The previous accumulators (acc_1, ..., acc_17). + /// - The previous output z_i. + /// - The initial input z_0. + /// - The natural i describing the previous step. + /// + /// The control flow is as follow: + /// - We compute the hash of the previous commitments and verify the hash + /// corresponds to the public input: + /// + /// ```text + /// hash = H(i, acc_1, ..., acc_17, z_0, z_i) + /// ``` + /// + /// - We also have to check that the previous challenges (α, β, γ) have been + /// correctly generated. Therefore, we must compute the hashes of the + /// witnesses and verify they correspond to the public input. + /// + /// TODO + /// + /// - We compute the output of the application (TODO) + /// + /// ```text + /// z_(i + 1) = F(w_i, z_i) + /// ``` + /// + /// - We decompose the scalar `r`, the random combiner, into bits to compute + /// the MSM for the next step. + /// + /// - We compute the MSM (verifier) + /// + /// ```text + /// acc_(i + 1)_j = acc_i + r C_j + /// ``` + /// And also the cross-terms: + /// + /// ```text + /// E = E1 - r T1 - r^2 T2 - ... - r^d T^d + r^(d+1) E2 + /// = E1 - r (T1 + r (T2 + ... + r T^(d - 1)) - r E2) + /// ``` + /// where (d + 1) is the degree of the highest gate. + /// + /// - We compute the next hash we give to the next instance + /// + /// ```text + /// hash' = H(i + 1, acc'_1, ..., acc'_17, z_0, z_(i + 1)) + /// ``` + pub fn fetch_next_instruction(&mut self) -> Instruction { + match self.current_instruction { + Instruction::Poseidon(i) => { + if i < POSEIDON_ROUNDS_FULL - 5 { + Instruction::Poseidon(i + 5) + } else { + // FIXME: we continue absorbing + Instruction::Poseidon(0) + } + } + Instruction::EllipticCurveScaling(i_comm, bit) => { + // TODO: we still need to substract (or not?) the blinder. + // Maybe we can avoid this by aggregating them. + // TODO: we also need to aggregate the cross-terms. + // Therefore i_comm must also take into the account the number + // of cross-terms. + assert!(i_comm < NUMBER_OF_COLUMNS, "Maximum number of columns reached ({NUMBER_OF_COLUMNS}), increase the number of columns"); + assert!(bit < MAXIMUM_FIELD_SIZE_IN_BITS, "Maximum number of bits reached ({MAXIMUM_FIELD_SIZE_IN_BITS}), increase the number of bits"); + if bit < MAXIMUM_FIELD_SIZE_IN_BITS - 1 { + Instruction::EllipticCurveScaling(i_comm, bit + 1) + } else if i_comm < NUMBER_OF_COLUMNS - 1 { + Instruction::EllipticCurveScaling(i_comm + 1, 0) + } else { + // We have computed all the bits for all the columns + Instruction::NoOp + } + } + Instruction::EllipticCurveAddition(i_comm) => { + if i_comm < NUMBER_OF_COLUMNS - 1 { + Instruction::EllipticCurveAddition(i_comm + 1) + } else { + Instruction::NoOp + } + } + Instruction::NoOp => Instruction::NoOp, + } + } +} diff --git a/arrabiata/tests/constraints.rs b/arrabiata/tests/constraints.rs new file mode 100644 index 0000000000..8506e5a183 --- /dev/null +++ b/arrabiata/tests/constraints.rs @@ -0,0 +1,166 @@ +use num_bigint::BigInt; +use std::collections::HashMap; + +use arrabiata::{ + columns::Gadget, + constraints, + interpreter::{self, Instruction}, + poseidon_3_60_0_5_5_fp, poseidon_3_60_0_5_5_fq, +}; +use mina_curves::pasta::fields::{Fp, Fq}; + +fn helper_compute_constraints_gadget(instr: Instruction, exp_constraints: usize) { + let mut constraints_fp = { + let poseidon_mds = poseidon_3_60_0_5_5_fp::static_params().mds.clone(); + constraints::Env::::new(poseidon_mds.to_vec(), BigInt::from(0_usize)) + }; + + interpreter::run_ivc(&mut constraints_fp, instr); + assert_eq!(constraints_fp.constraints.len(), exp_constraints); + + let mut constraints_fq = { + let poseidon_mds = poseidon_3_60_0_5_5_fq::static_params().mds.clone(); + constraints::Env::::new(poseidon_mds.to_vec(), BigInt::from(0_usize)) + }; + interpreter::run_ivc(&mut constraints_fq, instr); + assert_eq!(constraints_fq.constraints.len(), exp_constraints); +} + +fn helper_check_expected_degree_constraints(instr: Instruction, exp_degrees: HashMap) { + let mut constraints_fp = { + let poseidon_mds = poseidon_3_60_0_5_5_fp::static_params().mds.clone(); + constraints::Env::::new(poseidon_mds.to_vec(), BigInt::from(0_usize)) + }; + interpreter::run_ivc(&mut constraints_fp, instr); + + let mut actual_degrees: HashMap = HashMap::new(); + constraints_fp.constraints.iter().for_each(|c| { + let degree = c.degree(1, 0); + let count = actual_degrees.entry(degree).or_insert(0); + *count += 1; + }); + + exp_degrees.iter().for_each(|(degree, count)| { + assert_eq!( + actual_degrees.get(degree), + Some(count), + "Instruction {:?}: invalid count for degree {} (computed: {}, expected: {})", + instr, + degree, + actual_degrees.get(degree).unwrap(), + count + ); + }); +} + +// Helper to verify the number of columns each gadget uses +fn helper_gadget_number_of_columns_used( + instr: Instruction, + exp_nb_columns: usize, + exp_nb_public_input: usize, +) { + let mut constraints_fp = { + let poseidon_mds = poseidon_3_60_0_5_5_fp::static_params().mds.clone(); + constraints::Env::::new(poseidon_mds.to_vec(), BigInt::from(0_usize)) + }; + interpreter::run_ivc(&mut constraints_fp, instr); + + let nb_columns = constraints_fp.idx_var; + assert_eq!(nb_columns, exp_nb_columns); + + let nb_public_input = constraints_fp.idx_var_pi; + assert_eq!(nb_public_input, exp_nb_public_input); +} + +fn helper_check_gadget_activated(instr: Instruction, gadget: Gadget) { + let mut constraints_fp = { + let poseidon_mds = poseidon_3_60_0_5_5_fp::static_params().mds.clone(); + constraints::Env::::new(poseidon_mds.to_vec(), BigInt::from(0_usize)) + }; + interpreter::run_ivc(&mut constraints_fp, instr); + + assert_eq!(constraints_fp.activated_gadget, Some(gadget)); +} + +#[test] +fn test_gadget_poseidon_next_row() { + let instr = Instruction::Poseidon(0); + helper_compute_constraints_gadget(instr, 15); + + let mut exp_degrees = HashMap::new(); + exp_degrees.insert(5, 15); + helper_check_expected_degree_constraints(instr, exp_degrees); + + helper_gadget_number_of_columns_used(instr, 15, 17); + + // We always have 2 additional public inputs, even if set to 0 + let instr = Instruction::Poseidon(1); + helper_gadget_number_of_columns_used(instr, 15, 17); + + helper_check_gadget_activated(instr, Gadget::Poseidon); +} + +#[test] +fn test_gadget_elliptic_curve_addition() { + let instr = Instruction::EllipticCurveAddition(0); + helper_compute_constraints_gadget(instr, 3); + + let mut exp_degrees = HashMap::new(); + exp_degrees.insert(3, 1); + exp_degrees.insert(2, 2); + helper_check_expected_degree_constraints(instr, exp_degrees); + + helper_gadget_number_of_columns_used(instr, 8, 0); + + helper_check_gadget_activated(instr, Gadget::EllipticCurveAddition); +} + +#[test] +fn test_ivc_total_number_of_constraints_ivc() { + let constraints_fp = { + let poseidon_mds = poseidon_3_60_0_5_5_fp::static_params().mds.clone(); + constraints::Env::::new(poseidon_mds.to_vec(), BigInt::from(0_usize)) + }; + + let constraints = constraints_fp.get_all_constraints_for_ivc(); + assert_eq!(constraints.len(), 28); +} + +#[test] +fn test_degree_of_constraints_ivc() { + let constraints_fp = { + let poseidon_mds = poseidon_3_60_0_5_5_fp::static_params().mds.clone(); + constraints::Env::::new(poseidon_mds.to_vec(), BigInt::from(0_usize)) + }; + + let constraints = constraints_fp.get_all_constraints_for_ivc(); + + let mut degree_per_constraints = HashMap::new(); + constraints.iter().for_each(|c| { + let degree = c.degree(1, 0); + let count = degree_per_constraints.entry(degree).or_insert(0); + *count += 1; + }); + + assert_eq!(degree_per_constraints.get(&1), Some(&1)); + assert_eq!(degree_per_constraints.get(&2), Some(&11)); + assert_eq!(degree_per_constraints.get(&3), Some(&1)); + assert_eq!(degree_per_constraints.get(&4), None); + assert_eq!(degree_per_constraints.get(&5), Some(&15)); +} + +#[test] +fn test_gadget_elliptic_curve_scaling() { + let instr = Instruction::EllipticCurveScaling(0, 0); + // FIXME: update when the gadget is fnished + helper_compute_constraints_gadget(instr, 10); + + let mut exp_degrees = HashMap::new(); + exp_degrees.insert(1, 1); + exp_degrees.insert(2, 9); + helper_check_expected_degree_constraints(instr, exp_degrees); + + helper_gadget_number_of_columns_used(instr, 10, 0); + + helper_check_gadget_activated(instr, Gadget::EllipticCurveScaling); +} diff --git a/arrabiata/tests/utils.rs b/arrabiata/tests/utils.rs new file mode 100644 index 0000000000..67dd557e40 --- /dev/null +++ b/arrabiata/tests/utils.rs @@ -0,0 +1,18 @@ +//! This file aims to test the correctness of the functions from external +//! libraries that are used in the project. +//! It is mostly to ensure that the assumptions we make about the external +//! libraries are correct. + +use num_bigint::BigInt; +use num_integer::Integer; + +// Testing that modulo on negative numbers gives a positive value. +// It is extensively used in the witness generation, therefore checking this +// assumption is important. +#[test] +fn test_biguint_from_bigint() { + let a = BigInt::from(-9); + let modulus = BigInt::from(10); + let a = a.mod_floor(&modulus); + assert_eq!(a, BigInt::from(1)); +} diff --git a/arrabiata/tests/witness.rs b/arrabiata/tests/witness.rs new file mode 100644 index 0000000000..5c8239d889 --- /dev/null +++ b/arrabiata/tests/witness.rs @@ -0,0 +1,244 @@ +use ark_ec::{AffineRepr, Group}; +use ark_ff::{PrimeField, UniformRand}; +use arrabiata::{ + interpreter::{self, Instruction, InterpreterEnv}, + poseidon_3_60_0_5_5_fp, + witness::Env, + MAXIMUM_FIELD_SIZE_IN_BITS, POSEIDON_ROUNDS_FULL, POSEIDON_STATE_SIZE, +}; +use mina_curves::pasta::{Fp, Fq, Pallas, ProjectivePallas, Vesta}; +use mina_poseidon::{constants::SpongeConstants, permutation::poseidon_block_cipher}; +use num_bigint::{BigInt, ToBigInt}; +use o1_utils::FieldHelpers; +use poly_commitment::{commitment::CommitmentCurve, PolyComm}; +use rand::{CryptoRng, RngCore}; + +// Used by the mina_poseidon library. Only for testing. +#[derive(Clone)] +pub struct PlonkSpongeConstants {} + +impl SpongeConstants for PlonkSpongeConstants { + const SPONGE_CAPACITY: usize = 1; + const SPONGE_WIDTH: usize = POSEIDON_STATE_SIZE; + const SPONGE_RATE: usize = 2; + const PERM_ROUNDS_FULL: usize = POSEIDON_ROUNDS_FULL; + const PERM_ROUNDS_PARTIAL: usize = 0; + const PERM_HALF_ROUNDS_FULL: usize = 0; + const PERM_SBOX: u32 = 5; + const PERM_FULL_MDS: bool = true; + const PERM_INITIAL_ARK: bool = false; +} + +#[test] +fn test_unit_witness_poseidon_next_row_gadget_one_full_hash() { + let srs_log2_size = 6; + let sponge: [BigInt; POSEIDON_STATE_SIZE] = std::array::from_fn(|_i| BigInt::from(42u64)); + let mut env = Env::::new( + srs_log2_size, + BigInt::from(1u64), + sponge.clone(), + sponge.clone(), + ); + + env.current_instruction = Instruction::Poseidon(0); + + (0..(POSEIDON_ROUNDS_FULL / 5)).for_each(|i| { + interpreter::run_ivc(&mut env, Instruction::Poseidon(5 * i)); + env.reset(); + }); + let exp_output = { + let mut state = sponge + .clone() + .to_vec() + .iter() + .map(|x| Fp::from_biguint(&x.to_biguint().unwrap()).unwrap()) + .collect::>(); + state[0] += env.srs_e2.h.x; + state[1] += env.srs_e2.h.y; + poseidon_block_cipher::( + poseidon_3_60_0_5_5_fp::static_params(), + &mut state, + ); + state + .iter() + .map(|x| x.to_biguint().into()) + .collect::>() + }; + // Check correctness for current iteration + assert_eq!(env.sponge_e1.to_vec(), exp_output); + // Check the other sponge hasn't been modified + assert_eq!(env.sponge_e2, sponge.clone()); + + // Number of rows used by one full hash + assert_eq!(env.current_row, 13); +} + +#[test] +fn test_unit_witness_elliptic_curve_addition() { + let srs_log2_size = 6; + let sponge_e1: [BigInt; POSEIDON_STATE_SIZE] = std::array::from_fn(|_i| BigInt::from(42u64)); + let mut env = Env::::new( + srs_log2_size, + BigInt::from(1u64), + sponge_e1.clone(), + sponge_e1.clone(), + ); + + let instr = Instruction::EllipticCurveAddition(0); + env.current_instruction = instr; + + // If we are at iteration 0, we will compute the addition of points over + // Pallas, whose scalar field is Fp. + assert_eq!(env.current_iteration, 0); + let (exp_x3, exp_y3) = { + let res: Pallas = (env.ivc_accumulator_e2[0].get_first_chunk() + + env.previous_commitments_e2[0].get_first_chunk()) + .into(); + let (x3, y3) = res.to_coordinates().unwrap(); + ( + x3.to_biguint().to_bigint().unwrap(), + y3.to_biguint().to_bigint().unwrap(), + ) + }; + interpreter::run_ivc(&mut env, instr); + assert_eq!(exp_x3, env.state[6], "The x coordinate is incorrect"); + assert_eq!(exp_y3, env.state[7], "The y coordinate is incorrect"); + + env.reset(); + env.reset_for_next_iteration(); + env.current_instruction = instr; + env.current_iteration += 1; + + assert_eq!(env.current_iteration, 1); + let (exp_x3, exp_y3) = { + let res: Vesta = (env.ivc_accumulator_e1[0].get_first_chunk() + + env.previous_commitments_e1[0].get_first_chunk()) + .into(); + let (x3, y3) = res.to_coordinates().unwrap(); + ( + x3.to_biguint().to_bigint().unwrap(), + y3.to_biguint().to_bigint().unwrap(), + ) + }; + interpreter::run_ivc(&mut env, instr); + assert_eq!(exp_x3, env.state[6], "The x coordinate is incorrect"); + assert_eq!(exp_y3, env.state[7], "The y coordinate is incorrect"); + + env.reset(); + env.reset_for_next_iteration(); + env.current_instruction = instr; + env.current_iteration += 1; + + assert_eq!(env.current_iteration, 2); + let (exp_x3, exp_y3) = { + let res: Pallas = (env.ivc_accumulator_e2[0].get_first_chunk() + + env.previous_commitments_e2[0].get_first_chunk()) + .into(); + let (x3, y3) = res.to_coordinates().unwrap(); + ( + x3.to_biguint().to_bigint().unwrap(), + y3.to_biguint().to_bigint().unwrap(), + ) + }; + interpreter::run_ivc(&mut env, instr); + + assert_eq!(exp_x3, env.state[6], "The x coordinate is incorrect"); + assert_eq!(exp_y3, env.state[7], "The y coordinate is incorrect"); +} + +#[test] +fn test_witness_double_elliptic_curve_point() { + let mut rng = o1_utils::tests::make_test_rng(None); + let srs_log2_size = 6; + let sponge_e1: [BigInt; POSEIDON_STATE_SIZE] = std::array::from_fn(|_i| BigInt::from(42u64)); + let mut env = Env::::new( + srs_log2_size, + BigInt::from(1u64), + sponge_e1.clone(), + sponge_e1.clone(), + ); + + env.current_instruction = Instruction::EllipticCurveAddition(0); + + // Generate a random point + let p1: Pallas = { + let x = Fq::rand(&mut rng); + Pallas::generator().mul_bigint(x.into_bigint()).into() + }; + + // Doubling in the environment + let pos_x = env.allocate(); + let pos_y = env.allocate(); + let p1_x = env.write_column(pos_x, p1.x.to_biguint().into()); + let p1_y = env.write_column(pos_y, p1.y.to_biguint().into()); + let (res_x, res_y) = env.double_ec_point(pos_x, pos_y, p1_x, p1_y); + + let exp_res: Pallas = (p1 + p1).into(); + let exp_x: BigInt = exp_res.x.to_biguint().into(); + let exp_y: BigInt = exp_res.y.to_biguint().into(); + + assert_eq!(res_x, exp_x, "The x coordinate is incorrect"); + assert_eq!(res_y, exp_y, "The y coordinate is incorrect"); +} + +fn helper_elliptic_curve_scalar_multiplication(r: BigInt, rng: &mut RNG) +where + RNG: RngCore + CryptoRng, +{ + let srs_log2_size = 10; + let sponge_e1: [BigInt; POSEIDON_STATE_SIZE] = std::array::from_fn(|_i| r.clone()); + let mut env = Env::::new( + srs_log2_size, + BigInt::from(1u64), + sponge_e1.clone(), + sponge_e1.clone(), + ); + + let i_comm = 0; + let p1: Pallas = { + let x = Fq::rand(rng); + Pallas::generator().mul_bigint(x.into_bigint()).into() + }; + env.previous_commitments_e2[0] = PolyComm::new(vec![p1]); + + // We only go up to the maximum bit field size. + (0..MAXIMUM_FIELD_SIZE_IN_BITS).for_each(|bit_idx| { + let instr = Instruction::EllipticCurveScaling(i_comm, bit_idx); + env.current_instruction = instr; + interpreter::run_ivc(&mut env, instr); + env.reset(); + }); + + let res_x: BigInt = env.state[0].clone(); + let res_y: BigInt = env.state[1].clone(); + + let p1_proj: ProjectivePallas = p1.into(); + // @volhovm TODO check if mul_bigint is what was intended + let p1_r: Pallas = p1_proj.mul_bigint(r.clone().to_u64_digits().1).into(); + let exp_res: Pallas = (p1_r + env.srs_e2.h).into(); + + let exp_x: BigInt = exp_res.x.to_biguint().into(); + let exp_y: BigInt = exp_res.y.to_biguint().into(); + assert_eq!(res_x, exp_x, "The x coordinate is incorrect"); + assert_eq!(res_y, exp_y, "The y coordinate is incorrect"); +} + +#[test] +fn test_witness_elliptic_curve_scalar_multiplication() { + let mut rng = o1_utils::tests::make_test_rng(None); + + // We start with doubling + helper_elliptic_curve_scalar_multiplication(BigInt::from(2u64), &mut rng); + + // Special case of the identity + helper_elliptic_curve_scalar_multiplication(BigInt::from(1u64), &mut rng); + + helper_elliptic_curve_scalar_multiplication(BigInt::from(3u64), &mut rng); + + // The answer to everything + helper_elliptic_curve_scalar_multiplication(BigInt::from(42u64), &mut rng); + + // A random scalar + let r: BigInt = Fp::rand(&mut rng).to_biguint().to_bigint().unwrap(); + helper_elliptic_curve_scalar_multiplication(r, &mut rng); +} diff --git a/arrabiata/tests/witness_utils.rs b/arrabiata/tests/witness_utils.rs new file mode 100644 index 0000000000..7e6c78e876 --- /dev/null +++ b/arrabiata/tests/witness_utils.rs @@ -0,0 +1,72 @@ +//! Tests for utility functions in the witness environment. +//! Utilities are defined as functions indirectly used by the user. +//! A user is expected to use the gadget methods. +//! The API of the utilities is more subject to changes. + +use arrabiata::{interpreter::InterpreterEnv, witness::Env, POSEIDON_STATE_SIZE}; +use mina_curves::pasta::{Fp, Fq, Pallas, Vesta}; +use num_bigint::BigInt; +use o1_utils::FieldHelpers; + +#[test] +#[should_panic] +fn test_constrain_boolean_witness_negative_value() { + let srs_log2_size = 2; + let mut env = { + let z0 = BigInt::from(1u64); + let sponge_e1: [BigInt; POSEIDON_STATE_SIZE] = std::array::from_fn(|_i| BigInt::from(0u64)); + Env::::new(srs_log2_size, z0, sponge_e1.clone(), sponge_e1.clone()) + }; + + env.constrain_boolean(BigInt::from(-42)); +} + +#[test] +fn test_constrain_boolean_witness_positive_and_negative_modulus() { + let srs_log2_size = 2; + let mut env = { + let z0 = BigInt::from(1u64); + let sponge_e1: [BigInt; POSEIDON_STATE_SIZE] = std::array::from_fn(|_i| BigInt::from(0u64)); + Env::::new(srs_log2_size, z0, sponge_e1.clone(), sponge_e1.clone()) + }; + + let modulus: BigInt = Fp::modulus_biguint().into(); + env.constrain_boolean(modulus.clone()); + env.constrain_boolean(modulus.clone() + BigInt::from(1u64)); + env.constrain_boolean(-modulus.clone()); + env.constrain_boolean(-modulus.clone() + BigInt::from(1u64)); +} + +#[test] +fn test_write_column_return_the_result_reduced_in_field() { + let srs_log2_size = 6; + let sponge_e1: [BigInt; POSEIDON_STATE_SIZE] = std::array::from_fn(|_i| BigInt::from(42u64)); + let mut env = Env::::new( + srs_log2_size, + BigInt::from(1u64), + sponge_e1.clone(), + sponge_e1.clone(), + ); + let modulus: BigInt = Fp::modulus_biguint().into(); + let pos_x = env.allocate(); + let res = env.write_column(pos_x, modulus.clone() + BigInt::from(1u64)); + assert_eq!(res, BigInt::from(1u64)); + assert_eq!(env.state[0], BigInt::from(1u64)); +} + +#[test] +fn test_write_public_return_the_result_reduced_in_field() { + let srs_log2_size = 6; + let sponge_e1: [BigInt; POSEIDON_STATE_SIZE] = std::array::from_fn(|_i| BigInt::from(42u64)); + let mut env = Env::::new( + srs_log2_size, + BigInt::from(1u64), + sponge_e1.clone(), + sponge_e1.clone(), + ); + let modulus: BigInt = Fp::modulus_biguint().into(); + let pos_x = env.allocate_public_input(); + let res = env.write_public_input(pos_x, modulus.clone() + BigInt::from(1u64)); + assert_eq!(res, BigInt::from(1u64)); + assert_eq!(env.public_state[0], BigInt::from(1u64)); +} diff --git a/book/Makefile b/book/Makefile index 5416d1b7a3..1971eba6a1 100644 --- a/book/Makefile +++ b/book/Makefile @@ -59,4 +59,4 @@ serve: check # clean: - mdbook clean \ No newline at end of file + mdbook clean diff --git a/book/macros.txt b/book/macros.txt index 26382f28ac..64e11da633 100644 --- a/book/macros.txt +++ b/book/macros.txt @@ -65,4 +65,4 @@ \prg:{\small \mathsf{PRG}} \plonk:\mathcal{PlonK} -\plookup:\textsf{Plookup} \ No newline at end of file +\plookup:\textsf{Plookup} diff --git a/book/specifications/README.md b/book/specifications/README.md index 820421c678..0afc221331 100644 --- a/book/specifications/README.md +++ b/book/specifications/README.md @@ -15,6 +15,11 @@ Install cargo-spec: $ cargo install cargo-spec ``` +You will also need markdownlint: +```console +$ npm install markdownlint-cli -g +``` + ## Render To produce a specification, simply run the following command: diff --git a/book/specifications/kimchi/Makefile b/book/specifications/kimchi/Makefile index f957c4e2c6..e058f2c65c 100644 --- a/book/specifications/kimchi/Makefile +++ b/book/specifications/kimchi/Makefile @@ -9,15 +9,17 @@ OUT_FILE := ../../src/specs/kimchi.md # - MD012: Multiple consecutive blank lines # · because the template file can look awkward sometimes # - MD013: Line length -# · because we do not want to break the lines when we spec long tables +# · because we do not want to break the lines when we spec long tables # - MD024: Multiple headers with the same content # · because otherwise we cannot have a common structure for the sub-headers in FFMul and FFAdd -# - MD049: Emphasis style should be consistent -# · because it interprets the underscore in LaTeX as emphasis +# - MD049: Emphasis style should be consistent +# · because it interprets the underscore in LaTeX as emphasis +# - MD056: Table column count +# · because CI hetzner does not work otherwise; seems that the problem is synthetic build: cargo spec build --output-file $(OUT_FILE) @which markdownlint &>/dev/null || (echo "Missing markdownlint-cli dependency: npm install -g markdownlint-cli" && exit 1) - markdownlint --ignore node_modules --disable=MD012 --disable=MD024 --disable=MD049 --disable=MD013 --disable=MD010 $(OUT_FILE) + markdownlint --ignore node_modules --disable=MD012 --disable=MD024 --disable=MD049 --disable=MD013 --disable=MD010 --disable=MD056 $(OUT_FILE) # watches specification-related files and rebuilds them on the fly watch: diff --git a/book/specifications/kimchi/Specification.toml b/book/specifications/kimchi/Specification.toml index 7ec3ec05fa..01b51bd184 100644 --- a/book/specifications/kimchi/Specification.toml +++ b/book/specifications/kimchi/Specification.toml @@ -32,6 +32,7 @@ table_xor = "../../../kimchi/src/circuits/lookup/tables/xor.rs" table_12bit = "../../../kimchi/src/circuits/lookup/tables/range_check.rs" lookup = "../../../kimchi/src/circuits/lookup/constraints.rs" lookup_index = "../../../kimchi/src/circuits/lookup/index.rs" +runtime_tables = "../../../kimchi/src/circuits/lookup/runtime_tables.rs" # setup constraint_system = "../../../kimchi/src/circuits/constraints.rs" diff --git a/book/specifications/kimchi/template.md b/book/specifications/kimchi/template.md index f40ef6e5ce..95a9f8fe50 100644 --- a/book/specifications/kimchi/template.md +++ b/book/specifications/kimchi/template.md @@ -195,7 +195,7 @@ Similarly to the generic gate, each values taking part in a lookup can be scaled The lookup functionality is an opt-in feature of kimchi that can be used by custom gates. From the user's perspective, not using any gates that make use of lookups means that the feature will be disabled and there will be no overhead to the protocol. -Refer to the [lookup RFC](../kimchi/lookup.md) for an overview of the lookup feature. +Please refer to the [lookup RFC](../kimchi/lookup.md) for an overview of the lookup feature. In this section, we describe the tables kimchi supports, as well as the different lookup selectors (and their associated queries) @@ -213,6 +213,12 @@ Kimchi currently supports two lookup tables: {sections.table_12bit} +##### Runtime tables + +Another type of lookup tables has been suggested in the [Extended Lookup Tables](../kimchi/extended-lookup-tables.md). + +{sections.runtime_tables} + #### The Lookup Selectors **XorSelector**. Performs 4 queries to the XOR lookup table. @@ -292,7 +298,7 @@ Here we describe basic gadgets that we build using a combination of the gates de In this section we specify the setup that goes into creating two indexes from a circuit: -* A [*prover index*](#prover-index), necessary for the prover to to create proofs. +* A [*prover index*](#prover-index), necessary for the prover to create proofs. * A [*verifier index*](#verifier-index), necessary for the verifier to verify proofs. ```admonish @@ -377,7 +383,7 @@ Same as the prover index, we have a number of pre-computations as part of the ve ## Proof Construction & Verification -Originally, kimchi is based on an interactive protocol that was transformed into a non-interactive one using the [Fiat-Shamir](https://o1-labs.github.io/mina-book/crypto/plonk/fiat_shamir.html) transform. +Originally, kimchi is based on an interactive protocol that was transformed into a non-interactive one using the [Fiat-Shamir](https://o1-labs.github.io/proof-systems/plonk/fiat_shamir.html) transform. For this reason, it can be useful to visualize the high-level interactive protocol before the transformation: ```mermaid diff --git a/book/src/SUMMARY.md b/book/src/SUMMARY.md index f30839da67..be318bb0a6 100644 --- a/book/src/SUMMARY.md +++ b/book/src/SUMMARY.md @@ -7,7 +7,7 @@ - [Terminology](./fundamentals/zkbook_foundations.md) - [Groups](./fundamentals/zkbook_groups.md) - [Rings](./fundamentals/zkbook_rings.md) -- [Fields](./fundamentals/zkbook.md) +- [Fields](./fundamentals/zkbook_fields.md) - [Polynomials](./fundamentals/zkbook_polynomials.md) - [Multiplying Polynomials](./fundamentals/zkbook_multiplying_polynomials.md) - [Fast Fourier Transform](./fundamentals/zkbook_fft.md) @@ -62,9 +62,12 @@ # Pickles & Inductive Proof Systems -- [Overview](./fundamentals/zkbook_ips.md) +- [Overview](./pickles/overview.md) +- [Inductive Proof Systems](./pickles/zkbook_ips.md) - [Accumulation](./pickles/accumulation.md) - [Deferred Computation](./pickles/deferred.md) +- [Technical Diagrams](./pickles/diagrams.md) + # Technical Specifications diff --git a/book/src/fundamentals/custom_constraints.md b/book/src/fundamentals/custom_constraints.md deleted file mode 100644 index 8773465d1c..0000000000 --- a/book/src/fundamentals/custom_constraints.md +++ /dev/null @@ -1 +0,0 @@ -# Custom constraints diff --git a/book/src/fundamentals/zkbook_2pc/fkaes.md b/book/src/fundamentals/zkbook_2pc/fkaes.md index 78208909c7..9c7003483a 100644 --- a/book/src/fundamentals/zkbook_2pc/fkaes.md +++ b/book/src/fundamentals/zkbook_2pc/fkaes.md @@ -8,4 +8,4 @@ Several instantiations of hash function based on fix-key AES are proposed inspir The TCCR hash based on fixed-key AES is defined as follows. $$\sH(i,x) = \mathsf{TCCR\_Hash}(i,x) = \pi(\pi(x)\oplus i)\oplus\pi(x),$$ -where $x$ is a $128$-bit string, $i$ a public $128$-bit $\mathsf{tweak}$, and $\pi(x) = \aes(k,x)$ for any fixed-key $k$. \ No newline at end of file +where $x$ is a $128$-bit string, $i$ a public $128$-bit $\mathsf{tweak}$, and $\pi(x) = \aes(k,x)$ for any fixed-key $k$. diff --git a/book/src/fundamentals/zkbook.md b/book/src/fundamentals/zkbook_fields.md similarity index 99% rename from book/src/fundamentals/zkbook.md rename to book/src/fundamentals/zkbook_fields.md index a188640b1f..b378312964 100644 --- a/book/src/fundamentals/zkbook.md +++ b/book/src/fundamentals/zkbook_fields.md @@ -156,5 +156,3 @@ Actually, on a practical level, it's more accurate to model the complexity in te As a result you can see it's the smaller $n$ is the better, especially with respect to multiplication, which dominates performance considerations for implementations of zk-SNARKs, since they are dominated by elliptic curve operations that consist of field operations. While still in development, Mina used to use a field of 753 bits or 12 limbs and now uses a field of 255 bits or 4 limbs. As a result, field multiplication became automatically sped up by a factor of $12^2 / 4^2 = 9$, so you can see it's very useful to try to shrink the field size. - - diff --git a/book/src/fundamentals/zkbook_groups.md b/book/src/fundamentals/zkbook_groups.md index cd16a26946..5910bbc8d4 100644 --- a/book/src/fundamentals/zkbook_groups.md +++ b/book/src/fundamentals/zkbook_groups.md @@ -112,7 +112,7 @@ Another really important assumption is the no-relation assumption (TODO: name). Basically what this says is that randomly selected group elements have no efficiently computable linear relations between them. Formally, letting $G$ be a group and $\mu$ a probability distribution on $G$, and $\mathcal{A}$ a set of programs (with resource bounds) that take in a list of group elements as inputs and outputs a list of integers of the same length. -Then the $(G, \mu, \mathcal{A}, N,\varepsilon)$-no-relation assumption says for all $(P,t,m) \in \mathcal{A}$, for any $n \leq N$, if you sample $g_1, \dots, g_n$ according to $G$, letting $[e_1, \dots, e_n] = P([g_1, \dots, g_n])$, it is not the case that +Then the $(G, \mu, \mathcal{A}, N,\varepsilon)$-no-relation assumption says for all $(P,t,m) \in \mathcal{A}$, for any $n \leq N$, if you sample $g_1, \dots, g_n$ according to $\mu$, letting $[e_1, \dots, e_n] = P([g_1, \dots, g_n])$, it is not the case that $$ e_1 \cdot g_1 + \dots + e_n \cdot g_n = 0 diff --git a/book/src/fundamentals/zkbook_multiplying_polynomials.md b/book/src/fundamentals/zkbook_multiplying_polynomials.md index 92fb4e8325..260b0e0442 100644 --- a/book/src/fundamentals/zkbook_multiplying_polynomials.md +++ b/book/src/fundamentals/zkbook_multiplying_polynomials.md @@ -7,7 +7,7 @@ The heart of Cooley-Tukey FFT is actually about converting between coefficient a Given polynomials $p$ and $q$ in dense coefficient representation, it works like this. 1. Convert $p$ and $q$ from coefficient to evaluation form in $O(n\log n)$ using Cooley-Tukey FFT 2. Compute $r = p*q$ in evaluation form by multiplying their points pairwise in $O(n)$ -3. Convert $r$ back to to coefficient form in $O(n\log n)$ using the inverse Cooley-Tukey FFT +3. Convert $r$ back to coefficient form in $O(n\log n)$ using the inverse Cooley-Tukey FFT The key observation is that we can choose any $n$ distinct evaluation points to represent any degree $n - 1$ polynomial. The Cooley-Tukey FFT works by selecting points that yield an efficient FFT algorithm. These points are fixed and work for any polynomials of a given degree. diff --git a/book/src/img/xor.png b/book/src/img/xor.png deleted file mode 100644 index 28fb7ae536..0000000000 Binary files a/book/src/img/xor.png and /dev/null differ diff --git a/book/src/kimchi/arguments.md b/book/src/kimchi/arguments.md index 1867905caa..9040bfb94f 100644 --- a/book/src/kimchi/arguments.md +++ b/book/src/kimchi/arguments.md @@ -5,6 +5,7 @@ In the previous section we saw how we can prove that certain equations hold for But first, let's recall the table that summarizes some important notation that will be used extensively: ![kimchi](../img/kimchi.png) + One of the first ideas that we must grasp is the notion of a **circuit**. A circuit can be thought of as a set of gates with wires connections between them. The simplest type of circuit that one could think of is a boolean circuit. Boolean circuits only have binary values: $1$ and $0$, `true` and `false`. From a very high level, boolean circuits are like an intricate network of pipes, and the values could be seen as _water_ or _no water_. Then, gates will be like stopcocks, making water flow or not between the pipes. [This video](https://twitter.com/i/status/1188749430020698112) is a cool representation of this idea. Then, each of these _behaviours_ will represent a gate (i.e. logic gates). One can have circuits that can perform more operations, for instance arithmetic circuits. Here, the type of gates available will be arithmetic operations (additions and multiplications) and wires could have numeric values and we could perform any arbitrary computation. But if we loop the loop a bit more, we could come up with a single `Generic` gate that could represent any arithmetic operation at once. This thoughtful type of gate is the one gate whose concatenation is used in Plonk to describe any relation. Apart from it's wires, these gates are tied to an array of **coefficients** that help describe the functionality. But the problem of this gate is that it takes a large set of them to represent any meaningful function. So instead, recent Plonk-like proof systems have appeared which use **custom gates** to represent repeatedly used functionalities more efficiently than as a series of generic gates connected to each other. Kimchi is one of these protocols. Currently, we give support to specific gates for the `Poseidon` hash function, the `CompleteAdd` operation for curve points, `VarBaseMul` for variable base multiplication, `EndoMulScalar` for endomorphism variable base scalar multiplication, `RangeCheck` for range checks and `ForeignFieldMul` and `ForeignFieldAdd` for foreign field arithmetic. Nonetheless, we have plans to further support many other gates soon, possibly even `Cairo` instructions. @@ -15,4 +16,3 @@ The **execution trace** refers to the state of all the wires throughout the circ Going back to the main motivation of the scheme, recall that we wanted to check that certain equations hold in a given set of numbers. Here's where this claim starts to make sense. The total number of rows in the execution trace will give us a **domain**. That is, we define a mapping between each of the row indices of the execution trace and the elements of a multiplicative group $\mathbb{G}$ with as many elements as rows in the table. Two things to note here. First, if no such group exists we can pad with zeroes. Second, any multiplicative group has a generator $g$ whose powers generate the whole group. Then we can assign to each row a power of $g$. Why do we want to do this? Because this will be the set over which we want to check our equations: we can transform a claim about the elements of a group to a claim like _"these properties hold for each of the rows of this table"_. Interesting, right? - diff --git a/book/src/kimchi/extended-lookup-tables.md b/book/src/kimchi/extended-lookup-tables.md index a2205ddebd..2b822fed9b 100644 --- a/book/src/kimchi/extended-lookup-tables.md +++ b/book/src/kimchi/extended-lookup-tables.md @@ -39,7 +39,8 @@ In order to support the desired goals, we first define 3 types of table: * **fixed tables** are tables declared as part of the constraint system, and are the same for every proof of that circuit. * **runtime tables** are tables whose contents are determined at proving time, - to be used for array accesses to data from the witness. + to be used for array accesses to data from the witness. Also called dynamic + tables in other projects. * **side-loaded tables** are tables that are committed to in a proof-independent way, so that they may be reused across proofs and/or signed without knowledge of the specific proof, but may be different for each @@ -470,7 +471,7 @@ This also increases the complexity of the proof system. This design combines the 3 different 'modes' of random access memory, roughly approximating `ROM -> RAM` loading, `disk -> RAM` loading, and `RAM r/w`. This gives the maximum flexibility for circuit designers, and each primitive has -uses for projects that we are persuing. +uses for projects that we are pursuing. This design has been refined over several iterations to reduce the number of components exposed in the proof and the amount of computation that the verifier diff --git a/book/src/kimchi/final_check.md b/book/src/kimchi/final_check.md index 90acf49ad0..fad1059659 100644 --- a/book/src/kimchi/final_check.md +++ b/book/src/kimchi/final_check.md @@ -1,6 +1,6 @@ -# Understanding the implementation of the $f(\zeta) = Z_H(\zeta) t(\zeta)$ check +# Understanding the implementation of the $f(\zeta) = Z_H(\zeta) t(\zeta)$ check -Unlike the latest version of vanilla PLONK that implements the the final check using a polynomial opening (via Maller's optimization), we implement it manually. (That is to say, Izaak implemented Maller's optimization for 5-wires.) +Unlike the latest version of vanilla $\plonk$ that implements the final check using a polynomial opening (via Maller's optimization), we implement it manually. (That is to say, Izaak implemented Maller's optimization for 5-wires.) But the check is not exactly $f(\zeta) = Z_H(\zeta) t(\zeta)$. This note describes how and why the implementation deviates a little. @@ -15,14 +15,14 @@ What is it missing in the permutation part? Let's look more closely. This is wha $$ \begin{align} --1 \cdot z(\zeta \omega) \cdot zkpm(\zeta) \cdot \alpha^{PERM0} \cdot \\ - (w[0](\zeta) + \beta \cdot \sigma[0](\zeta) + \gamma) \cdot \\ - (w[1](\zeta) + \beta \cdot \sigma[1](\zeta) + \gamma) \cdot \\ - (w[2](\zeta) + \beta \cdot \sigma[2](\zeta) + \gamma) \cdot \\ - (w[3](\zeta) + \beta \cdot \sigma[3](\zeta) + \gamma) \cdot \\ - (w[4](\zeta) + \beta \cdot \sigma[4](\zeta) + \gamma) \cdot \\ - (w[5](\zeta) + \beta \cdot \sigma[5](\zeta) + \gamma) \cdot \\ - \beta \cdot \sigma[6](x) +-1\, \cdot\, &z(\zeta \omega) \cdot \mathsf{zkpm}(\zeta) \cdot \alpha^{\mathsf{PERM0}} \\ + \cdot\, &(w[0](\zeta) + \beta \cdot \sigma[0](\zeta) + \gamma) \\ + \cdot\, &(w[1](\zeta) + \beta \cdot \sigma[1](\zeta) + \gamma) \\ + \cdot\, &(w[2](\zeta) + \beta \cdot \sigma[2](\zeta) + \gamma) \\ + \cdot\, &(w[3](\zeta) + \beta \cdot \sigma[3](\zeta) + \gamma) \\ + \cdot\, &(w[4](\zeta) + \beta \cdot \sigma[4](\zeta) + \gamma) \\ + \cdot\, &(w[5](\zeta) + \beta \cdot \sigma[5](\zeta) + \gamma) \\ + \cdot\, &\beta \cdot \sigma[6](x) \end{align} $$ @@ -30,70 +30,72 @@ In comparison, here is the list of the stuff we needed to have: 1. the two sides of the coin: $$\begin{align} - & z(\zeta) \cdot zkpm(\zeta) \cdot \alpha^{PERM0} \cdot \\ - & (w[0](\zeta) + \beta \cdot \text{shift}[0] \zeta + \gamma) \cdot \\ - & (w[1](\zeta) + \beta \cdot \text{shift}[1] \zeta + \gamma) \cdot \\ - & (w[2](\zeta) + \beta \cdot \text{shift}[2] \zeta + \gamma) \cdot \\ - & (w[3](\zeta) + \beta \cdot \text{shift}[3] \zeta + \gamma) \cdot \\ - & (w[4](\zeta) + \beta \cdot \text{shift}[4] \zeta + \gamma) \cdot \\ - & (w[5](\zeta) + \beta \cdot \text{shift}[5] \zeta + \gamma) \cdot \\ - & (w[6](\zeta) + \beta \cdot \text{shift}[6] \zeta + \gamma) + z(\zeta)\, \cdot\, &\mathsf{zkpm}(\zeta) \cdot \alpha^{\mathsf{PERM0}} \\ + \cdot\, &(w[0](\zeta) + \beta \cdot \mathsf{shift}[0](\zeta) + \gamma) \\ + \cdot\, &(w[1](\zeta) + \beta \cdot \mathsf{shift}[1](\zeta) + \gamma) \\ + \cdot\, &(w[2](\zeta) + \beta \cdot \mathsf{shift}[2](\zeta) + \gamma) \\ + \cdot\, &(w[3](\zeta) + \beta \cdot \mathsf{shift}[3](\zeta) + \gamma) \\ + \cdot\, &(w[4](\zeta) + \beta \cdot \mathsf{shift}[4](\zeta) + \gamma) \\ + \cdot\, &(w[5](\zeta) + \beta \cdot \mathsf{shift}[5](\zeta) + \gamma) \\ + \cdot\, &(w[6](\zeta) + \beta \cdot \mathsf{shift}[6](\zeta) + \gamma) \end{align}$$ - with $\text{shift}[0] = 1$ + with $\mathsf{shift}[0] = 1$ 2. and $$\begin{align} - & -1 \cdot z(\zeta \omega) \cdot zkpm(\zeta) \cdot \alpha^{PERM0} \cdot \\ - & (w[0](\zeta) + \beta \cdot \sigma[0](\zeta) + \gamma) \cdot \\ - & (w[1](\zeta) + \beta \cdot \sigma[1](\zeta) + \gamma) \cdot \\ - & (w[2](\zeta) + \beta \cdot \sigma[2](\zeta) + \gamma) \cdot \\ - & (w[3](\zeta) + \beta \cdot \sigma[3](\zeta) + \gamma) \cdot \\ - & (w[4](\zeta) + \beta \cdot \sigma[4](\zeta) + \gamma) \cdot \\ - & (w[5](\zeta) + \beta \cdot \sigma[5](\zeta) + \gamma) \cdot \\ - & (w[6](\zeta) + \beta \cdot \sigma[6](\zeta) + \gamma) \cdot + -1\, \cdot\, &z(\zeta \omega) \cdot \mathsf{zkpm}(\zeta) \cdot \alpha^{\mathsf{PERM0}} \\ + \cdot\, &(w[0](\zeta) + \beta \cdot \sigma[0](\zeta) + \gamma) \\ + \cdot\, &(w[1](\zeta) + \beta \cdot \sigma[1](\zeta) + \gamma) \\ + \cdot\, &(w[2](\zeta) + \beta \cdot \sigma[2](\zeta) + \gamma) \\ + \cdot\, &(w[3](\zeta) + \beta \cdot \sigma[3](\zeta) + \gamma) \\ + \cdot\, &(w[4](\zeta) + \beta \cdot \sigma[4](\zeta) + \gamma) \\ + \cdot\, &(w[5](\zeta) + \beta \cdot \sigma[5](\zeta) + \gamma) \\ + \cdot\, &(w[6](\zeta) + \beta \cdot \sigma[6](\zeta) + \gamma) \end{align}$$ 3. the initialization: - $$(z(\zeta) - 1) L_1(\zeta) \alpha^{PERM1}$$ + $$(z(\zeta) - 1) \cdot L_1(\zeta) \cdot \alpha^{\mathsf{PERM1}}$$ 4. and the end of the accumulator: - $$(z(\zeta) - 1) L_{n-k}(\zeta) \alpha^{PERM2}$$ + $$(z(\zeta) - 1) \cdot L_{n-k}(\zeta) \cdot \alpha^{\mathsf{PERM2}}$$ You can read more about why it looks like that in [this post](https://minaprotocol.com/blog/a-more-efficient-approach-to-zero-knowledge-for-plonk). We can see clearly that the permutation stuff we have in f is clearly lacking. It's just the subtraction part (equation 2), and even that is missing some terms. It is missing exactly this: -$$\begin{align} - & -1 \cdot z(\zeta \omega) \cdot zkpm(\zeta) \cdot \alpha^{PERM0} \cdot \\ - & (w[0](\zeta) + \beta \cdot \sigma[0](\zeta) + \gamma) \cdot \\ - & (w[1](\zeta) + \beta \cdot \sigma[1](\zeta) + \gamma) \cdot \\ - & (w[2](\zeta) + \beta \cdot \sigma[2](\zeta) + \gamma) \cdot \\ - & (w[3](\zeta) + \beta \cdot \sigma[3](\zeta) + \gamma) \cdot \\ - & (w[4](\zeta) + \beta \cdot \sigma[4](\zeta) + \gamma) \cdot \\ - & (w[5](\zeta) + \beta \cdot \sigma[5](\zeta) + \gamma) \cdot \\ - & (w[6](\zeta) + \gamma) - \end{align}$$ - +$$ +\begin{align} + -1\ \cdot\, &z(\zeta \omega) \cdot \mathsf{zkpm}(\zeta) \cdot \alpha^{\mathsf{PERM0}} \\ + \cdot\, &(w[0](\zeta) + \beta \cdot \sigma[0](\zeta) + \gamma) \\ + \cdot\, &(w[1](\zeta) + \beta \cdot \sigma[1](\zeta) + \gamma) \\ + \cdot\, &(w[2](\zeta) + \beta \cdot \sigma[2](\zeta) + \gamma) \\ + \cdot\, &(w[3](\zeta) + \beta \cdot \sigma[3](\zeta) + \gamma) \\ + \cdot\, &(w[4](\zeta) + \beta \cdot \sigma[4](\zeta) + \gamma) \\ + \cdot\, &(w[5](\zeta) + \beta \cdot \sigma[5](\zeta) + \gamma) \\ + \cdot\, &(w[6](\zeta) + \gamma) +\end{align} +$$ + So at the end, when we have to check for the identity $f(\zeta) = Z_H(\zeta) t(\zeta)$ we'll actually have to check something like this (I colored the missing parts on the left hand side of the equation): $$ \begin{align} -& f(\zeta) + \color{darkgreen}{pub(\zeta)} + \\ -& \color{darkred}{z(\zeta) \cdot zkpm(\zeta) \cdot \alpha^{PERM0} \cdot} \\ -& \color{darkred}{(w[0](\zeta) + \beta \zeta + \gamma) \cdot} \\ -& \color{darkred}{(w[1](\zeta) + \beta \cdot \text{shift}[0] \zeta + \gamma) \cdot} \\ -& \color{darkred}{(w[2](\zeta) + \beta \cdot \text{shift}[1] \zeta + \gamma) \cdot} \\ -& \color{darkred}{(w[3](\zeta) + \beta \cdot \text{shift}[2] \zeta + \gamma) \cdot} \\ -& \color{darkred}{(w[4](\zeta) + \beta \cdot \text{shift}[3] \zeta + \gamma) \cdot} \\ -& \color{darkred}{(w[5](\zeta) + \beta \cdot \text{shift}[4] \zeta + \gamma) \cdot} \\ -& \color{darkred}{(w[6](\zeta) + \beta \cdot \text{shift}[5] \zeta + \gamma) +} \\ -& \color{blue}{- z(\zeta \omega) \cdot zkpm(\zeta) \cdot \alpha^{PERM0} \cdot} \\ -& \color{blue}{(w[0](\zeta) + \beta \cdot \sigma[0](\zeta) + \gamma) \cdot} \\ -& \color{blue}{(w[1](\zeta) + \beta \cdot \sigma[1](\zeta) + \gamma) \cdot} \\ -& \color{blue}{(w[2](\zeta) + \beta \cdot \sigma[2](\zeta) + \gamma) \cdot} \\ -& \color{blue}{(w[3](\zeta) + \beta \cdot \sigma[3](\zeta) + \gamma) \cdot} \\ -& \color{blue}{(w[4](\zeta) + \beta \cdot \sigma[4](\zeta) + \gamma) \cdot} \\ -& \color{blue}{(w[5](\zeta) + \beta \cdot \sigma[5](\zeta) + \gamma) \cdot} \\ -& \color{blue}{(w[6](\zeta) + \gamma) +} \\ -& \color{purple}{(z(\zeta) - 1) L_1(\zeta) \alpha^{PERM1} +} \\ -& \color{darkblue}{(z(\zeta) - 1) L_{n-k}(\zeta) \alpha^{PERM2}} \\ +f(\zeta) &+ \color{darkgreen}{\mathsf{pub}(\zeta)} \\ +& \color{darkred}{+ z(\zeta) \cdot \mathsf{zkpm}(\zeta) \cdot \alpha^{\mathsf{PERM0}}} \\ +& \qquad \color{darkred}{\cdot (w[0](\zeta) + \beta \zeta + \gamma)} \\ +& \qquad \color{darkred}{\cdot (w[1](\zeta) + \beta \cdot \mathsf{shift}[0](\zeta) + \gamma)} \\ +& \qquad \color{darkred}{\cdot (w[2](\zeta) + \beta \cdot \mathsf{shift}[1](\zeta) + \gamma)} \\ +& \qquad \color{darkred}{\cdot (w[3](\zeta) + \beta \cdot \mathsf{shift}[2](\zeta) + \gamma)} \\ +& \qquad \color{darkred}{\cdot (w[4](\zeta) + \beta \cdot \mathsf{shift}[3](\zeta) + \gamma)} \\ +& \qquad \color{darkred}{\cdot (w[5](\zeta) + \beta \cdot \mathsf{shift}[4](\zeta) + \gamma)} \\ +& \qquad \color{darkred}{\cdot (w[6](\zeta) + \beta \cdot \mathsf{shift}[5](\zeta) + \gamma)} \\ +& \color{blue}{- z(\zeta \omega) \cdot \mathsf{zkpm}(\zeta) \cdot \alpha^{\mathsf{PERM0}} \cdot} \\ +& \qquad \color{blue}{\cdot (w[0](\zeta) + \beta \cdot \sigma[0](\zeta) + \gamma) \cdot} \\ +& \qquad \color{blue}{\cdot (w[1](\zeta) + \beta \cdot \sigma[1](\zeta) + \gamma) \cdot} \\ +& \qquad \color{blue}{\cdot (w[2](\zeta) + \beta \cdot \sigma[2](\zeta) + \gamma) \cdot} \\ +& \qquad \color{blue}{\cdot (w[3](\zeta) + \beta \cdot \sigma[3](\zeta) + \gamma) \cdot} \\ +& \qquad \color{blue}{\cdot (w[4](\zeta) + \beta \cdot \sigma[4](\zeta) + \gamma) \cdot} \\ +& \qquad \color{blue}{\cdot (w[5](\zeta) + \beta \cdot \sigma[5](\zeta) + \gamma) \cdot} \\ +& \qquad \color{blue}{\cdot (w[6](\zeta) + \gamma)} \\ +& \color{purple}{+ (z(\zeta) - 1) \cdot L_1(\zeta) \cdot \alpha^{\mathsf{PERM1}}} \\ +& \color{darkblue}{+ (z(\zeta) - 1) \cdot L_{n-k}(\zeta) \cdot \alpha^{\mathsf{PERM2}}} \\ & = t(\zeta)(\zeta^n - 1) \end{align} $$ @@ -103,26 +105,26 @@ I highlight what changes in each steps. First, we just move things around: $$ \begin{align} -& f(\zeta) + pub(\zeta) + \\ -& z(\zeta) \cdot zkpm(\zeta) \cdot \alpha^{PERM0} \cdot \\ -& (w[0](\zeta) + \beta \cdot \text{shift}[0] \zeta + \gamma) \cdot \\ -& (w[1](\zeta) + \beta \cdot \text{shift}[1] \zeta + \gamma) \cdot \\ -& (w[2](\zeta) + \beta \cdot \text{shift}[2] \zeta + \gamma) \cdot \\ -& (w[3](\zeta) + \beta \cdot \text{shift}[3] \zeta + \gamma) \cdot \\ -& (w[4](\zeta) + \beta \cdot \text{shift}[4] \zeta + \gamma) \cdot \\ -& (w[5](\zeta) + \beta \cdot \text{shift}[5] \zeta + \gamma) \cdot \\ -& (w[6](\zeta) + \beta \cdot \text{shift}[6] \zeta + \gamma) + \\ -& - z(\zeta \omega) \cdot zkpm(\zeta) \cdot \alpha^{PERM0} \cdot \\ -& (w[0](\zeta) + \beta \cdot \sigma[0](\zeta) + \gamma) \cdot \\ -& (w[1](\zeta) + \beta \cdot \sigma[1](\zeta) + \gamma) \cdot \\ -& (w[2](\zeta) + \beta \cdot \sigma[2](\zeta) + \gamma) \cdot \\ -& (w[3](\zeta) + \beta \cdot \sigma[3](\zeta) + \gamma) \cdot \\ -& (w[4](\zeta) + \beta \cdot \sigma[4](\zeta) + \gamma) \cdot \\ -& (w[5](\zeta) + \beta \cdot \sigma[5](\zeta) + \gamma) \cdot \\ -& (w[6](\zeta) + \gamma) + \\ -& \color{darkred}{- t(\zeta)(\zeta^n - 1)} \\ -& = \color{darkred}{(1 - z(\zeta)) L_1(\zeta) \alpha^{PERM1} +} \\ -& \color{darkred}{(1 - z(\zeta)) L_{n-k}(\zeta) \alpha^{PERM2}} \\ +f(\zeta) &+ \mathsf{pub}(\zeta) \\ + &+ z(\zeta) \cdot \mathsf{zkpm}(\zeta) \cdot \alpha^{\mathsf{PERM0}} \\ +& \qquad \cdot (w[0](\zeta) + \beta \cdot \mathsf{shift}[0] \zeta + \gamma) \\ +& \qquad \cdot (w[1](\zeta) + \beta \cdot \mathsf{shift}[1] \zeta + \gamma) \\ +& \qquad \cdot (w[2](\zeta) + \beta \cdot \mathsf{shift}[2] \zeta + \gamma) \\ +& \qquad \cdot (w[3](\zeta) + \beta \cdot \mathsf{shift}[3] \zeta + \gamma) \\ +& \qquad \cdot (w[4](\zeta) + \beta \cdot \mathsf{shift}[4] \zeta + \gamma) \\ +& \qquad \cdot (w[5](\zeta) + \beta \cdot \mathsf{shift}[5] \zeta + \gamma) \\ +& \qquad \cdot (w[6](\zeta) + \beta \cdot \mathsf{shift}[6] \zeta + \gamma) \\ + &- z(\zeta \omega) \cdot \mathsf{zkpm}(\zeta) \cdot \alpha^{\mathsf{PERM0}} \cdot \\ +& \qquad \cdot (w[0](\zeta) + \beta \cdot \sigma[0](\zeta) + \gamma) \\ +& \qquad \cdot (w[1](\zeta) + \beta \cdot \sigma[1](\zeta) + \gamma) \\ +& \qquad \cdot (w[2](\zeta) + \beta \cdot \sigma[2](\zeta) + \gamma) \\ +& \qquad \cdot (w[3](\zeta) + \beta \cdot \sigma[3](\zeta) + \gamma) \\ +& \qquad \cdot (w[4](\zeta) + \beta \cdot \sigma[4](\zeta) + \gamma) \\ +& \qquad \cdot (w[5](\zeta) + \beta \cdot \sigma[5](\zeta) + \gamma) \\ +& \qquad \cdot (w[6](\zeta) + \gamma) \\ +&\color{darkred}{- t(\zeta)(\zeta^n - 1)} \\ + &= \color{darkred}{(1 - z(\zeta)) L_1(\zeta) \alpha^{\mathsf{PERM1}}} \\ + & \qquad \color{darkred}{+ (1 - z(\zeta)) L_{n-k}(\zeta) \alpha^{\mathsf{PERM2}}} \\ \end{align} $$ @@ -130,64 +132,64 @@ here are the actual lagrange basis calculated with the [formula here](../plonk/l $$ \begin{align} -& f(\zeta) + pub(\zeta) + \\ -& z(\zeta) \cdot zkpm(\zeta) \cdot \alpha^{PERM0} \cdot \\ -& (w[0](\zeta) + \beta \cdot \text{shift}[0] \zeta + \gamma) \cdot \\ -& (w[1](\zeta) + \beta \cdot \text{shift}[1] \zeta + \gamma) \cdot \\ -& (w[2](\zeta) + \beta \cdot \text{shift}[2] \zeta + \gamma) \cdot \\ -& (w[3](\zeta) + \beta \cdot \text{shift}[3] \zeta + \gamma) \cdot \\ -& (w[4](\zeta) + \beta \cdot \text{shift}[4] \zeta + \gamma) \cdot \\ -& (w[5](\zeta) + \beta \cdot \text{shift}[5] \zeta + \gamma) \cdot \\ -& (w[6](\zeta) + \beta \cdot \text{shift}[6] \zeta + \gamma) + \\ -& - z(\zeta \omega) \cdot zkpm(\zeta) \cdot \alpha^{PERM0} \cdot \\ -& (w[0](\zeta) + \beta \cdot \sigma[0](\zeta) + \gamma) \cdot \\ -& (w[1](\zeta) + \beta \cdot \sigma[1](\zeta) + \gamma) \cdot \\ -& (w[2](\zeta) + \beta \cdot \sigma[2](\zeta) + \gamma) \cdot \\ -& (w[3](\zeta) + \beta \cdot \sigma[3](\zeta) + \gamma) \cdot \\ -& (w[4](\zeta) + \beta \cdot \sigma[4](\zeta) + \gamma) \cdot \\ -& (w[5](\zeta) + \beta \cdot \sigma[5](\zeta) + \gamma) \cdot \\ -& (w[6](\zeta) + \gamma) + \\ -& - t(\zeta)(\zeta^n - 1) \\ -& = \color{darkred}{(1 - z(\zeta))[\frac{(\zeta^n - 1)}{n(\zeta - 1)} \alpha^{PERM1} + \frac{\omega^{n-k}(\zeta^n - 1)}{n(\zeta - \omega^{n-k})} \alpha^{PERM2}]} +f(\zeta) + &\mathsf{pub}(\zeta) \\ ++ & z(\zeta) \cdot \mathsf{zkpm}(\zeta) \cdot \alpha^{\mathsf{PERM0}} \\ +& \quad \cdot (w[0](\zeta) + \beta \cdot \mathsf{shift}[0] \zeta + \gamma) \\ +& \quad \cdot (w[1](\zeta) + \beta \cdot \mathsf{shift}[1] \zeta + \gamma) \\ +& \quad \cdot (w[2](\zeta) + \beta \cdot \mathsf{shift}[2] \zeta + \gamma) \\ +& \quad \cdot (w[3](\zeta) + \beta \cdot \mathsf{shift}[3] \zeta + \gamma) \\ +& \quad \cdot (w[4](\zeta) + \beta \cdot \mathsf{shift}[4] \zeta + \gamma) \\ +& \quad \cdot (w[5](\zeta) + \beta \cdot \mathsf{shift}[5] \zeta + \gamma) \\ +& \quad \cdot (w[6](\zeta) + \beta \cdot \mathsf{shift}[6] \zeta + \gamma) + \\ +- & z(\zeta \omega) \cdot \mathsf{zkpm}(\zeta) \cdot \alpha^{\mathsf{PERM0}} \\ +& \quad \cdot (w[0](\zeta) + \beta \cdot \sigma[0](\zeta) + \gamma) \\ +& \quad \cdot (w[1](\zeta) + \beta \cdot \sigma[1](\zeta) + \gamma) \\ +& \quad \cdot (w[2](\zeta) + \beta \cdot \sigma[2](\zeta) + \gamma) \\ +& \quad \cdot (w[3](\zeta) + \beta \cdot \sigma[3](\zeta) + \gamma) \\ +& \quad \cdot (w[4](\zeta) + \beta \cdot \sigma[4](\zeta) + \gamma) \\ +& \quad \cdot (w[5](\zeta) + \beta \cdot \sigma[5](\zeta) + \gamma) \\ +& \quad \cdot (w[6](\zeta) + \gamma) + \\ +- & t(\zeta)(\zeta^n - 1) \\ += & \color{darkred}{(1 - z(\zeta))[\frac{(\zeta^n - 1)}{n(\zeta - 1)} \alpha^{\mathsf{PERM1}} + \frac{\omega^{n-k}(\zeta^n - 1)}{n(\zeta - \omega^{n-k})} \alpha^{\mathsf{PERM2}}]} \end{align} $$ - + finally we extract some terms from the lagrange basis: $$ \begin{align} & \color{darkred}{[} \\ -& \; f(\zeta) + pub(\zeta) + \\ -& \; z(\zeta) \cdot zkpm(\zeta) \cdot \alpha^{PERM0} \cdot \\ -& \; (w[0](\zeta) + \beta \cdot \text{shift}[0] \zeta + \gamma) \cdot \\ -& \; (w[1](\zeta) + \beta \cdot \text{shift}[1] \zeta + \gamma) \cdot \\ -& \; (w[2](\zeta) + \beta \cdot \text{shift}[2] \zeta + \gamma) \cdot \\ -& \; (w[3](\zeta) + \beta \cdot \text{shift}[3] \zeta + \gamma) \cdot \\ -& \; (w[4](\zeta) + \beta \cdot \text{shift}[4] \zeta + \gamma) \cdot \\ -& \; (w[5](\zeta) + \beta \cdot \text{shift}[5] \zeta + \gamma) \cdot \\ -& \; (w[6](\zeta) + \beta \cdot \text{shift}[6] \zeta + \gamma) + \\ -& \; - z(\zeta \omega) \cdot zkpm(\zeta) \cdot \alpha^{PERM0} \cdot \\ -& \; (w[0](\zeta) + \beta \cdot \sigma[0](\zeta) + \gamma) \cdot \\ -& \; (w[1](\zeta) + \beta \cdot \sigma[1](\zeta) + \gamma) \cdot \\ -& \; (w[2](\zeta) + \beta \cdot \sigma[2](\zeta) + \gamma) \cdot \\ -& \; (w[3](\zeta) + \beta \cdot \sigma[3](\zeta) + \gamma) \cdot \\ -& \; (w[4](\zeta) + \beta \cdot \sigma[4](\zeta) + \gamma) \cdot \\ -& \; (w[5](\zeta) + \beta \cdot \sigma[5](\zeta) + \gamma) \cdot \\ -& \; (w[6](\zeta) + \gamma) + \\ -& \; - t(\zeta)(\zeta^n - 1) \\ +& f(\zeta) + \mathsf{pub}(\zeta) \\ +& + z(\zeta) \cdot \mathsf{zkpm}(\zeta) \cdot \alpha^{\mathsf{PERM0}} \\ +& \qquad \cdot (w[0](\zeta) + \beta \cdot \mathsf{shift}[0] \zeta + \gamma) \\ +& \qquad \cdot (w[1](\zeta) + \beta \cdot \mathsf{shift}[1] \zeta + \gamma) \\ +& \qquad \cdot (w[2](\zeta) + \beta \cdot \mathsf{shift}[2] \zeta + \gamma) \\ +& \qquad \cdot (w[3](\zeta) + \beta \cdot \mathsf{shift}[3] \zeta + \gamma) \\ +& \qquad \cdot (w[4](\zeta) + \beta \cdot \mathsf{shift}[4] \zeta + \gamma) \\ +& \qquad \cdot (w[5](\zeta) + \beta \cdot \mathsf{shift}[5] \zeta + \gamma) \\ +& \qquad \cdot (w[6](\zeta) + \beta \cdot \mathsf{shift}[6] \zeta + \gamma) + \\ +& - z(\zeta \omega) \cdot \mathsf{zkpm}(\zeta) \cdot \alpha^{\mathsf{PERM0}} \\ +& \qquad \cdot (w[0](\zeta) + \beta \cdot \sigma[0](\zeta) + \gamma) \\ +& \qquad \cdot (w[1](\zeta) + \beta \cdot \sigma[1](\zeta) + \gamma) \\ +& \qquad \cdot (w[2](\zeta) + \beta \cdot \sigma[2](\zeta) + \gamma) \\ +& \qquad \cdot (w[3](\zeta) + \beta \cdot \sigma[3](\zeta) + \gamma) \\ +& \qquad \cdot (w[4](\zeta) + \beta \cdot \sigma[4](\zeta) + \gamma) \\ +& \qquad \cdot (w[5](\zeta) + \beta \cdot \sigma[5](\zeta) + \gamma) \\ +& \qquad \cdot (w[6](\zeta) + \gamma) \\ +& - t(\zeta)(\zeta^n - 1) \\ & \color{darkred}{] \cdot (\zeta - 1)(\zeta - \omega^{n-k})} & \\ -& = \color{darkred}{(1 - z(\zeta))[\frac{(\zeta^n - 1)(\zeta - \omega^{n-k})}{n} \alpha^{PERM1} + \frac{\omega^{n-k}(\zeta^n - 1)(\zeta - 1)}{n} \alpha^{PERM2}]} +& = \color{darkred}{(1 - z(\zeta))\Bigg[\frac{(\zeta^n - 1)(\zeta - \omega^{n-k})}{n} \alpha^{\mathsf{PERM1}} + \frac{\omega^{n-k}(\zeta^n - 1)(\zeta - 1)}{n} \alpha^{\mathsf{PERM2}}\Bigg]} \end{align} $$ -with $\alpha^{PERM0} = \alpha^{17}, \alpha^{PERM1} = \alpha^{18}, \alpha^{PERM2} = \alpha^{19}$ +with $\alpha^{\mathsf{PERM0}} = \alpha^{17}, \alpha^{\mathsf{PERM1}} = \alpha^{18}, \alpha^{\mathsf{PERM2}} = \alpha^{19}$ -Why do we do things this way? Most likely to reduce +Why do we do things this way? Most likely to reduce Also, about the code: -* the division by $n$ in the $\alpha^{PERM1}$ and the $\alpha^{PERM2}$ terms is missing (why?) -* the multiplication by $w^{n-k}$ in the $\alpha^{PERM2}$ term is missing (why?) +* the division by $n$ in the $\alpha^{\mathsf{PERM1}}$ and the $\alpha^{\mathsf{PERM2}}$ terms is missing (why?) +* the multiplication by $w^{n-k}$ in the $\alpha^{\mathsf{PERM2}}$ term is missing (why?) As these terms are constants, and do not prevent the division by $Z_H$ to form $t$, $t$ also omits them. In other word, they cancel one another. @@ -211,29 +213,25 @@ t += &bnd; // t is evaluated at zeta and sent... ``` -Notice that **`bnd` is not divided by the vanishing polynomial**. Also what's `perm`? Let's answer the latter: +Notice that **`bnd` is not divided by the vanishing polynomial**. Also what's `perm`? Let's answer the latter TESTREMOVEME: $$ \begin{align} -perm(x) = & \\ - & a^{PERM0} \cdot zkpl(x) \cdot [ \\ - & \;\; z(x) \cdot \\ - & \;\; (w[0](x) + \gamma + x \cdot \beta \cdot \text{shift}[0]) \cdot \\ - & \;\; (w[1](x) + \gamma + x \cdot \beta \cdot \text{shift}[1]) \cdot \cdots \\ - & \; - \\ - & \;\; z(x \cdot w) \cdot \\ - & \;\; (w[0](x) + \gamma + \sigma[0] \cdot \beta) \cdot \\ - & \;\; (w[1](x) + \gamma + \sigma[1] \cdot \beta) \cdot ... \\ +\mathsf{perm}(x) =\ & a^{\mathsf{PERM0}} \cdot \mathsf{zkpl}(x) \cdot [ \\ + & z(x) \cdot (w[0](x) + \gamma + x \cdot \beta \cdot \mathsf{shift}[0]) \\ + & \qquad \cdot (w[1](x) + \gamma + x \cdot \beta \cdot \mathsf{shift}[1]) \cdot \ldots \\ + - &z(x \cdot w) \cdot (w[0](x) + \gamma + \sigma[0] \cdot \beta) \\ + & \qquad \quad \ \ \cdot (w[1](x) + \gamma + \sigma[1] \cdot \beta) \cdot \ldots \\ &] \end{align} $$ and `bnd` is: -$$bnd(x) = - a^{PERM1} \cdot \frac{z(x) - 1}{x - 1} +$$\mathsf{bnd}(x) = + a^{\mathsf{PERM1}} \cdot \frac{z(x) - 1}{x - 1} + - a^{PERM2} \cdot \frac{z(x) - 1}{x - sid[n-k]} + a^{\mathsf{PERM2}} \cdot \frac{z(x) - 1}{x - \mathsf{sid}[n-k]} $$ you can see that some terms are missing in `bnd`, specifically the numerator $x^n - 1$. Well, it would have been cancelled by the division by the vanishing polynomial, and this is the reason why we didn't divide that term by $Z_H(x) = x^n - 1$. diff --git a/book/src/kimchi/foreign_field_add.md b/book/src/kimchi/foreign_field_add.md index 4425df795f..5168d8718b 100644 --- a/book/src/kimchi/foreign_field_add.md +++ b/book/src/kimchi/foreign_field_add.md @@ -5,17 +5,17 @@ This document is meant to explain the foreign field addition gate in Kimchi. ## Overview The goal of this gate is to perform the following operation between a left $a$ and a right $b$ input, to obtain a result $r$ -$$a + s * b = r \mod f$$ -where $a,b,r\in\mathbb{F}_f$ belong to the foreign field of modulus $f$ and we work over a native field $\mathbb{F}_n$ with modulus $n$, and $s$ is a flag that is either $-1$ or $1$ to indicate whether it is a subtraction or addition gate. +$$a + s \cdot b = r \mod f$$ +where $a,b,r\in\mathbb{F}_f$ belong to the _foreign field_ $\mathbb{F}_f$ of modulus $f$ and we work over a native field $\mathbb{F}_n$ of modulus $n$, and $s \in \{-1, 1\}$ is a flag indicating whether it is a subtraction or addition gate. -If $f < n$ then we can easily perform the above computation. But in this gate we are interested in the contrary case in which $f>n$. In order to deal with this, we will divide foreign field elements into limbs that fit in our native field. We want to be compatible with the foreign field multiplication gate, and thus the parameters we will be using are the following: +If $f < 2 \cdot n$ then we can easily perform the above computation natively since no overflows happen. But in this gate we are interested in the contrary case in which $f>n$. In order to deal with this, we will divide foreign field elements into limbs that fit in our native field. We want to be compatible with the foreign field multiplication gate, and thus the parameters we will be using are the following: - 3 limbs of 88 bits each (for a total of 264 bits) -- $2^{264} > f$ -- our foreign field will have 256-bit length -- our native field has 255-bit length +- So $f < 2^{264}$, concretely: + - The modulus $f$ of the foreign field $\mathbb{F}_f$ is 256 bit + - The modulus of our the native field $\mathbb{F}_n$ is 255 bit -In other words, using 3 limbs of 88 bits each allows us to represent any foreign field element in the range $[0,2^{264})$ for foreign field addition, but only up to $2^{259}$ for foreign field multiplication. Thus, with the current configuration of our limbs, our foreign field must be smaller than $2^{259}$ (because $2^{264} \cdot 2^{255} > {2^{259}}^2 + 2^{259}$, more on this in the [FFmul RFC](https://github.com/o1-labs/rfcs/blob/main/0006-ffmul-revised.md). +In other words, using 3 limbs of 88 bits each allows us to represent any foreign field element in the range $[0,2^{264})$ for foreign field addition, but only up to $2^{259}$ for foreign field multiplication. Thus, with the current configuration of our limbs, our foreign field must be smaller than $2^{259}$ (because $2^{264} \cdot 2^{255} > {2^{259}}^2 + 2^{259}$, more on this in [Foreign Field Multiplication](../kimchi/foreign_field_mul.md) or the original [FFmul RFC](https://github.com/o1-labs/rfcs/blob/main/0006-ffmul-revised.md). ### Splitting the addition @@ -30,11 +30,11 @@ b = (-------b0-------|-------b1-------|-------b2-------) = r = (-------r0-------|-------r1-------|-------r2-------) mod(f) ``` -We will perform the addition in 3 steps, one per limb. Now, when we split long additions in this way, we must be careful with carry bits between limbs. Also, even if $a$ and $b$ are foreign field elements, it could be the case that $a + b$ is already larger than the modulus (such addition could be at most $2f - 2$). But it could also be the case that the subtraction produces an underflow because $a < b$ (with a difference of at most $1 - f$). Thus we will have to consider the more general case. That is, +We will perform the addition in 3 steps, one per limb. Now, when we split long additions in this way, we must be careful with carry bits between limbs. Also, even if $a$ and $b$ are foreign field elements, it could be the case that $a + b$ is already larger than the modulus (in this case $a + b$ could be at most $2f - 2$). But it could also be the case that the subtraction produces an underflow because $a < b$ (with a difference of at most $1 - f$). Thus we will have to consider the more general case. That is, $$ a + s \cdot b = q \cdot f + r \mod 2^{264}$$ -with a field overflow term $q$ that will be either $0$ (if no underflow nor overflow is produced) or $1$ (if there is overflow with $s = 1$) or $-1$ (if there is underflow with $s = -1$). Looking at this in limb form, we have: +with a field overflow term $q \in \{-1,0,1\}$ that will be either $0$ (if no underflow nor overflow is produced), $1$ (if there is overflow with $s = 1$) or $-1$ (if there is underflow with $s = -1$). Looking at this in limb form, we have: ```text bits 0..............87|88...........175|176...........263 @@ -52,7 +52,7 @@ f = (-------f0-------|-------f1-------|-------f2-------) r = (-------r0-------|-------r1-------|-------r2-------) ``` -First, if $a + b$ was larger than $f$, then we will have a field overflow (represented by $q = 1$) and we will have to subtract $f$ from the sum $a + b$ to obtain $r$. Whereas the foreign field overflow necessitates an overflow bit $q$ for the foreign field equation above, when $q = 1$ there is a corresponding subtraction that may introduce carries (or even borrows) between the limbs. This is because $r = a + b - q \cdot f \mod 2^{264}$. Therefore, in the equations for the limbs we will use a carry flag $c_i$ for limb $i$ to represent both carries and borrows. The carry flags can be anything in $\{-1, 0, 1\}$, where $c_i = -1$ represents a borrow and $c_i = 1$ represents a carry. Next we see more clearly how this works. +First, if $a + b$ is larger than $f$, then we will have a field overflow (represented by $q = 1$) and thus will have to subtract $f$ from the sum $a + b$ to obtain $r$. Whereas the foreign field overflow necessitates an overflow bit $q$ for the foreign field equation above, when $q = 1$ there is a corresponding subtraction that may introduce carries (or even borrows) between the limbs. This is because $r = a + b - q \cdot f \mod 2^{264}$. Therefore, in the equations for the limbs we will use a carry flag $c_i$ for limb $i$ to represent both carries and borrows. The carry flags $c_i$ are in $\{-1, 0, 1\}$, where $c_i = -1$ represents a borrow and $c_i = 1$ represents a carry. Next we explain how this works. In order to perform this operation in parts, we first take a look at the least significant limb, which is the easiest part. This means we want to know how to compute $r_0$. First, if the addition of the bits in $a_0$ and $b_0$ produce a carry (or borrow) bit, then it should propagate to the second limb. That means one has to subtract $2^{88}$ from $a_0 + b_0$, add $1$ to $a_1 + b_1$ and set the low carry flag $c_0$ to 1 (otherwise it is zero). Thus, the equation for the lowest bit is @@ -60,7 +60,7 @@ $$a_0 + s \cdot b_0 = q \cdot f_0 + r_0 + c_0 \cdot 2^{88}$$ Or put in another way, this is equivalent to saying that $a_0 + b_0 - q \cdot f_0 - r_0$ is a multiple of $2^{88}$ (or, the existence of the carry coefficient $c_0$). -This kind of equation needs an additional check that the carry coefficient $c_0$ is a correct carry value (`0`, `1`, or `-1`). We will use this idea for the remaining limbs as well. +This kind of equation needs an additional check that the carry coefficient $c_0$ is a correct carry value (belonging to $\{-1,0,1\}$). We will use this idea for the remaining limbs as well. Looking at the second limb, we first need to observe that the addition of $a_1$ and $b_1$ can, not only produce a carry bit $c_1$, but they may need to take into account the carry bit from the first limb; $c_0$. Similarly to the above, @@ -85,7 +85,7 @@ s = 1 | -1 b = (-------b0-------|-------b1-------|-------b2-------) > > > = c_0 c_1 c_2 -q = -1 |0 | 1 +q = -1 | 0 | 1 · f = (-------f0-------|-------f1-------|-------f2-------) + @@ -95,24 +95,27 @@ r = (-------r0-------|-------r1-------|-------r2-------) Our witness computation is currently using the `BigUint` library, which takes care of all of the intermediate carry limbs itself so we do not have to address all casuistries ourselves (such as propagating the low carry flag to the high limb in case the middle limb is all zeros). -### Upper bound check +### Upper bound check for foreign field membership Last but not least, we should perform some range checks to make sure that the result $r$ is contained in $\mathbb{F}_f$. This is important because there could be other values of the result which still fit in $<2^{264}$ but are larger than $f$, and we must make sure that the final result is the minimum one (that we will be referring to as $r_{min}$ in the following). -Ideally, we would like to reuse some gates that we already have. In particular, we can perform range checks for $0\leq X <2^{3\ell}=2^{3\cdot 88}$. But we want to check that $0 \leq r_{min} < f$. The way we can tweak this gate to behave as we want, is the following. First, the above inequality is equivalent to saying that $-f \leq r_{min} - f < 0$. Then we add $2^{264}$ on both sides to obtain $2^{264} - f \leq r_{min} - f + 2^{264} < 2^{264}$. Thus, all there is left to check is that a given bound value is indeed computed correctly. Meaning, that an upperbound term $u$ is correctly obtained as $r_{min} + 2^{264} - f$. Even if we could apply this check for all intermediate results of foreign field additions, it is sufficient to apply it only once at the end of the computations. +Ideally, we would like to reuse some gates that we already have. In particular, we can perform range checks for $0\leq X <2^{3\ell}=2^{3\cdot 88}$. But we want to check that $0 \leq r_{min} < f$, which is a smaller range. The way we can tweak the existing range check gate to behave as we want, is as follows. + +First, the above inequality is equivalent to $-f \leq r_{min} - f < 0$. Add $2^{264}$ on both sides to obtain $2^{264} - f \leq r_{min} - f + 2^{264} < 2^{264}$. Thus, we can perform the standard $r_{min} \in [0,2^{264})$ check together with the $r_{min} - f + 2^{264} \in [0,2^{264})$, which together will imply $r_{min} \in [0,f)$ as intended. All there is left to check is that a given upperbound term is correctly obtained; calling it $u$, we want to enforce $u := r_{min} + 2^{264} - f$. -This is very similar to a foreign field addition, but simpler: the field overflow bit will always be $1$, the sign is positive $1$, and the right input is formed by the limbs $(0, 0, 2^{88})$ standing for the number $2^{264}$. There could be intermediate limb carry bits $k_0$ and $k_1$ as above. Observe that, because the sum is meant to be $<2^{264}$, then the carry bit for the most significant limb should always be zero $k_2 = 0$, so we do not use it. Happily, we can apply the addition gate again to perform the addition limb-wise, by selecting the following parameters: + +The following computation is very similar to a foreign field addition ($r_{min} + 2^{264} = u \mod f$), but simpler. The field overflow bit will always be $1$, the main operation sign is positive $1$ (we are doing addition, not subtraction), and the right input $2^{264}$, call it $g$, is represented by the limbs $(0, 0, 2^{88})$. There could be intermediate limb carry bits $k_0$ and $k_1$ as in the general case FF addition protocol. Observe that, because the sum is meant to be $<2^{264}$, the carry bit for the most significant limb should always be zero $k_2 = 0$, so this condition is enforced implicitly by omitting $k_2$ from the equations. Happily, we can apply the addition gate again to perform the addition limb-wise, by selecting the following parameters: $$ \begin{aligned} -a_0 &=& r_{min_{0}}, \\ a_1 &=& r_{min_{1}}, \\ a_2 &=& r_{min_{2}} , \\ -b_0 &=& 0, \\ b_1 &=& 0, \\ b_2 &=& 2^{88} , \\ +a_0 &=& r_{min_{0}} \\ a_1 &=& r_{min_{1}} \\ a_2 &=& r_{min_{2}} \\ +b_0 &=& 0 \\ b_1 &=& 0 \\ b_2 &=& 2^{88} \\ s &=& 1 \\ q &=& 1 \end{aligned} $$ -Calling $u$ the upper bound term, the equation $r_{min} + 2^{264} - f $ can be expressed as $r_{min} + 2^{264} = 1 * f + u$. Finally, we perform a range check on the sum $u$, and we would know that $r_{min} < f$. +Calling $u$ the upper bound term, the equation $r_{min} + 2^{264} - f $ can be expressed as $r_{min} + 2^{264} = 1 \cdot f + u$. Finally, we perform a range check on the sum $u$, and we would know that $r_{min} < f$. ```text @@ -142,6 +145,8 @@ $$ But now we also have to check that $0\leq r$. But this is implicitly covered by $r$ being a field element of at most 264 bits (range check). +When we have a chain of additions $a_i + b_i = r_i$ with $a_i = {r_{i-1}}$, we could apply the field membership check naively to every intermediate $r_i$, however it is sufficient to apply it only once at the end of the computations for $r_{n}$, and keep intermediate $r_i \in [0,2^{264})$, in a "raw" form. Generally speaking, treating intermediate values lazily helps to save a lot of computation in many different FF addition and multiplication scenarios. + ### Subtractions Mathematically speaking, a subtraction within a field is no more than an addition over that field. Negative elements are not different from "positive" elements in finite fields (or in any modular arithmetic). Our witness computation code computes negative sums by adding the modulus to the result. To give a general example, the element $-e$ within a field $\mathbb{F}_m$ of order $m$ and $e < m$ is nothing but $m - e$. Nonetheless, for arbitrarily sized elements (not just those smaller than the modulus), the actual field element could be any $c \cdot m - e$, for any multiple $c \cdot m$ of the modulus. Thus, representing negative elements directly as "absolute" field elements may incur in additional computations involving multiplications and thus would result in a less efficient mechanism. @@ -155,14 +160,14 @@ So far, one can recompute the result as follows: ```text bits 0..............87|88...........175|176...........263 r = (-------r0-------|-------r1-------|-------r2-------) -= += a = (-------a0-------|-------a1-------|-------a2-------) + s = 1 | -1 · b = (-------b0-------|-------b1-------|-------b2-------) -- -q = -1 |0 | 1 +- +q = -1 | 0 | 1 · f = (-------f0-------|-------f1-------|-------f2-------) > > > @@ -174,18 +179,18 @@ Inspired by the halving approach in foreign field multiplication, an optimized v ```text bits 0..............87|88...........175|176...........263 r = (-------r0------- -------r1-------|-------r2-------) -= += a = (-------a0------- -------a1-------|-------a2-------) + s = 1 | -1 · b = (-------b0------- -------b1-------|-------b2-------) -- -q = -1 |0 | 1 +- +q = -1 | 0 | 1 · f = (-------f0------- -------f1-------|-------f2-------) > > - c 0 + c 0 ``` These are the new equations: @@ -193,18 +198,18 @@ f = (-------f0------- -------f1-------|-------f2-------) $$ \begin{aligned} r_{bot} &= (a_0 + 2^{88} \cdot a_1) + s \cdot (b_0 + 2^{88} \cdot b_1) - q \cdot (f_0 + 2^{88} \cdot f_1) - c \cdot 2^{176} \\ -r_{top} &= a_2 + s \cdot b_2 - q \cdot f_2 + c +r_{top} &= a_2 + s \cdot b_2 - q \cdot f_2 + c \end{aligned} $$ -with $r_{top} = r_2$ and $c = c_1$. +with $r_{top} = r_2$ and $c = c_1$. ## Gadget -The full foreign field addition/subtraction gadget will be composed of: +The full foreign field addition/subtraction gadget will be composed of: - $1$ public input row containing the value $1$; - $n$ rows with `ForeignFieldAdd` gate type: - - for the actual $n$ chained additions or subtractions; + - for the actual $n$ chained additions or subtractions; - $1$ `ForeignFieldAdd` row for the bound addition; - $1$ `Zero` row for the final bound addition. - $1$ `RangeCheck` gadget for the first left input of the chain $a_1 := a$; @@ -227,7 +232,7 @@ A total of 20 rows with 15 columns in Kimchi for 1 addition. All ranges below ar | 9n+7..9n+10 | `multi-range-check` | $u$ | -This mechanism can chain foreign field additions together. Initially, there are $n$ foreign field addition gates, followed by a foreign field addition gate for the bound addition (whose current row corresponds to the next row of the last foreign field addition gate), and an auxiliary `Zero` row that holds the upper bound. At the end, an initial left input range check is performed, which is followed by a $n$ pairs of range check gates for the right input and intermediate result (which become the left input for the next iteration). After the chained inputs checks, a final range check on the bound takes place. +This mechanism can chain foreign field additions together. Initially, there are $n$ foreign field addition gates, followed by a foreign field addition gate for the bound addition (whose current row corresponds to the next row of the last foreign field addition gate), and an auxiliary `Zero` row that holds the upper bound. At the end, an initial left input range check is performed, which is followed by a $n$ pairs of range check gates for the right input and intermediate result (which become the left input for the next iteration). After the chained inputs checks, a final range check on the bound takes place. For example, chaining the following set of 3 instructions would result in a full gadget with 37 rows: @@ -250,7 +255,7 @@ $$add(add(add(a,b),c),d)$$ | 30..33 | `multi-range-check` | $a+b+c+d$ | | 34..37 | `multi-range-check` | bound | -Nonetheless, such an exhaustive set of checks are not necessary for completeness nor soundness. In particular, only the very final range check for the bound is required. Thus, a shorter gadget that is equally valid and takes $(8*n+4)$ fewer rows could be possible if we can assume that the inputs of each addition are correct foreign field elements. It would follow the next layout (with inclusive ranges): +Nonetheless, such an exhaustive set of checks are not necessary for completeness nor soundness. In particular, only the very final range check for the bound is required. Thus, a shorter gadget that is equally valid and takes $(8\cdot n+4)$ fewer rows could be possible if we can assume that the inputs of each addition are correct foreign field elements. It would follow the next layout (with inclusive ranges): | Row(s) | Gate type(s) | Witness | | -------- | ----------------------------------------------------- | ------- | @@ -260,7 +265,7 @@ Nonetheless, such an exhaustive set of checks are not necessary for completeness | n+2 | `Zero` | | | n+3..n+6 | `multi-range-check` for `bound` | $u$ | -Otherwise, we would need range checks for each new input of the chain, but none for intermediate results; implying $4\cdot n$ fewer rows. +Otherwise, we would need range checks for each new input of the chain, but none for intermediate results; implying $4\cdot n$ fewer rows. | Row(s) | Gate type(s) | Witness | | --------------- | ----------------------------------------------------- | --------- | @@ -273,7 +278,7 @@ Otherwise, we would need range checks for each new input of the chain, but none | 5n+7..5n+10 | `multi-range-check` for bound | $u$ | -For more details see the Bound Addition section of the [Foreign Field Multiplication RFC](https://github.com/o1-labs/rfcs/blob/main/0006-ffmul-revised.md). +For more details see the Bound Addition section in [Foreign Field Multiplication](../kimchi/foreign_field_mul.md) or the original [Foreign Field Multiplication RFC](https://github.com/o1-labs/rfcs/blob/main/0006-ffmul-revised.md). ### Layout @@ -326,4 +331,4 @@ TODO: decide if we really want to keep this check or leave it implicit as it is When we use this gate as part of a larger chained gadget, we should optimize the number of range check rows to avoid redundant checks. In particular, if the result of an addition becomes one input of another addition, there is no need to check twice that the limbs of that term have the right length. -The sign is now part of the coefficients of the gate. This will allow us to remove the sign check constraint and release one witness cell element. But more importantly, it brings soundness to the gate as it is now possible to wire the $1$ public input to the overflow witness of the final bound check of every addition chain. \ No newline at end of file +The sign is now part of the coefficients of the gate. This will allow us to remove the sign check constraint and release one witness cell element. But more importantly, it brings soundness to the gate as it is now possible to wire the $1$ public input to the overflow witness of the final bound check of every addition chain. diff --git a/book/src/kimchi/foreign_field_mul.md b/book/src/kimchi/foreign_field_mul.md index a2bbf6c361..a2b37ed753 100644 --- a/book/src/kimchi/foreign_field_mul.md +++ b/book/src/kimchi/foreign_field_mul.md @@ -1,6 +1,7 @@ -# Foreign Field Multiplication RFC +# Foreign Field Multiplication -This document explains how we constrain foreign field multiplication (i.e. non-native field multiplication) in the Kimchi proof system. +This document is an original RFC explaining how we constrain foreign field multiplication (i.e. non-native field multiplication) in the Kimchi proof system. +For a more recent RFC, see [FFmul RFC](https://github.com/o1-labs/rfcs/blob/main/0006-ffmul-revised.md). **Changelog** @@ -175,7 +176,7 @@ $$ m_{max} = \frac{t + \log_2(n)}{2}. $$ -With $t=258, n=255$ we have  +With $t=258, n=255$ we have $$ \begin{aligned} @@ -456,7 +457,7 @@ All that remains is to range check $v_0$ and $v_1$ When unconstrained, computing $u_0 = p_0 + 2^{\ell} \cdot p_{10} - (r_0 + 2^{\ell} \cdot r_1)$ could require borrowing. Fortunately, we have the constraint that the $2\ell$ least significant bits of $u_0$ are `0` (i.e. $u_0 = 2^{2\ell} \cdot v_0$), which means borrowing cannot happen. -Borrowing is prevented because when the the $2\ell$ least significant bits of the result are `0` it is not possible for the minuend to be less than the subtrahend. We can prove this by contradiction. +Borrowing is prevented because when the $2\ell$ least significant bits of the result are `0` it is not possible for the minuend to be less than the subtrahend. We can prove this by contradiction. Let @@ -552,7 +553,7 @@ This requires a single constraint of the form > 9. $a_n \cdot b_n - q_n \cdot f_n = r_n$ -with all of the terms expanded into the limbs according the the above equations. The values $a_n, b_n, q_n, f_n$ and $r_n$ do not need to be in the witness. +with all of the terms expanded into the limbs according the above equations. The values $a_n, b_n, q_n, f_n$ and $r_n$ do not need to be in the witness. ## Range check both sides of $a \cdot b = q \cdot f + r$ @@ -609,7 +610,7 @@ $$ ### Bound checks -To perform the above range checks we use the *upper bound check* method described in the [Foreign Field Addition](../kimchi/foreign_field_add.md#upper-bound-check). +To perform the above range checks we use the *upper bound check* method described in the upper bound check section in [Foreign Field Addition](../kimchi/foreign_field_add.md#upper-bound-check). The upper bound check is as follows. We must constrain $0 \le q < f$ over the positive integers, so @@ -1008,7 +1009,7 @@ We have proven that $q'_{carry2}$ is always zero, so that allows use to simplify In other words, we have eliminated constraint (7) and removed $q'_{carry2}$ from the witness. -Since we already needed to range-check $q$ or $q'$, the total number of new constraints added is 4: 3 added to to `ForeignFieldMul` and 1 added to `multi-range-check` gadget for constraining the decomposition of $q'_{01}$. +Since we already needed to range-check $q$ or $q'$, the total number of new constraints added is 4: 3 added to `ForeignFieldMul` and 1 added to `multi-range-check` gadget for constraining the decomposition of $q'_{01}$. This saves 2 rows per multiplication. diff --git a/book/src/kimchi/gates.md b/book/src/kimchi/gates.md index 529538377e..f15693eaa4 100644 --- a/book/src/kimchi/gates.md +++ b/book/src/kimchi/gates.md @@ -186,7 +186,7 @@ Finally, providing $q(X) = \text{gates}(X)/v_{\mathbb{G}}(X)$ and performing the _The prover knows a polynomial_ $\text{gates}(X)$ _that equals zero on any_ $x\in\{1,g,g^2,g^3\}$. -Nonetheless, it would still remain to verify that $\text{gates}(X)$ actually corresponds to the encoding of the actual constraints. Meaning, checking that this polynomial encodes the column witnesses, and the agreed circuit. So instead of providing just $\text{gates}(X)$ (actually, a commmitment to it), the prover can send commitments to each of th $15$ witness polynomials, so that the verifier can reconstruct the huge constraint using the encodings of the circuit (which is known ahead). +Nonetheless, it would still remain to verify that $\text{gates}(X)$ actually corresponds to the encoding of the actual constraints. Meaning, checking that this polynomial encodes the column witnesses, and the agreed circuit. So instead of providing just $\text{gates}(X)$ (actually, a commitment to it), the prover can send commitments to each of th $15$ witness polynomials, so that the verifier can reconstruct the huge constraint using the encodings of the circuit (which is known ahead). diff --git a/book/src/kimchi/lookup.md b/book/src/kimchi/lookup.md index a261672716..7893d92121 100644 --- a/book/src/kimchi/lookup.md +++ b/book/src/kimchi/lookup.md @@ -1,44 +1,55 @@ -# Plookup in Kimchi +# $\plookup$ in Kimchi -In 2020, [plookup](https://eprint.iacr.org/2020/315.pdf) showed how to create lookup proofs. Proofs that some witness values are part of a [lookup table](https://en.wikipedia.org/wiki/Lookup_table). Two years later, an independent team published [plonkup](https://eprint.iacr.org/2022/086) showing how to integrate Plookup into Plonk. +In 2020, [$\plookup$](https://eprint.iacr.org/2020/315.pdf) showed how to create lookup proofs. Proofs that some witness values are part of a [lookup table](https://en.wikipedia.org/wiki/Lookup_table). Two years later, an independent team published [plonkup](https://eprint.iacr.org/2022/086) showing how to integrate $\plookup$ into $\plonk$. -This document specifies how we integrate plookup in kimchi. It assumes that the reader understands the basics behind plookup. +This document specifies how we integrate $\plookup$ in kimchi. It assumes that the reader understands the basics behind $\plookup$. ## Overview -We integrate plookup in kimchi with the following differences: +We integrate $\plookup$ in kimchi with the following differences: * we snake-ify the sorted table instead of wrapping it around (see later) * we allow fixed-ahead-of-time linear combinations of columns of the queries we make -* we only use a single table (XOR) at the moment of this writing +* we implemented different tables, like RangeCheck and XOR. * we allow several lookups (or queries) to be performed within the same row * zero-knowledgeness is added in a specific way (see later) The following document explains the protocol in more detail -### Recap on the grand product argument of plookup +### Recap on the grand product argument of $\plookup$ -As per the Plookup paper, the prover will have to compute three vectors: +As per the $\plookup$ paper, the prover will have to compute three vectors: * $f$, the (secret) **query vector**, containing the witness values that the prover wants to prove are part of the lookup table. * $t$, the (public) **lookup table**. * $s$, the (secret) concatenation of $f$ and $t$, sorted by $t$ (where elements are listed in the order they are listed in $t$). -Essentially, plookup proves that all the elements in $f$ are indeed in the lookup table $t$ if and only if the following multisets are equal: +Essentially, $\plookup$ proves that all the elements in $f$ are indeed in the lookup table $t$ if and only if the following multisets are equal: * $\{(1+\beta)f, \text{diff}(t)\}$ * $\text{diff}(\text{sorted}(f, t))$ -where $\text{diff}$ is a new set derived by applying a "randomized difference" between every successive pairs of a vector. For example: +where $\text{diff}$ is a new set derived by applying a "randomized difference" +between every successive pairs of a vector, and $(f, t)$ is the set union of $f$ +et $t$. + +More precisely, for a set $S = +\{s_{0}, s_{1}, \cdots, s_{n} \}$, $\text{diff}(S)$ is defined as the set +$\{s_{0} + \beta s_{1}, s_{1} + \beta s_{2}, \cdots, s_{n - 1} + \beta s_{n}\}$. + +For example, with: * $f = \{5, 4, 1, 5\}$ * $t = \{1, 4, 5\}$ + +we have: +* $\text{sorted}(f,t) = \{1, 1, 4, 4, 5, 5, 5\}$ * $\{\color{blue}{(1+\beta)f}, \color{green}{\text{diff}(t)}\} = \{\color{blue}{(1+\beta)5, (1+\beta)4, (1+\beta)1, (1+\beta)5}, \color{green}{1+\beta 4, 4+\beta 5}\}$ * $\text{diff}(\text{sorted}(f, t)) = \{1+\beta 1, 1+\beta 4, 4+\beta 4, 4+\beta 5, 5+\beta 5, 5+\beta 5\}$ > Note: This assumes that the lookup table is a single column. You will see in the next section how to address lookup tables with more than one column. -The equality between the multisets can be proved with the permutation argument of plonk, which would look like enforcing constraints on the following accumulator: +The equality between the multisets can be proved with the permutation argument of $\plonk$, which would look like enforcing constraints on the following accumulator: * init: $\mathsf{acc}_0 = 1$ * final: $\mathsf{acc}_n = 1$ @@ -47,17 +58,18 @@ The equality between the multisets can be proved with the permutation argument o \mathsf{acc}_i = \mathsf{acc}_{i-1} \cdot \frac{(\gamma + (1+\beta) f_{i-1}) \cdot (\gamma + t_{i-1} + \beta t_i)}{(\gamma + s_{i-1} + \beta s_{i})(\gamma + s_{n+i-1} + \beta s_{n+i})} $$ -Note that the plookup paper uses a slightly different equation to make the proof work. It is possible that the proof would work with the above equation, but for simplicity let's just use the equation published in plookup: +Note that the $\plookup$ paper uses a slightly different equation to make the proof work. It is possible that the proof would work with the above equation, but for simplicity let's just use the equation published in $\plookup$: $$ \mathsf{acc}_i = \mathsf{acc}_{i-1} \cdot \frac{(1+\beta) \cdot (\gamma + f_{i-1}) \cdot (\gamma(1 + \beta) + t_{i-1} + \beta t_i)}{(\gamma(1+\beta) + s_{i-1} + \beta s_{i})(\gamma(1+\beta) + s_{n+i-1} + \beta s_{n+i})} $$ -> Note: in plookup $s$ is longer than $n$ ($|s| = |f| + |t|$), and thus it needs to be split into multiple vectors to enforce the constraint at every $i \in [[0;n]]$. This leads to the two terms in the denominator as shown above, so that the degree of $\gamma (1 + \beta)$ is equal in the nominator and denominator. +> Note: in $\plookup$ $s$ is longer than $n$ ($|s| = |f| + |t|$), and thus it needs to be split into multiple vectors to enforce the constraint at every $i \in [[0;n]]$. This leads to the two terms in the denominator as shown above, so that the degree of $\gamma (1 + \beta)$ is equal in the nominator and denominator. ### Lookup tables -Kimchi uses a single **lookup table** at the moment of this writing; the XOR table. The XOR table for values of 1 bit is the following: + +Kimchi uses different lookup tables, including RangeCheck and XOR. The XOR table for values of 1 bit is the following: | l | r | o | @@ -73,7 +85,7 @@ Note: the $(0, 0, 0)$ **entry** is at the very end on purpose (as it will be use ### Querying the table -The plookup paper handles a vector of lookups $f$ which we do not have. So the first step is to create such a table from the witness columns (or registers). To do this, we define the following objects: +The $\plookup$ paper handles a vector of lookups $f$ which we do not have. So the first step is to create such a table from the witness columns (or registers). To do this, we define the following objects: * a **query** tells us what registers, in what order, and scaled by how much, are part of a query * a **query selector** tells us which rows are using the query. It is pretty much the same as a [gate selector](). @@ -86,10 +98,16 @@ For example, the following **query** tells us that we want to check if $r_0 \opl | :---: | :---: | :---: | | 1, $r_0$ | 1, $r_2$ | 2, $r_1$ | -The grand product argument for the lookup consraint will look like this at this point: +$r_0$, $r_1$ and $r_2$ will be the result of the evaluation at $g^i$ of +respectively the wire polynomials $w_0$, $w_1$ and $w_2$. +To perform vector lookups (i.e. lookups over a list of values, not a single +element), we use a standard technique which consists of coining a combiner value +$j$ and sum the individual elements of the list using powers of this coin. + +The grand product argument for the lookup constraint will look like this at this point: $$ -\mathsf{acc}_i = \mathsf{acc}_{i-1} \cdot \frac{(1+\beta) \cdot {\color{green}(\gamma + w_0(g^i) + j \cdot w_2(g^i) + j^2 \cdot 2 \cdot w_1(g^i))} \cdot (\gamma(1 + \beta) + t_{i-1} + \beta t_i)}{(\gamma(1+\beta) + s_{i-1} + \beta s_{i})(\gamma(1+\beta) + s_{n+i-1} + \beta s_{n+i})} +\mathsf{acc}_i = \mathsf{acc}_{i-1} \cdot \frac{(1+\beta) \cdot {\color{green}(\gamma + j^0 \cdot 1 \cdot w_0(g^i) + j \cdot 1 \cdot w_2(g^i) + j^2 \cdot 2 \cdot w_1(g^i))} \cdot (\gamma(1 + \beta) + t_{i-1} + \beta t_i)}{(\gamma(1+\beta) + s_{i-1} + \beta s_{i})(\gamma(1+\beta) + s_{n+i-1} + \beta s_{n+i})} $$ Not all rows need to perform queries into a lookup table. We will use a query selector in the next section to make the constraints work with this in mind. @@ -104,7 +122,7 @@ The associated **query selector** tells us on which rows the query into the XOR | 1 | 0 | -Both the (XOR) lookup table and the query are built-ins in kimchi. The query selector is derived from the circuit at setup time. Currently only the ChaCha gates make use of the lookups. +Both the (XOR) lookup table and the query are built-ins in kimchi. The query selector is derived from the circuit at setup time. With the selectors, the grand product argument for the lookup constraint has the following form: @@ -116,7 +134,7 @@ where $\color{green}{\mathsf{query}}$ is constructed so that a dummy query ($0 \ $$ \begin{align} -\mathsf{query} := &\ \mathsf{selector} \cdot (\gamma + w_0(g^i) + j \cdot w_2(g^i) + j^2 \cdot 2 \cdot w_1(g^i)) + \\ +\mathsf{query} := &\ \mathsf{selector} \cdot (\gamma + j^0 \cdot 1 \cdot w_0(g^i) + j \cdot 1 \cdot w_2(g^i) + j^2 \cdot 2 \cdot w_1(g^i)) + \\ &\ (1- \mathsf{selector}) \cdot (\gamma + 0 + j \cdot 0 + j^2 \cdot 0) \end{align} $$ @@ -125,7 +143,10 @@ $$ Since we would like to allow multiple table lookups per row, we define multiple **queries**, where each query is associated with a **lookup selector**. -At the moment of this writing, the `ChaCha` gates all perform $4$ queries in a row. Thus, $4$ is trivially the largest number of queries that happen in a row. +Previously, ChaCha20 was implemented in Kimchi but has been removed as it has become unneeded. +You can still find the implementation +[here](https://github.com/o1-labs/proof-systems/blob/601e0adb2a4ba325c9a76468b091ded2bc7b0f70/kimchi/src/circuits/polynomials/chacha.rs). +The `ChaCha` gates all perform $4$ queries in a row. Thus, $4$ is trivially the largest number of queries that happen in a row. **Important**: to make constraints work, this means that each row must make $4$ queries. Potentially some or all of them are dummy queries. @@ -196,7 +217,7 @@ The denominator thus becomes $\prod_{k=1}^{5} (\gamma (1+\beta) + s_{kn+i-1} + \ There are two things that we haven't touched on: -* The vector $t$ representing the **combined lookup table** (after its columns have been combined with a joint combiner $j$). The **non-combined loookup table** is fixed at setup time and derived based on the lookup tables used in the circuit (for now only one, the XOR lookup table, can be used in the circuit). +* The vector $t$ representing the **combined lookup table** (after its columns have been combined with a joint combiner $j$). The **non-combined loookup table** is fixed at setup time and derived based on the lookup tables used in the circuit. * The vector $s$ representing the sorted multiset of both the queries and the lookup table. This is created by the prover and sent as commitment to the verifier. The first vector $t$ is quite straightforward to think about: @@ -208,7 +229,7 @@ What about the second vector $s$? ## The sorted vector $s$ -We said earlier that in original plonk the size of $s$ is equal to $|s| = |f|+|t|$, where $f$ encodes queries, and $t$ encodes the lookup table. +We said earlier that in original $\plonk$ the size of $s$ is equal to $|s| = |f|+|t|$, where $f$ encodes queries, and $t$ encodes the lookup table. With our multi-query approach, the second vector $s$ is of the size $$n \cdot |\#\text{queries}| + |\text{lookup\_table}|$$ @@ -216,7 +237,7 @@ $$n \cdot |\#\text{queries}| + |\text{lookup\_table}|$$ That is, it contains the $n$ elements of each **query vectors** (the actual values being looked up, after being combined with the joint combinator, that's $4$ per row), as well as the elements of our lookup table (after being combined as well). Because the vector $s$ is larger than the domain size $n$, it is split into several vectors of size $n$. Specifically, in the plonkup paper, the two halves of $s$, which are then interpolated as $h_1$ and $h_2$. -The denominator in plonk in the vector form is +The denominator in $\plonk$ in the vector form is $$ \big(\gamma(1+\beta) + s_{i-1} + \beta s_{i}\big)\big(\gamma(1+\beta)+s_{n+i-1} + \beta s_{n+i}\big) $$ @@ -234,11 +255,11 @@ which is equivalent to checking that $h_1(g^{n-1}) = h_2(1)$. ## The sorted vector $s$ in kimchi -Since this vector is known only by the prover, and is evaluated as part of the protocol, zero-knowledge must be added to the corresponding polynomial (in case of plookup approach, to $h_1(X),h_2(X)$). To do this in kimchi, we use the same technique as with the other prover polynomials: we randomize the last evaluations (or rows, on the domain) of the polynomial. +Since this vector is known only by the prover, and is evaluated as part of the protocol, zero-knowledge must be added to the corresponding polynomial (in case of $\plookup$ approach, to $h_1(X),h_2(X)$). To do this in kimchi, we use the same technique as with the other prover polynomials: we randomize the last evaluations (or rows, on the domain) of the polynomial. This means two things for the lookup grand product argument: -1. We cannot use the wrap around trick to make sure that the list is split in two correctly (enforced by $L_{n-1}(h_1(x) - h_2(g \cdot x)) = 0$ which is equivalent to $h_1(g^{n-1}) = h_2(1)$ in the plookup paper) +1. We cannot use the wrap around trick to make sure that the list is split in two correctly (enforced by $L_{n-1}(h_1(x) - h_2(g \cdot x)) = 0$ which is equivalent to $h_1(g^{n-1}) = h_2(1)$ in the $\plookup$ paper) 2. We have even less space to store an entire query vector. Which is actually super correct, as the witness also has some zero-knowledge rows at the end that should not be part of the queries anyway. The first problem can be solved in two ways: diff --git a/book/src/kimchi/maller_15.md b/book/src/kimchi/maller_15.md index 7e1311e033..d88fe3a924 100644 --- a/book/src/kimchi/maller_15.md +++ b/book/src/kimchi/maller_15.md @@ -1,4 +1,4 @@ -# Maller's optimization for kimchi +# Maller's optimization for Kimchi This document proposes a protocol change for [kimchi](../specs/kimchi.md). @@ -17,10 +17,10 @@ $$ They could do this like so: $$ -com(L) = com(f) - Z_H(\zeta) \cdot com(t) +\mathsf{com}(L) = \mathsf{com}(f) - Z_H(\zeta) \cdot \mathsf{com}(t) $$ -Since they already know $f$, they can produce $com(f)$, the only thing they need is $com(t)$. So the protocol looks like that: +Since they already know $f$, they can produce $\mathsf{com}(f)$, the only thing they need is $\mathsf{com}(t)$. So the protocol looks like that: ![maller 15 1](../img/maller_15_1.png) @@ -39,7 +39,7 @@ In the rest of this document, we review the details and considerations needed to ## How to deal with a chunked $t$? -There's one challenge that prevents us from directly using this approach: $com(t)$ is typically sent and received in several commitments (called chunks or segments throughout the codebase). As $t$ is the largest polynomial, usually exceeding the size of the SRS (which is by default set to be the size of the domain). +There's one challenge that prevents us from directly using this approach: $\mathsf{com}(t)$ is typically sent and received in several commitments (called chunks or segments throughout the codebase). As $t$ is the largest polynomial, usually exceeding the size of the SRS (which is by default set to be the size of the domain). ### The verifier side @@ -50,16 +50,16 @@ L = L_0 + x^n L_1 + x^{2n} L_1 + \cdots $$ where every $L_i$ is of degree $n-1$ at most. -Then we have that +Then we have that $$ -com(L) = com(L_0) + com(x^n \cdot L_1) + com(x^{2n} \cdot L_2) + \cdots +\mathsf{com}(L) = \mathsf{com}(L_0) + \mathsf{com}(x^n \cdot L_1) + \mathsf{com}(x^{2n} \cdot L_2) + \cdots $$ Which the verifier can't produce because of the powers of $x^n$, but we can linearize it as we already know which $x$ we're going to evaluate that polynomial with: $$ -com(\tilde L) = com(L_0) + \zeta^n \cdot com(L_1) + \zeta^{2n} \cdot com(L_2) + \cdots +\mathsf{com}(\tilde L) = \mathsf{com}(L_0) + \zeta^n \cdot \mathsf{com}(L_1) + \zeta^{2n} \cdot \mathsf{com}(L_2) + \cdots $$ ### The prover side @@ -83,11 +83,11 @@ $$ To compute it, there are two rules to follow: * when two commitment are **added** together, their associated blinding factors get added as well: - $$com(a) + com(b) \implies r_a + r_b$$ + $$\mathsf{com}(a) + \mathsf{com}(b) \implies r_a + r_b$$ * when **scaling** a commitment, its blinding factor gets scalled too: - $$n \cdot com(a) \implies n \cdot r_a$$ + $$n \cdot \mathsf{com}(a) \implies n \cdot r_a$$ -As such, if we know $r_f$ and $r_t$, we can compute: +As such, if we know $r_f$ and $r_t$, we can compute: $$ r_{\tilde L} = r_{\tilde f} + (\zeta^n-1) \cdot r_{\tilde t} @@ -97,7 +97,7 @@ The prover knows the blinding factor of the commitment to $t$, as they committed 1. **The commitments we sent them**. In the linearization process, the verifier actually gets rid of most prover commitments, except for the commitment to the last sigma commitment $S_{\sigma6}$. (TODO: link to the relevant part in the spec) As this commitment is public, it is not blinded. 2. **The public commitments**. Think commitment to selector polynomials or the public input polynomial. These commitments are not blinded, and thus do not impact the calculation of the blinding factor. -3. **The evaluations we sent them**. Instead of using commitments to the wires when recreating $f$, the verifier uses the (verified) evaluations of these in $\zeta$. If we scale our commitment $com(z)$ with any of these scalars, we will have to do the same with $r_z$. +3. **The evaluations we sent them**. Instead of using commitments to the wires when recreating $f$, the verifier uses the (verified) evaluations of these in $\zeta$. If we scale our commitment $\mathsf{com}(z)$ with any of these scalars, we will have to do the same with $r_z$. Thus, the blinding factor of $\tilde L$ is $(\zeta^n-1) \cdot r_{\tilde t}$. @@ -109,24 +109,24 @@ It should be equal to the following: $$ \begin{align} & \tilde f(\zeta) - \tilde t(\zeta)(\zeta^n - 1) = \\ -& \frac{1 - z(\zeta)}{(\zeta - 1)(\zeta - \omega^{n-k})}\left[ \frac{(\zeta^n - 1)(\zeta - \omega^{n-k})}{n} \alpha^{PERM1} + \frac{\omega^{n-k}(\zeta^n - 1)(\zeta - 1)}{n} \alpha^{PERM2} \right] \\ -& - pub(\zeta) \\ -& \; - z(\zeta) \cdot zkpm(\zeta) \cdot \alpha^{PERM0} \cdot \\ -& \; (w[0](\zeta) + \beta \cdot \text{shift}[0] \zeta + \gamma) \cdot \\ -& \; (w[1](\zeta) + \beta \cdot \text{shift}[1] \zeta + \gamma) \cdot \\ -& \; (w[2](\zeta) + \beta \cdot \text{shift}[2] \zeta + \gamma) \cdot \\ -& \; (w[3](\zeta) + \beta \cdot \text{shift}[3] \zeta + \gamma) \cdot \\ -& \; (w[4](\zeta) + \beta \cdot \text{shift}[4] \zeta + \gamma) \cdot \\ -& \; (w[5](\zeta) + \beta \cdot \text{shift}[5] \zeta + \gamma) \cdot \\ -& \; (w[6](\zeta) + \beta \cdot \text{shift}[6] \zeta + \gamma) + \\ -& \; + z(\zeta \omega) \cdot zkpm(\zeta) \cdot \alpha^{PERM0} \cdot \\ -& \; (w[0](\zeta) + \beta \cdot \sigma[0](\zeta) + \gamma) \cdot \\ -& \; (w[1](\zeta) + \beta \cdot \sigma[1](\zeta) + \gamma) \cdot \\ -& \; (w[2](\zeta) + \beta \cdot \sigma[2](\zeta) + \gamma) \cdot \\ -& \; (w[3](\zeta) + \beta \cdot \sigma[3](\zeta) + \gamma) \cdot \\ -& \; (w[4](\zeta) + \beta \cdot \sigma[4](\zeta) + \gamma) \cdot \\ -& \; (w[5](\zeta) + \beta \cdot \sigma[5](\zeta) + \gamma) \cdot \\ -& \; (w[6](\zeta) + \gamma) + \\ +& \frac{1 - z(\zeta)}{(\zeta - 1)(\zeta - \omega^{n-k})}\left[ \frac{(\zeta^n - 1)(\zeta - \omega^{n-k})}{n} \alpha^{\mathsf{PERM1}} + \frac{\omega^{n-k}(\zeta^n - 1)(\zeta - 1)}{n} \alpha^{\mathsf{PERM2}} \right] \\ +& - \mathsf{pub}(\zeta) \\ +& \; - z(\zeta) \cdot \mathsf{zkpm}(\zeta) \cdot \alpha^{\mathsf{PERM0}} \\ +& \; \qquad \qquad \cdot (w[0](\zeta) + \beta \cdot \mathsf{shift}[0] \zeta + \gamma) \\ +& \; \qquad \qquad \cdot (w[1](\zeta) + \beta \cdot \mathsf{shift}[1] \zeta + \gamma) \\ +& \; \qquad \qquad \cdot (w[2](\zeta) + \beta \cdot \mathsf{shift}[2] \zeta + \gamma) \\ +& \; \qquad \qquad \cdot (w[3](\zeta) + \beta \cdot \mathsf{shift}[3] \zeta + \gamma) \\ +& \; \qquad \qquad \cdot (w[4](\zeta) + \beta \cdot \mathsf{shift}[4] \zeta + \gamma) \\ +& \; \qquad \qquad \cdot (w[5](\zeta) + \beta \cdot \mathsf{shift}[5] \zeta + \gamma) \\ +& \; \qquad \qquad \cdot (w[6](\zeta) + \beta \cdot \mathsf{shift}[6] \zeta + \gamma) \\ +& \; + z(\zeta \omega) \cdot \mathsf{zkpm}(\zeta) \cdot \alpha^{\mathsf{PERM0}} \\ +& \; \qquad \qquad \cdot (w[0](\zeta) + \beta \cdot \sigma[0](\zeta) + \gamma) \\ +& \; \qquad \qquad \cdot (w[1](\zeta) + \beta \cdot \sigma[1](\zeta) + \gamma) \\ +& \; \qquad \qquad \cdot (w[2](\zeta) + \beta \cdot \sigma[2](\zeta) + \gamma) \\ +& \; \qquad \qquad \cdot (w[3](\zeta) + \beta \cdot \sigma[3](\zeta) + \gamma) \\ +& \; \qquad \qquad \cdot (w[4](\zeta) + \beta \cdot \sigma[4](\zeta) + \gamma) \\ +& \; \qquad \qquad \cdot (w[5](\zeta) + \beta \cdot \sigma[5](\zeta) + \gamma) \\ +& \; \qquad \qquad \cdot (w[6](\zeta) + \gamma) + \end{align} $$ @@ -146,25 +146,25 @@ Now here's how we need to modify the current protocol: 3. The prover must create a linearized polynomial $\tilde L$ by creating a linearized polynomial $\tilde f$ and a linearized polynomial $\tilde t$ and computing: $$\tilde L = \tilde f + (\zeta^n-1) \cdot \tilde t$$ 4. While the verifier can compute the evaluation of $\tilde L(\zeta)$ by themselves, they don't know the evaluation of $\tilde L(\zeta \omega)$, so the prover needs to send that. -5. The verifier must recreate $com(\tilde L)$, the commitment to $\tilde L$, themselves so that they can verify the evaluation proofs of both $\tilde L(\zeta)$ and $\tilde L(\zeta\omega)$. +5. The verifier must recreate $\mathsf{com}(\tilde L)$, the commitment to $\tilde L$, themselves so that they can verify the evaluation proofs of both $\tilde L(\zeta)$ and $\tilde L(\zeta\omega)$. 6. The evaluation of $\tilde L(\zeta \omega)$ must be absorbed in both sponges (Fq and Fr). ![maller 15 2](../img/maller_15_2.png) The proposal is implemented in [#150](https://github.com/o1-labs/proof-systems/pull/150) with the following details: -* the $\tilde L$ polynomial is called $ft$. -* the evaluation of $\tilde L(\zeta)$ is called $ft_eval0$. -* the evaluation $\tilde L(\zeta\omega)$ is called $ft_eval1$ +* the $\tilde L$ polynomial is called `ft`. +* the evaluation of $\tilde L(\zeta)$ is called `ft_eval0`. +* the evaluation $\tilde L(\zeta\omega)$ is called `ft_eval1`. diff --git a/book/src/kimchi/overview.md b/book/src/kimchi/overview.md index 6ecf0f95e4..74b290bc7a 100644 --- a/book/src/kimchi/overview.md +++ b/book/src/kimchi/overview.md @@ -45,10 +45,10 @@ So what is the simplest way one could think of to perform this one-time check? P So instead, we can multiply each term in the sum by a random number. The reason why this trick works is the independence between random numbers. That is, if two different polynomials $f(X)$ and $g(X)$ are both equal to zero on a given $X=x$, then with very high probability the same $x$ will be a root of the random combination $\alpha\cdot f(x) + \beta\cdot g(x) = 0$. If applied to the whole statement, we could transform the $n$ equations into a single equation, -$$\bigwedge_{i_n} p_i(X) \stackrel{?}{=} 0 \text{\quad iff w.h.p. \quad} \sum_{i=0}^{n} \rho_i \cdot p_i(X) \stackrel{?}{=} 0$$ +$$\bigwedge_{i_n} p_i(X) =_? 0 \Leftrightarrow_{w.h.p.} \sum_{i=0}^{n} \rho_i \cdot p_i(X) =_? 0$$ This sounds great so far. But we are forgetting about an important part of proof systems which is proof length. For the above claim to be sound, the random values used for aggregation should be verifier-chosen, or at least prover-independent. So if the verifier had to communicate with the prover to inform about the random values being used, we would get an overhead of $n$ field elements. Instead, we take advantage of another technique that is called **powers-of-alpha**. Here, we make the assumption that powers of a random value $\alpha^i$ are indistinguishable from actual random values $\rho_i$. Then, we can twist the above claim to use only one random element $\alpha$ to be agreed with the prover as: -$$\bigwedge_{i_n} p_i(X) \stackrel{?}{=} 0 \text{\quad iff w.h.p. \quad} \sum_{i=0}^{n} \alpha^i \cdot p_i(X) \stackrel{?}{=} 0$$ +$$\bigwedge_{i_n} p_i(X) =_? 0 \Leftrightarrow_{w.h.p.} \sum_{i=0}^{n} \alpha^i \cdot p_i(X) =_? 0$$ diff --git a/book/src/pickles/.gitignore b/book/src/pickles/.gitignore new file mode 100644 index 0000000000..fd2f78e90c --- /dev/null +++ b/book/src/pickles/.gitignore @@ -0,0 +1,2 @@ +.*.bkp +.*.dtmp diff --git a/book/src/pickles/accumulation.md b/book/src/pickles/accumulation.md index f1af0deb01..548bf049fe 100644 --- a/book/src/pickles/accumulation.md +++ b/book/src/pickles/accumulation.md @@ -1,8 +1,8 @@ -# Accumulation + # Accumulation ## Introduction -The trick below was originally described in [Halo](https://eprint.iacr.org/2020/499.pdf), +The trick below was originally described in [Halo](https://eprint.iacr.org/2019/1021.pdf), however we are going to base this post on the abstraction of "accumulation schemes" described by Bünz, Chiesa, Mishra and Spooner in [Proof-Carrying Data from Accumulation Schemes](https://eprint.iacr.org/2020/499.pdf), in particular the scheme in Appendix A. 2. Relevant resources include: @@ -145,7 +145,7 @@ The communication complexity will be solved by a well-known folding argument, ho First a reduction from PCS to an inner product relation. -## Reduction: $\relation_{\mathsf{PCS},d} \to \relation_{\mathsf{IPA},\ell}$ +## Reduction 1: $\relation_{\mathsf{PCS},d} \to \relation_{\mathsf{IPA},\ell}$ Formally the relation of the inner product argument is: @@ -194,7 +194,7 @@ $$ this also works, however we (in Kimchi) simply hash to the curve to sample $H$. -## Reduction: $\relation_{\mathsf{IPA},\ell} \to \relation_{\mathsf{IPA},\ell/2}$ +## Reduction 2 (Incremental Step): $\relation_{\mathsf{IPA},\ell} \to \relation_{\mathsf{IPA},\ell/2}$ **Note:** The folding argument described below is the particular variant implemented in Kimchi, although some of the variable names are different. @@ -486,7 +486,7 @@ This would require half as much communication as the naive proof. A modest impro However, we can iteratively apply this transformation until we reach an instance of constant size: -## Reduction: $\relIPA{\ell} \to \ldots \to \relIPA{1}$ +## Reduction 2 (Full): $\relIPA{\ell} \to \ldots \to \relIPA{1}$ That the process above can simply be applied again to the new $(C', \vec{G}', H, \vec{\openx}', v) \in \relation_{\mathsf{IPA}, \ell/2}$ instance as well. By doing so $k = \log_2(\ell)$ times the total communication is brought down to $2 k$ $\GG$-elements @@ -498,21 +498,22 @@ we let $\vec{G}^{(i)}$, $\vec{f}^{(i)}$, $\vec{\openx}^{(i)}$ be the $\vec{G}'$, $\vec{f}'$, $\vec{\openx}'$ vectors respectively after $i$ recursive applications, with $\vec{G}^{(0)}$, $\vec{f}^{(0)}$, $\vec{\openx}^{(0)}$ being the original instance. +We will drop the vector arrow for $\vec{G}^{(k)}$ (similarly, for $f,x$), since the last step vector has size $1$, so we refer to this single value instead. We denote by $\chalfold_i$ the challenge of the $i$'th application. -## Reduction: $\relation_{\mathsf{IPA},1} \to \relation_{\mathsf{Acc},\overset{\rightarrow}{G} }$ +## Reduction 3: $\relation_{\mathsf{IPA},1} \to \relation_{\mathsf{Acc},\overset{\rightarrow}{G} }$ While the proof for $\relIPA{\ell}$ above has $O(\log(\ell))$-size, the verifiers time-complexity is $O(\ell)$: -- Computing $\vec{G}^{(k)}$ from $\vec{G}^{(0)}$ using $\vec{\chalfold}$ takes $O(\ell)$. -- Computing $\vec{\openx}^{(k)}$ from $\vec{\openx}^{(0)}$ using $\vec{\chalfold}$ takes $O(\ell)$. +- Computing $G^{(k)}$ from $\vec{G}^{(0)}$ using $\vec{\chalfold}$ takes $O(\ell)$. +- Computing $\openx^{(k)}$ from $\vec{\openx}^{(0)}$ using $\vec{\chalfold}$ takes $O(\ell)$. The rest of the verifiers computation is only $O(\log(\ell))$, namely computing: - Sampling all the challenges $\chalfold \sample \FF$. - Computing $C^{(i)} \gets [\chalfold_i^{-1}] \cdot L^{(i)} + C^{(i-1)} + [\chalfold_i] \cdot R^{(i)}$ for every $i$ -However, upon inspection, the naive claim that computing $\vec{\openx}^{(k)}$ takes $O(\ell)$ turns out not to be true: +However, upon inspection, the more pessimistic claim that computing $\vec{\openx}^{(k)}$ takes $O(\ell)$ turns out to be false: **Claim:** Define @@ -521,31 +522,39 @@ $ $, then $ -\vec{\openx}^{(k)} = \hpoly(\openx) -$ for all $\openx$. +\openx^{(k)} = \hpoly(\openx) +$ for all $\openx$. Therefore, $\openx^{(k)}$ can be evaluated in $O(k) = O(\log \ell)$. **Proof:** This can be verified by looking at the expansion of $\hpoly(X)$. -In slightly more detail: -an equivalent claim is that $\openx^{(k)} = \sum_{i=1}^{\ell} h_i \cdot \openx^{i-1}$ -where $\hpoly(X) = \sum_{i=1}^\ell h_i \cdot X^{i-1}$. -Let $\vec{b}$ be the bit-decomposition of the index $i$ and observe that: +Define $\{h_i\}_{i=0}^l$ to be the coefficients of $\hpoly(X)$, that is $\hpoly(X) = \sum_{i=1}^\ell h_i \cdot X^{i-1}$. +Then the claim is equivalent to $\openx^{(k)} = \sum_{i=1}^{\ell} h_i \cdot \openx^{i-1}$. +Let $\vec{b}(i,j)$ denote the $j$th bit in the bit-decomposition of the index $i$ and observe that: $$ -h_i = \sum_{b_j} b_j \cdot \chalfold_{k-i}, \text{ where } i = \sum_{j} b_j \cdot 2^j +h_i = \prod_{j=1}^k \vec \chalfold_{k-j}^{b(i,j)} \text{\qquad where\qquad } \sum_{j} \vec b(i,j) \cdot 2^j = i $$ -Which is simply a special case of the binomial theorem for the product: -$$(1 + \chalfold_1) \cdot (1 + \chalfold_2) \cdots (1 + \chalfold_k)$$ +Now, compare this with how a $k$th element of $x^{(i)}$ is constructed: +$$ +\begin{align*} +x^{(1)}_k &= x^{(0)}_k + \alpha_1 \cdot x^{(0)}_{n/2+k}\\ +x^{(2)}_k &= x^{(1)}_k + \alpha_2 \cdot x^{(2)}_{n/4+k}\\ +&= x^{(0)}_k + \alpha_1 \cdot x^{(0)}_{n/2+k} + \alpha_2 \cdot (x^{(0)}_{n/4+k} + \alpha_1 \cdot x^{(0)}_{n/2 + n/4 +k})\\ +&= \sum_{i=0}^3 x^{(0)}_{i \cdot \frac{n}{4} + k} \cdot \big( \prod_{j=0}^1 \alpha_j^{b(i,j)} \big) +\end{align*} +$$ +Recalling that $x^{(0)}_k = x^k$, it is easy to see that this generalizes exactly to the expression for $h_i$ that we derived later, which concludes that evaluation through $h(X)$ is correct. + +Finally, regarding evaluation complexity, it is clear that $\hpoly$ can be evaluated in $O(k = \log \ell)$ time as a product of $k$ factors. +This concludes the proof. -Looking at $\hpoly$ it can clearly can be evaluated in $O(k = \log \ell)$ time, -computing $\vec{\openx}^{(k)}$ therefore takes just $O(\log \ell)$ time! #### The "Halo Trick" The "Halo trick" resides in observing that this is also the case for $\vec{G}^{(k)}$: -since it is folded the same way as $\vec{\openx}$. It is not hard to convince one-self (using the same type of argument as above) that: +since it is folded the same way as $\vec{\openx}^{(k)}$. It is not hard to convince one-self (using the same type of argument as above) that: $$ -\vec{G}^{(k)} = \langle \vec{h}, \vec{G} \rangle +G^{(k)} = \langle \vec{h}, \vec{G} \rangle $$ Where $\vec{h}$ is the coefficients of $h(X)$ (like $\vec{f}$ is the coefficients of $f(X)$), i.e. $h(X) = \sum_{i = 1}^{\ell} h_i X^{i-1}$ @@ -613,7 +622,7 @@ $(\statement = (\accCom, \vec{\accChal}), \witness = \epsilon) \in \relAcc$ **Note:** The above can be optimized, in particular there is no need for the prover to send $\accCom$. -## Reduction: $\relation_{\mathsf{Acc}, \overset{\rightarrow}{G}} \to \relation_{\mathsf{PCS}, d}$ +## Reduction 4: $\relation_{\mathsf{Acc}, \overset{\rightarrow}{G}} \to \relation_{\mathsf{PCS}, d}$ Tying the final knot in the diagram. @@ -637,6 +646,13 @@ y &= \sum_i \ \chalfold^{i-1} \cdot h^{(i)}(u) \in \FF \\ C &= \sum_i \ [\chalfold^{i-1}] \cdot U^{(i)} \in \GG \end{align} $$ +Alternatively: +$$ +\begin{align} +y &= \sum_i \ u^{i-1} \cdot h^{(i)}(\chaleval) \in \FF \\ +C &= \sum_i \ [u^{i-1}] \cdot U^{(i)} \in \GG +\end{align} +$$ And outputs the following claim: @@ -656,7 +672,9 @@ Taking a union bound over all $n$ terms leads to soundness error $\frac{n \ell}{ The reduction above requires $n$ $\GG$ operations and $O(n \log \ell)$ $\FF$ operations. -**Addition of Polynomial Relations:** additional polynomial commitments (i.e. from PlonK) can be added to the randomized sums $(C, y)$ above and opened at $\chaleval$ as well: in which case the prover proves the claimed openings at $\chaleval$ before sampling the challenge $u$. +## Support for Arbitrary Polynomial Relations + +Additional polynomial commitments (i.e. from PlonK) can be added to the randomized sums $(C, y)$ above and opened at $\chaleval$ as well: in which case the prover proves the claimed openings at $\chaleval$ before sampling the challenge $u$. This is done in Kimchi/Pickles: the $\chaleval$ and $u$ above is the same as in the Kimchi code. The combined $y$ (including both the $h(\cdot)$ evaluations and polynomial commitment openings at $\chaleval$ and $\chaleval \omega$) is called `combined_inner_product` in Kimchi. @@ -673,7 +691,8 @@ Cycle of reductions with the added polynomial relations from PlonK. This $\relation_{\mathsf{PCS},\ell}$ instance reduced back into a single $\relAcc$ instance, which is included with the proof. -**Multiple Accumulators (the case of PCD):** +## Multiple Accumulators (PCD Case) + From the section above it may seem like there is always going to be a single $\relAcc$ instance, this is indeed the case if the proof only verifies a single proof, "Incremental Verifiable Computation" (IVC) in the literature. If the proof verifies multiple proofs, "Proof-Carrying Data" (PCD), then there will be multiple accumulators: @@ -709,16 +728,17 @@ Let $\mathcal{C} \subseteq \FF$ be the challenge space (128-bit GLV decomposed c 1. PlonK verifier on $\pi$ outputs polynomial relations (in Purple in Fig. 4). 1. Checking $\relation_{\mathsf{Acc}, \vec{G}}$ and polynomial relations (from PlonK) to $\relation_{\mathsf{PCS},d}$ (the dotted arrows): 1. Sample $\chaleval \sample \mathcal{C}$ (evaluation point) using the Poseidon sponge. - 1. Read claimed evaluations at $\chaleval$ and $\omega \chaleval$ (`ProofEvaluations`). - 1. Sample $\chalu \sample \mathcal{C}$ (commitment combination challenge) using the Poseidon sponge. - 1. Sample $\chalv \sample \mathcal{C}$ (evaluation combination challenge) using the Poseidon sponge. - 1. Compute $C \in \GG$ with $\chalu$ from: + 1. Read claimed evaluations at $\chaleval$ and $\omega \chaleval$ (`PointEvaluations`). + 1. Sample $\chalv \sample \mathcal{C}$ (commitment combination challenge, `polyscale`) using the Poseidon sponge. + 1. Sample $\chalu \sample \mathcal{C}$ (evaluation combination challenge, `evalscale`) using the Poseidon sponge. + - The $(\chalv, \chalu)$ notation is consistent with the Plonk paper where $\chalv$ recombines commitments and $\chalu$ recombines evaluations + 1. Compute $C \in \GG$ with $\chalv$ from: - $\accCom^{(1)}, \ldots, \accCom^{(n)}$ (`RecursionChallenge.comm` $\in \GG$) - Polynomial commitments from PlonK (`ProverCommitments`) - 1. Compute $\openy_{\chaleval}$ (part of `combined_inner_product`) with $\chalu$ from: + 1. Compute $\openy_{\chaleval}$ (part of `combined_inner_product`) with $\chalv$ from: - The evaluations of $h^{(1)}(\chaleval), \ldots, h^{(n)}(\chaleval)$ - Polynomial openings from PlonK (`ProofEvaluations`) at $\chaleval$ - 1. Compute $\openy_{\chaleval\omega}$ (part of `combined_inner_product`) with $\chalu$ from: + 1. Compute $\openy_{\chaleval\omega}$ (part of `combined_inner_product`) with $\chalv$ from: - The evaluations of $h^{(1)}(\chaleval\omega), \ldots, h^{(n)}(\chaleval\omega)$ - Polynomial openings from PlonK (`ProofEvaluations`) at $\chaleval \cdot \omega$ @@ -737,7 +757,7 @@ Let $\mathcal{C} \subseteq \FF$ be the challenge space (128-bit GLV decomposed c 1. Checking $\relation_{\mathsf{PCS}, d} \to \relation_{\mathsf{IPA},\ell}$. 1. Sample $\genOpen \sample \GG$ using the Poseidon sponge: hash to curve. - 1. Compute $\openy \gets \openy_{\chaleval} + \chalv \cdot \openy_{\chaleval\omega}$. + 1. Compute $\openy \gets \openy_{\chaleval} + \chalu \cdot \openy_{\chaleval\omega}$. 1. Update $\comm' \gets \comm + [\openy] \cdot \genOpen$. 1. Checking $\relation_{\mathsf{IPA}, \ell} \to \relation_{\mathsf{IPA},1}$:
Check the correctness of the folding argument, for every $i = 1, \ldots, k$: @@ -751,10 +771,12 @@ Let $\mathcal{C} \subseteq \FF$ be the challenge space (128-bit GLV decomposed c 1. Checking $\relation_{\mathsf{IPA},1} \to \relation_{\mathsf{Acc}, \vec{G}}$ 1. Receive $c$ form the prover. 1. Define $\hpoly$ from $\vec{\chalfold}$ (folding challenges, computed above). - 1. Compute $\openy' \gets c \cdot (\hpoly(\chaleval) + \chalv \cdot \hpoly(\chaleval \omega))$, this works since: + 1. Compute a combined evaluation of the IPA challenge polynomial on two points: $b = \hpoly(\chaleval) + \chalu \cdot \hpoly(\chaleval \omega)$ + - Computationally, $b$ is obtained inside bulletproof folding procedure, by folding the vector $b_{\mathsf{init}}$ such that $b_{\mathsf{init},j} = \zeta^j + v \cdot \zeta^j w^j$ using the same standard bulletproof challenges that constitute $h(X)$. This $b_{\mathsf{init}}$ is a $v$-recombined evaluation point. The equality is by linearity of IPA recombination. + 1. Compute $\openy' \gets c \cdot b = c \cdot (\hpoly(\chaleval) + \chalu \cdot \hpoly(\chaleval \omega))$, this works since: $$ \openx^{(\rounds)} = - \openx^{(\rounds)}_{\chaleval} + \chalv \cdot \openx^{(\rounds)}_{\chaleval\omega} + \openx^{(\rounds)}_{\chaleval} + \chalu \cdot \openx^{(\rounds)}_{\chaleval\omega} $$ See [Different functionalities](../plonk/inner_product_api.html) for more details or [the relevant code](https://github.com/o1-labs/proof-systems/blob/76c678d3db9878730f8a4eead65d1e038a509916/poly-commitment/src/commitment.rs#L785). @@ -767,7 +789,7 @@ Note that the accumulator verifier must be proven (in addition to the Kimchi/Plo Note that the "cycles of curves" (e.g. Pasta cycle) does not show up in this part of the code: a separate accumulator is needed for each curve and the final verifier must check both accumulators to deem the combined recursive proof valid. -This takes the form of `passthough` in pickles. +This takes the form of `passthough` data in pickles. Note however, that the accumulation verifier makes use of both $\GG$-operations and $\FF$-operations, therefore it (like the Kimchi verifier) also requires [deferred computation](deferred.html). diff --git a/book/src/pickles/deferred.md b/book/src/pickles/deferred.md index b308c167ef..0144abb54e 100644 --- a/book/src/pickles/deferred.md +++ b/book/src/pickles/deferred.md @@ -33,7 +33,7 @@ The solution is to "pass" a value $v$ between the two proofs, in other words to **Insufficient Field Size:** $p < q$ hence $v$ cannot fit in $\mathbb{F}_p$. -**No Binding:** More concerning, there is *no binding* between the $\tilde{v}$ in the $\mathbb{F}_p$-witness and the $v$ in the $\mathbb{F}_q$-witness: a malicious prover could choose completely unreleated values. This violates soundness of the overall $\mathbb{F}_q/\mathbb{F}_q$-relation being proved. +**No Binding:** More concerning, there is *no binding* between the $\tilde{v}$ in the $\mathbb{F}_p$-witness and the $v$ in the $\mathbb{F}_q$-witness: a malicious prover could choose completely unrelated values. This violates soundness of the overall $\mathbb{F}_q/\mathbb{F}_q$-relation being proved. #### Problem 1: Decompose diff --git a/book/src/pickles/diagrams.md b/book/src/pickles/diagrams.md new file mode 100644 index 0000000000..63783603a6 --- /dev/null +++ b/book/src/pickles/diagrams.md @@ -0,0 +1,29 @@ +# Pickles Technical Diagrams + +This section contains a series of diagrams giving an overview of implementation of pickles, closely following the structures and abstractions in the code. However, they are very technical and are primarily intended for developer use. The primary source of all the images in this section is [`pickles_structure.drawio`](./pickles_structure.drawio) file, and if one edits it one must re-generate the `.svg` renders by doing "export -> selection". + +The legend of the diagrams is quite straightforward: +- The black boxes are data structures that have names and labels following the implementation. + - `MFNStep`/`MFNWrap` is an abbreviation from `MessagesForNextStep` and `MessagesForNextWrap` that is used for brevity. Most other datatypes are exactly the same as in the codebase. +- The blue boxes are computations. Sometimes, when the computation is trivial or only vaguely indicated, it is denoted as a text sign directly on an arrow. +- Arrows are blue by default and denote moving a piece of data from one place to another with no (or very little) change. Light blue arrows are denoting witness query that is implemented through the `handler` mechanism. The "chicken foot" connector means that this arrow accesses just one field in an array: such an arrow could connect e.g. a input field of type `old_a: A` in a structure `Vec<(A,B)>` to an output `new_a: A`, which just means that we are inside a `for` loop and this computation is done for all the elemnts in the vector/array. +- Colour of the field is sometimes introduced and denotes how many steps ago was this piece of data created. The absense of the colour means either that (1) the data structure contains different subfields of different origin, or that (2) it was not coloured but it could be. The colours are assigned according to the following convention: + +![](./pickles_structure_legend_1.svg) + + +### Wrap Computatiton and Deferred Values + +The following is the diagram that explains the Wrap computation. The left half of it corresponds to the `wrap.ml` file and the general logic of proof creation, while the right part of it is `wrap_main.ml`/`wrap_verifier.ml` and explains in-circuit computation. + +[ ![](pickles_structure_wrap.svg) ](./pickles_structure_wrap.svg) + +This diagram explains the `deferred_values` computation which is in the heart of `wrap.ml` represented as a blue box in the middle of the left pane of the main Wrap diagram. Deferred values for Wrap are computed as follows: + +[ ![](pickles_structure_wrap_deferred_values.svg) ](./pickles_structure_wrap_deferred_values.svg) + +### Step Computation + +The following is the diagram that explains the Step computation, similarly to Wrap. The left half of it corresponds to the general logic in `step.ml`, and the right part of it is `step_main.ml` and explains in-circuit computation. We provide no `deferred_values` computation diagram for Step, but it is very conceptually similar to the one already presented for Wrap. + +[ ![](pickles_structure_step.svg) ](./pickles_structure_step.svg) diff --git a/book/src/pickles/overview.md b/book/src/pickles/overview.md index f69a423807..7d2d8bdb7a 100644 --- a/book/src/pickles/overview.md +++ b/book/src/pickles/overview.md @@ -1,14 +1,30 @@ -# Overview - -0. Run PlonK verifiers on proofs $\pi_1, \ldots, \pi_n$ outputs polynomial relations. -1. Checking $\relation_{\mathsf{Acc}, \vec{G}}$ and polynomial relations to $\relation_{PCS,d}$: - 1. Sample $\zeta \sample \mathcal{C}$ using the Poseidon sponge. - 2. Compute $C \in \GG$ from $U^{(1)}, \ldots, U^{(n)}$ from -2. Checking $\relation_{\mathsf{IPA}, \ell} \to \relation_{\mathsf{IPA},1}$ -3. Checking - - - `RecursionChallenge.comm` $\in \GG$ fields from each "input proof" - - The polynomial commitments in the - - $u \in [0, 2^{128})$ challenge (provided in GLV form). -3. Compute $y$ from: - - `RecursionChallenge.prev_challenges` $\in \FF^\ell$ fields from each "input proof" +# Overview of Pickles + +Pickles is a recursion layer built on top of Kimchi. The complexity of pickles as a protocol lies in specifying how to verify previous kimchi inside of the current ones. Working over two curves requires us to have different circuits and a "mirrored" structure, some computations are deferred for efficiency, and one needs to carefully keep track of the accumulators. In this section we provide a general overview of pickles, while next sections in the same chapter dive into the actual implementation details. + +Pickles works over [Pasta](../specs/pasta.md), a cycle of curves consisting of Pallas and Vesta, and thus it defines two generic circuits, one for each curve. Each can be thought of as a parallel instantiation of a kimchi proof systems. These circuits are not symmetric and have somewhat different function: +- **Step circuit**: this is the main circuit that contains application logic. Each step circuit verifies a statement and potentially several (at most 2) other wrap proofs. +- **Wrap circuit**: this circuit merely verifies the step circuit, and does not have its own application logic. The intuition is that every time an application statement is proven it's done in Step, and then the resulting proof is immediately wrapped using Wrap. + + +#### General Circuit Structure + +Both Step and Wrap circuits additionally do a lot of recursive verification of the previous steps. Without getting too technical, Step (without lost of generality) does the following: +1. Execute the application logic statement (e.g. the mina transaction is valid) +2. Verify that the previous Wrap proof is (first-)half-valid (perform only main checks that are efficient for the curve) +3. Verify that the previous Step proof is (second-)half-valid (perform the secondary checks that were inefficient to perform when the previous Step was Wrapped) +4. Verify that the previous Step correctly aggregated the previous accumulator, e.g. $\mathsf{acc}_2 = \mathsf{Aggregate}(\mathsf{acc}_1, \pi_{\mathsf{step},2})$. + +The diagram roughly illustrates the interplay of the two kimchi instances. + +![Overview](./pickles_structure_overview.svg) + + +We note that the Step circuit may repeat items 2-3 to handle the following case: when Step 2 consumes several Wrap 1.X (e.g. Wrap 1.1, Wrap 1.2, etc), it must perform all these main Wrap 1.X checks, but also all the deferred Step 1.X checks where Wrap 1.X wraps exactly Step 1.X. + + +#### On Accumulators + +The accumulator is an abstraction introduced for the purpose of this diagram. In practice, each kimchi proof consists of (1) commitments to polynomials, (2) evaluations of them, (3) and the opening proof. What we refer to as accumulator here is actually the commitment inside the opening proof. It is called `sg` in the implementation and is semantically a polynomial commitment to `h(X)` (`b_poly` in the code) --- the poly-sized polynomial that is built from IPA challenges. It's a very important polynomial -- it can be evaluated in log time, but the commitment verification takes poly time, so the fact that `sg` is a commitment to `h(X)` is never proven inside the circuit. For more details, see [Proof-Carrying Data from Accumulation Schemes](https://eprint.iacr.org/2020/499.pdf), Appendix A.2, where `sg` is called `U`. + +In pickles, what we do is that we "absorb" this commitment `sg` from the previous step while creating a new proof. That is, for example, Step 1 will produce this commitment that is denoted as `acc1` on the diagram, as part of its opening proof, and Step 2 will absorb this commitment. And this "absorbtion" is what Wrap 2 will prove (and, partially, Step 3 will also refer to the challenges used to build `acc1`, but this detail is completely avoided in this overview). In the end, `acc2` will be the result of Step 2, so in a way `acc2` "aggregates" `acc1` which somewhat justifies the language used. diff --git a/book/src/pickles/pickles_structure.drawio b/book/src/pickles/pickles_structure.drawio new file mode 100644 index 0000000000..c6cb8d7f67 --- /dev/null +++ b/book/src/pickles/pickles_structure.drawio @@ -0,0 +1,3476 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/book/src/pickles/pickles_structure_legend_1.svg b/book/src/pickles/pickles_structure_legend_1.svg new file mode 100644 index 0000000000..04a08f2308 --- /dev/null +++ b/book/src/pickles/pickles_structure_legend_1.svg @@ -0,0 +1,4 @@ + + + +
Produced 3 iterations ago, in the Step before the last Step
Produced 3 iteration...
Last Step:
Produced 1 iteration ago, in the last Step
Last Step:...
Produced 2 iterations ago, in the previous Wrap
Produced 2 iteration...
This Wrap:
Produced in this iteration
This Wrap:...
Produced 2 iterations ago, in the previous Step
Produced 2 iteration...
This Step:
Produced in this iteration
This Step:...
Produced 3 iterations ago, in the Wrap before the last Wrap
Produced 3 iteration...
Last Wrap:
Produced 1 iteration ago, in the last Wrap
Last Wrap:...
When in Step:
When in Step:
When in Wrap:
When in Wrap:
Produced 4 iterations ago, in the prev-previous Wrap
Produced 4 iteration...
Produced 2 iterations ago, in the previous Step
Produced 2 iteration...
Text is not SVG - cannot display
\ No newline at end of file diff --git a/book/src/pickles/pickles_structure_overview.svg b/book/src/pickles/pickles_structure_overview.svg new file mode 100644 index 0000000000..ac0192c085 --- /dev/null +++ b/book/src/pickles/pickles_structure_overview.svg @@ -0,0 +1,4 @@ + + + +
Step 1:
1.
2. Wrap 0 main checks
    2.1. acc-1' was
           aggregated properly
3. Step 0 deferred checks

(2-3 may be repeated)
Step 1:...
Step 2:
1.
2. Wrap 1 main checks
   2.1. acc0' was
          aggregated properly
3. Step 1 deferred checks

(2-3 may be repeated)
Step 2:...
Wrap 1:
1. Step 1 main checks
    1.1. acc0 was
           aggregated properly
2. Wrap 0 deferred checks
Wrap 1:...
Wrap 2:
1. Step 2 main checks
    1.1. acc1 was
           aggregated properly
2. Wrap 1 deferred checks
Wrap 2:...
acc1
acc1
acc2
acc2
acc0
acc0
acc1'
acc1'
acc2'
acc2'
acc0'
acc0'
application statement
application statement...
application statement
application statement...
`\ldot...
`\ld...
`\ld...
`\ld...
`\ld...
`\ldot...
Text is not SVG - cannot display
\ No newline at end of file diff --git a/book/src/pickles/pickles_structure_step.svg b/book/src/pickles/pickles_structure_step.svg new file mode 100644 index 0000000000..f536444c68 --- /dev/null +++ b/book/src/pickles/pickles_structure_step.svg @@ -0,0 +1,4 @@ + + + +
prev_proofs: Vec<WrapProof>
prev_proofs: Vec<WrapProof>
WrapStatement
WrapStatement
MFNStep
MFNStep
challenge_poly_commitment
challenge_poly_commitment
dlog_plonk_index
dlog_plonk_index
app_state
app_state
WrapProofState
WrapProofState
WrapDeferredValues
WrapDeferredValues
xi
xi
plonk
plonk
combined_inner_product
combined_inner_product
bulletproof_challenges
bulletproof_challenges
branch_data
branch_data
MFNWrap
MFNWrap
challenge_poly_commitment
challenge_poly_commitment
old_bulletproof_challenges
old_bulletproof_challenges
next_statement: StepStatement
next_statement: StepStatement
Vec<MFNWrap>
Vec<MFNWrap>
challenge_poly_commitment
challenge_poly_commitment
old_bulletproof_challenges
old_bulletproof_challenges
StepProofState
StepProofState
unfinalized_proofs: Vec<PerProof>
unfinalized_proofs: Vec<PerProof>
StepDeferredValues
StepDeferredValues
xi
xi
plonk
plonk
combined_inner_product
combined_inner_product
b
b
bulletproof_challenges
bulletproof_challenges
should_finalize
should_finalize
sponge_digest_before_evaluations
sponge_digest_before_evaluations
MFNStep
MFNStep
challenge_poly_commitment
challenge_poly_commitment
app_state
app_state
dlog_plonk_index
dlog_plonk_index
old_bulletproof_challenges
old_bulletproof_challenges
StepProof
StepProof
prev_evals
prev_evals
index
index
statement
statement
proof: ProverProof
proof: ProverProof
prev_challenges
prev_challenges
...
...
Prove
Prove
x
x
w
w
rule.main
r...
branch_data
branch_data
index
index
proofs_verified
proofs_ver...
rule
rule
feature_flags
feature_fl...
requests
requests
main
main
lte
lte
domains
domains
challenge_polynomial_commitments = sgs
challenge_polynomial_commitments = sgs
unfinalized_proofs
unfinalized_proofs
map (\s -> s.proof_state.messages_for_next_wrap_proof), and pad with dummies
ma...
statement_with_hashes
statement_with_hashes
x_hats
x_hats
witnesses
witnesses
expand_proof
expand_proof
This function partially reruns the previous (wrap) proof and generates data that is necessary to create a step proof about it.
1. Unpacks some previous data, e.g. "plonk", "prev_challenges", and "prev_statement_with_hashes" (returned as statement_with_hashes).
2. Computes deferred values of these wrap proofs using Wrap_deferred_values.expand_deferred.
3. Creates an oracle O from (dlog_vk, mfns.challenge_polynomial_commitments, prev_challenges, public_input, proof)
4. Computes new_bulletproof_challenges and $b$ using O.opening_prechallenges, O.r, O.zeta, O.zetaw.
5. Sets challenge_polynomial_commitment to proof.openings.proof.challenge_polynomial_commitment
6. Creates witness: PerProofWitness (that is used inside the circuit)
7. Computes combined_inner_product
8. Builds unfinalized_proofs from computed components (plonk, combined_inner_product, xi, bulletproof_challenges, b)
This function partially reruns the previous (wrap) proof and generate...
dlog_vk
dlog_vk
app_state
app_state
prev_proof
prev_proof
proof_must_verify
proof_must_verify
app_state: either next_state OR/AND return_value
app_state: either next_st...
self_dlog_plonk_index
self_dlog_plonk_index
concat all together ?
c...
compute app_state
c...
public_input
public_input
prev_challenge_polynomial_commitments: Vec<_>
prev_challenge_polynomial_commitments: Vec<_>
challenge_poly_commitment
challenge_poly_commitment
old_bulletproof_challenges
old_bulletproof_challenges
main =* rule.main
m...
previous_proof_statements are accessed through handler
which is called in step_main (Compute_prev_proof_parts)
while unpacking rule.main
p...
step_main/incrementally_verify_proof     (over Fp)
step_main/incrementally_verify_proof     (over Fp)
sponge_digest_before_evaluations_actual
sponge_digest_before_evaluations_actual
bulletproof_challenges_actual
bulletproof_challenges_actual
bulletproof_success
bulletproof_success

Partially verifies the previous step proof, deferring some computations to the next iteration. This implicitly (partially?) verifies the correct accumulation.

Similar to wrap/incrementally_verify_proof

Partially verifies the previous step proof, deferr...
assert equal
assert equal
step_verifier/finalize_other_proof     (over Fp)
step_verifier/finalize_other_proof     (over Fp)
bulletproof_challenges
bulletproof_challenges
sponge
sponge
prev_challenges
prev_challenges
(ft_eval1, evals)
(ft_eval1, evals)

This finalizes the "deferred values" coming from a previous proof over the same field. 

Returns AND of the following 4 checks:

1. Checks that and where sampled correctly. I.e., by absorbing all the evaluation openings and then squeezing.
2. Checks that the "combined inner product" value used in the elliptic curve part of the opening proof was computed correctly, in terms of the evaluation openings and the evaluation points.
3. Check that the value was computed correctly (recombined using ).
4. Perform the arithmetic checks from marlin.

This finalizes the "deferred values" coming from...
bool_checks
bool_checks
deferred_values.bulletproof_challenges()
de...
deferred_values
deferred_values
unfinalized: PerProof
unfinalized: PerProof
sponge_digest_before_evaluations
sponge_digest_before_evaluations
deferred_values
deferred_values
should_finalize
should_finalize
step_main/verify    (over Fp)
step_main/verify    (over Fp)
create new as 
Poseidon over Fp
create new as...
pack_statement
pa...
statement
statement
sg_old
sg_old
proof 
proof 
sponge_after_index
sponge_after_index
public_input
public_input
sponge
sponge
sg_old
sg_old
proof (messages + openings)
proof (messages + openings)
plonk
plonk
xi
xi
advice.b
advice.b
advice.combined_inner_product
advice.combined_inner_product
sponge_after_index
sponge_after_index
unfinalized_proof: PerProof
unfinalized_proof: PerProof
StepDeferredValues
StepDeferredValues
xi
xi
plonk
plonk
combined_inner_product
combined_inner_product
b
b
bulletproof_challenges
bulletproof_challenges
should_finalize
should_finalize
sponge_digest_before_evaluations
sponge_digest_before_evaluations
map (assert equal)
map (assert equal)
finalized
finalized
chals
chals
verified & (finalized | !should_verify)
verified & (finalized | !should_verify)
verified
verified
PerProofWitness
PerProofWitness
app_state
app_state
wrap_proof
wrap_proof
prev_proof_evals
prev_proof_evals
prev_challenge_polynomial_commitments
prev_challenge_polynomial_commitments
prev_challenges
prev_challenges
proof_state: WrapProofState
proof_state: WrapProofState
sponge_digest_before_evaluations
sponge_digest_before_evaluations
WrapDeferredValues
WrapDeferredValues
xi
xi
plonk
plonk
combined_inner_product
combined_inner_product
b
b
bulletproof_challenges
bulletproof_challenges
branch_data
branch_data
MFNWrap
MFNWrap
challenge_poly_commitment
challenge_poly_commitment
old_bulletproof_challenges
old_bulletproof_challenges
override
ov...
MFNW (Digest)
MFNW (Digest)
should_verify
should_verify
d (datas): ForStep
d (datas): ForStep
branches
branches
num_chunks
num_chunks
feature_flags
feature_flags
step_domains
step_domains
wrap_domain
wrap_domain
public_input
public_input
proofs_verifieds
proofs_verifieds
zk_rows
zk_rows
wrap_key
wrap_key
max_proofs_verified
max_proofs_verified
Create new sponge
& absorb the input
Create new sponge...
statement: StepStatement
statement: StepStatement
proof_state
proof_state
MFNStep
MFNStep
challenge_poly_commitment
challenge_poly_commitment
app_state
app_state
dlog_plonk_index
dlog_plonk_index
old_bulletproof_challenges
old_bulletproof_challenges
Hash
Ha...
step_main/verify_one    (over Fp)
step_main/verify_one    (over Fp)
hash_messages_for_next_step_proof_opt
hash_messages_for_next_step_proof_opt
previous_proofs_statements: Vec<PreviousProofStatement>
previous_proofs_statements: Vec<Prev...
proof_must_verify
proof_must_verify
public_input
public_input
proof (unused)
proof (unused)
public_output (ret_var)
public_output (ret_var)
auxiliary_output (auxiliary_var)
auxiliary_output (auxiliary_...
app_state
app_state
next_state
next_state
dlog_plonk_index
dlog_plonk_index
self_dlog_plonk_index
self_dlog_plonk_index
map (\proof_witness -> proof_witness.wrap_proof.opening.challenge_polynomial_commitment)
ma...
prevs: Vec<PerProofWitness>
prevs: Vec<PerProofWitness>
message:
prev_challenges & prev_inputs
message:...
unfinalized_proofs_unextended
unfinalized_proofs_unextended
messages_for_next_wrap_proof
messages_for_next_wrap_proof
override app_state
in prevs
ov...
'd' contains all
7 fields of 'basic'
'd...
basic: Basic
basic: Basic
wrap_domains
wrap_domains
public_input
public_input
proofs_verifieds
proofs_verifieds
feature_flags
feature_flags
num_chunks
num_chunks
zk_rows
zk_rows
step_domains
step_domains
--- Direct Inputs to step_main: ---
--- Direct Inputs to step_main: ---
bulletproof_challenges
bulletproof_challenges
concat, assert all
concat, assert all
MFNStep
MFNStep
app_state
app_state
dlog_plonk_index
dlog_plonk_index
old_bulletproof_challenges
old_bulletproof_challenges
challenge_poly_commitment
challenge_poly_commitment
hash_messages_for_next_step_proof
hash_messages_for_next_step_proof
match public_input with:
1. app_state 
2. ret_var
3. (app_state,ret_var)
match public_input with...
public_input
public_input
actual_wrap_domains
actual_wrap_domains
send back to step.ml
send back to step.ml
send back to step.ml
send back to step.ml
run application logic of previous statements
run application logic of...
rule
rule
send back to step.ml 
triggering 
compute_prev_proof_parts
send back to step.ml...
previous_proofs_statements: Vec<PreviousProofStatement>
previous_proofs_statements: Vec<PreviousProofState...
proof_must_verify
proof_must_verify
public_input = app_state
public_input = app_state
ProverProof
ProverProof
OpeningProof
OpeningProof

sg = challenge_polynomial_commitment

sg = challenge_polynomial_commitment

lr

lr

z1

z1

z2

z2

delta

delta
commitments
commitments
prev_challenges
prev_challenges
ft_eval1
ft_eval1
evals
evals
unused, see 
plonk_dlog_proof.ml/of_backend
unused, see...
prev_evals: AllEvals
prev_evals: AllEvals
evals
evals
public_input
public_input
ft_eval1
ft_eval1
b
b
old_bulletproof_challenges
old_bulletproof_challenges
sponge_digest_before_evaluations
sponge_digest_before_evaluations
actual_wrap_domains
actual_wrap_domains
(Concatenated outputs of expand_proof):
(Concatenated outputs of...
prev_evals: Vec<_>
prev_evals: Vec<_>
evals
evals
ft_eval1
ft_eval1
AllEvals
AllEvals
public_input
public_input
evals
evals
ft_eval1
ft_eval1
Assert equal ?????
Assert equa...
Text is not SVG - cannot display
\ No newline at end of file diff --git a/book/src/pickles/pickles_structure_wrap.svg b/book/src/pickles/pickles_structure_wrap.svg new file mode 100644 index 0000000000..2dd34ffaf5 --- /dev/null +++ b/book/src/pickles/pickles_structure_wrap.svg @@ -0,0 +1,4 @@ + + + +
prev_statement: StepStatement
prev_statement: StepStatement
MFNWrap
MFNWrap
challenge_poly_commitment
challenge_poly_commitment
old_bulletproof_challenges
old_bulletproof_challenges
proof_state: StepProofState
proof_state: StepProofState
unfinalized_proofs: Vec<PerProof>
unfinalized_proofs: Vec<PerProof>
StepDeferredValues
StepDeferredValues
xi
xi
plonk
plonk
combined_inner_product
combined_inner_product
b
b
bulletproof_challenges
bulletproof_challenges
sponge_digest_before_evaluations
sponge_digest_before_evaluations
should_finalize
should_finalize
MFNStep
MFNStep
challenge_poly_commitment
challenge_poly_commitment
app_state
app_state
dlog_plonk_index
dlog_plonk_index
old_bulletproof_challenges
old_bulletproof_challenges
ProverProof
ProverProof
OpeningProof
OpeningProof

sg = challenge_polynomial_commitment

sg = challenge_polynomial_commitment

lr

lr

z1

z1

z2

z2

delta

delta
commitments
commitments
prev_challenges
prev_challenges
ft_eval1
ft_eval1
evals
evals
unused, see 
plonk_dlog_proof.ml/of_backend
unused, see...
next_statement: WrapStatement
next_statement: WrapStatement
MFNStep
MFNStep
challenge_poly_commitment
challenge_poly_commitment
old_bulletproof_challenges
old_bulletproof_challenges
dlog_plonk_index
dlog_plonk_index
app_state
app_state
WrapProofState
WrapProofState
sponge_digest_before_evaluations
sponge_digest_before_evaluations
WrapDeferredValues
WrapDeferredValues
xi
xi
plonk
plonk
combined_inner_product
combined_inner_product
b
b
bulletproof_challenges
bulletproof_challenges
branch_data
branch_data
MFNWrap
MFNWrap
challenge_poly_commitment: E_Vesta = (Fq,Fq)
challenge_poly_commitment:...
old_bulletproof_challenges: Vec<Vec<Fq>>
old_bulletproof_challenges...
Prove
Prove
WrapProof
WrapProof
statement
statement
prev_evals
prev_evals
proof: ProverProof
proof: ProverProof
prev_challenges
prev_challenges
...
...
AllEvals
AllEvals
public_input
public_input
ft_eval1
ft_eval1
evals
evals
StepStatementWithHashes
StepStatementWithHashes
StepProofState
StepProofState
unfinalized_proofs: Vec<PerProof>
unfinalized_proofs: Vec<PerProof>
MFNStep_hashed
MFNStep_hashed
MFNWrap_hashed
MFNWrap_hashed
Hash
H...
Hash
H...
deferred_values: computing deferred values
deferred_values: com...
sgs
sgs
prev_chals
prev_chals
public_input
public_input
IPA.Step.compute_challenges
I...
proof, vk, feature_flags
proof, vk, feature_fl...
x_hat_evals
x_hat_evals
digest
digest
next_acc: Vec<next_acc_elem>
next_acc: Vec<next_acc_elem>
next_acc_elem
next_acc_elem
commitment
commitment
challenges
challenges
front append with dummies
if necessary
f...
x
x
message:
prev_challenges & prev_inputs
message:...
StepProofWPublicEvals
StepProofWPublicEvals
public_evals
public_evals
proof: StepProof
proof: StepProof
prev_statement
prev_statement
prev_evals
prev_evals
proof
proof
index
index
step_vk
step_vk
actual_feature_flags
actual_feature_flags
prev_proof_state.unfinalized_proofs
p...
prev_proof_state
prev_proof_state
step_plonk_index
step_plonk_index
hash
h...
prev_step_accs
prev_step_accs
hash
h...
old_bp_challenges: Vec<Vec<Fq>>
old_bp_challenges: Vec<Vec<Fq>>
openings_proof.challenge_polynomial_commitment
o...
openings_proof
openings_proof
messages
(ProverCommitments)
messages...
WrapStatement
WrapStatement
WrapProofState
WrapProofState
MFNW_digest
MFNW_digest
WrapDeferredValues
WrapDeferredValues
xi
xi
plonk
plonk
combined_inner_product
combined_inner_product
b
b
bulletproof_challenges
bulletproof_challenges
branch_data
branch_data
sponge_digest_before_evaluations
sponge_digest_before_evaluations
MFNS (digest?) -- ignored/bound by SE
MFNS (digest?) -- ignored/bo...
assert equal
assert equal
Hash
H...
MFNWrap
MFNWrap
challenge_poly_commitment
challenge_poly_commitment
old_bulletproof_challenges
old_bulletproof_challenges
assert
assert
wrap_verifier/incrementally_verify_proof    (over Fq)
wrap_verifier/incrementally_verify_proof    (over Fq)
sponge_digest_before_evaluations_actual
sponge_digest_before_evaluations_actual
bulletproof_challenges_actual
bulletproof_challenges_actual
bulletproof_success
bulletproof_success
public_input
public_input
sponge
sponge
messages (ProverCommitments)
messages (ProverCommi...
openings_proof
openings_proof
plonk
plonk
advice.b
advice.b
advice.combined_inner_product
advice.combined_inner_product

Partially verifies the previous step proof, deferring some computations to the next iteration. This implicitly (partially?) verifies the correct accumulation.

What it does:

- Absorbs verifier index, squeezes out index_digest, and absorbs it. Absorbs sg_old (it is a Vesta curve point, so a pair of elements from a Pallas scalar field; everything else absorbed here also absorbs Vesta points). Computes x_hat (public input polynomial evaluation) using public_input and absorbs it. Absorbs messages.w_comm. Absorbs runtime tables. 

- Computes joint_combiner: absorbs lookups, and then squeezes joint_combiner. Computes lookup_table_comm using joint_combiner and a particular type of accumulator. Computes lookup_sorted.

- Squeezes beta, gamma. Absorbs some lookup data. Aborbs messages.z_comm. Squeezes alpha. Absorbs messages.t_comm. Squeezes zeta. After no point there are no squeezes/absorbtions.

- Computes ft_comm

- Generates bulletproof_challenges

- Checks bulletproof validity on bulletproof_challenges (check_bulletproof) using xi, advice, openings_proof, and the polynomials computed in place 

- Asserts that the given plonk is equal to (alpha, beta, gamma, zeta, joint_combiner, feature_flags) that was computed internally.

Partially verifies the previous step proof, deferring some c...
sg_old: E_Vesta = (Fq,Fq)
sg_old: E_Vesta = (Fq,Fq)
xi
xi
map (assert equal)
map (assert equal)
assert equal
assert equal
pack_statement()
p...
prev_statement: StepStatement
prev_statement: StepStatement
proof_state
proof_state
MFNWrap
MFNWrap
challenge_poly_commitment
challenge_poly_commitment
old_bulletproof_challenges
old_bulletproof_challenges
wrap_verifier/finalize_other_proof    (over Fq)
wrap_verifier/finalize_other_proof    (over Fq)
new_bulletproof_challenges
new_bulletproof_challenges
sponge
sponge
old_bulletproof_challenges
old_bulletproof_challenges
evals
evals

This finalizes the "deferred values" coming from a previous proof over the same field. It
1. Checks that and where sampled correctly. I.e., by absorbing all the evaluation openings and then squeezing.
2. Checks that the "combined inner product" value used in the elliptic curve part of the opening proof was computed correctly, in terms of the evaluation openings and the evaluation points.
3. Check that the value was computed correctly.
4. Perform the arithmetic checks from marlin.

This finalizes the "deferred values" coming from...
StepDeferredValues
StepDeferredValues
xi
xi
plonk
plonk
combined_inner_product
combined_inner_product
b
b
bulletproof_challenges
bulletproof_challenges
unwrap properly
u...
unfinalized_proofs: [PerProof]
unfinalized_proofs: [PerProof]
sponge_digest_before_evaluations
sponge_digest_before_evaluations
deferred_values
deferred_values
should_finalize
should_finalize
create new as Poseeidon over Fq
create new as Poseei...
hardcoded through step_keys during compilation
hardcoded through ste...
verification_key
verification_key
w
w
Hash
H...
Hash
H...
Text is not SVG - cannot display
\ No newline at end of file diff --git a/book/src/pickles/pickles_structure_wrap_deferred_values.svg b/book/src/pickles/pickles_structure_wrap_deferred_values.svg new file mode 100644 index 0000000000..a9ba1bd241 --- /dev/null +++ b/book/src/pickles/pickles_structure_wrap_deferred_values.svg @@ -0,0 +1,4 @@ + + + +
if Some
i...
proof.with_public_evals
proof.with_public_evals
oracle
oracle
sgs
sgs
actual_feature_flags
actual_feature_flags
prev_challenges: Vec<Vec<Fp>>
prev_challenges: Vec<Vec<Fp>>
step_vk.domain.log_size_of_group
s...
step_vk
step_vk
public_input
public_input
proof
proof
actual_proof_verified
actual_proof_verified
fallback
f...
p_eval_1 & p_eval_2
p_eval_1 & p_eval_2
WrapDeferredValues
WrapDeferredValues
x_hat_evals
x_hat_evals
digest
digest
plonk0
plonk0
new_bulletproof_challenges
new_bulletproof_challenges
shift_value (?)
s...
b
b
proof.proof.openings.evals
proof.proof.openings.evals
proof.proof.openings.ft_eval1
proof.proof.openings.ft_eva...
b = challenge_poly zeta + (r * challenge_poly zetaw)
b = challenge_poly zeta...
o.opening_prechals
o.opening_prechals
r = o.u
r = o.u
xi = o.v
xi = o.v
zeta = o.zeta
zeta = o.zeta
zetaw = zeta * w
zetaw = zeta * w
proofs_verified
proofs_verified
domain_log2
domain_log2
O.create_with_public_evals
O.create_with_public_evals
Pow_2_roots_of_unity
Pow_2_roots_of_unity
domain
domain
scalars_env
scalars_env
endo
endo
mds
mds
zk_rows
zk_rows
field_of_hex
field_of_hex
tick_plonk_minimal
tick_plonk_minimal
tick_combined_evals
tick_combined_evals
domain
domain
proof.proof.opening.evals
proof.proof.opening.evals
static
static
wrap/deferred_values
wrap/deferred_values
x_hat
x_hat
combined_inner_product
combined_inner_product

It's a joint linearized evaluation of all our polynomials combined.

It's sponge hashed/absorbed before starting to produce first IPA prechallenges. Used in IPA in two places:
- In the prover computed and absorbed before the folding loop
- Passed to the verifier as a and then also absorbed before the loop.

Also used explicitly in pickles to compute deferred values.

In rust: 

Assuming that each polynomial in is represented by a matrix where the rows correspond to evaluation on points, and the columns represent potential segments (if a polynomial was split in several parts), it computes:

In practice, since we have two evaluation points, will be equal to 



In ocaml:

Compute ft_eval0 (by first computing for it, and using both previous proof evals and public input).

Compute (old) by recombining each set of challenges from (which is a vector of vectors).

Compute as a list of concatenated 
1) grand-old (one step behind) challenge polys evaluated on (each challenge poly is built from an element of old_bulletproof_challenges: Vec<Vec<Fp>>)
2) old (but maybe current, x_hat) public input evaluations (proof_w_public_evals.public_evals) on ,
3) old proof_w_public_evals.proof.openings.evaluations (corresponding to ProverProof.evals) on ,
4) single evaluation ft_eval0

Compute : recombine them all with

Compute and similarly but w.r.t. and ft_eval1.

Return



It's a joint linearized evaluation of all our polynomials combined....
public_input
public_input
r
r
xi
xi
zeta, zetaw
zeta, zetaw
old_bulletproof_challenges
old_bulletproof_challenges
ft_eval1
ft_eval1
env
env
evals
evals
domain
domain
plonk
plonk
xi
xi
plonk
plonk
combined_inner_product
combined_inner_product
bulletproof_challenges
bulletproof_challenges
branch_data
branch_data
b
b
challenge_poly(prechals)
challenge_poly(prechals)
o.digest_before_evaluations
o.digest_before_evaluations
Text is not SVG - cannot display
\ No newline at end of file diff --git a/book/src/fundamentals/zkbook_ips.md b/book/src/pickles/zkbook_ips.md similarity index 87% rename from book/src/fundamentals/zkbook_ips.md rename to book/src/pickles/zkbook_ips.md index deb9ceb2dc..f5ed648750 100644 --- a/book/src/fundamentals/zkbook_ips.md +++ b/book/src/pickles/zkbook_ips.md @@ -4,7 +4,7 @@ Earlier we defined zero-knowledge proofs as being proofs of a computation of a f We will now go beyond this, and try to define zero-knowledge proof systems for computations that proceed **inductively**. That is, in pieces, and potentially over different locations involving different parties in a distributed manner. -An example of an **inductive computation** would be the verification of the Mina blockchain. Each block producer, when they produce a new block, +An example of an **inductive computation** would be the verification of the Mina blockchain. Each block producer, when they produce a new block, - verifies that previous state of the chain was arrived at correctly @@ -22,7 +22,7 @@ Ok, so what are inductive SNARKs? Well, first let's describe precisely the afore ## Inductive sets -zk-SNARKs [as defined earlier](./zkbook_plonk.md) allow you to prove for efficiently computable functions $f \colon A \times W \to \mathsf{Bool}$ statements of the form +zk-SNARKs [as defined earlier](../fundamentals/zkbook_plonk.md) allow you to prove for efficiently computable functions $f \colon A \times W \to \mathsf{Bool}$ statements of the form > I know $w \colon W$ such that $f(a, w) = \mathsf{true}$ @@ -31,7 +31,7 @@ Another way of looking at this is that they let you prove membership in sets of $A_f := \{ a \colon A \mid \text{There exists } w \colon W \text{ such that} f(a, w) = \mathsf{true} \}$[^1] -These are called `NP sets`. In intuitive terms, an NP set is one in which membership can be efficiently checked given some "witness" or helper information (which is $w$). +These are called [NP sets](https://en.wikipedia.org/wiki/NP_(complexity)). In intuitive terms, an NP set is one in which membership can be efficiently checked given some "witness" or helper information (which is $w$). Inductive proof systems let you prove membership in sets that are inductively defined. An inductively defined set is one where membership can be efficiently checked given some helper information, but the computation is explicitly segmented into pieces. @@ -45,7 +45,7 @@ We will give a recursive definition of a few concepts. Making this mathematicall -The data of an **inductive rule for a type $A$** is +The data of an **inductive rule for a type $A$** is - a sequence of inductive sets $A_0, \dots, A_{n-1}$ (note the recursive reference to ) @@ -59,7 +59,7 @@ The data of an **inductive set over a type $A$** is The subset of $A$ corresponding to $R$ (which we will for now write $A_R$) is defined inductively as follows. -For $a \colon A$, $a \in A_R$ if and only if +For $a \colon A$, $a \in A_R$ if and only if - there is some inductive rule $r$ in the sequence $R$ with function $f \colon A \times W_r \times A_{r,0} \times \dots A_{r, n_r-1} \to \mathsf{Bool}$ such that @@ -79,4 +79,4 @@ Actually there is a distinction between an inductive set $A$, the type underlyin --- -See this blog post (TODO: copy content of blogpost here) +See [this blog post](https://zkproof.org/2020/06/08/recursive-snarks/) diff --git a/book/src/plonk/domain.md b/book/src/plonk/domain.md index 7eb4229465..4dd79da36a 100644 --- a/book/src/plonk/domain.md +++ b/book/src/plonk/domain.md @@ -1,13 +1,13 @@ # Domain -Plonk needs a domain to encode the circuit on (for each point we can encode a row/constraint/gate). +$\plonk$ needs a domain to encode the circuit on (for each point we can encode a row/constraint/gate). We choose a domain $H$ such that it is a multiplicative subgroup of the scalar field of our curve. ## 2-adicity Furthermore, as FFT (used for interpolation, [see the section on FFTs](https://o1-labs.github.io/proof-systems/fundamentals/zkbook_fft.html)) works best in domains of size $2^k$ for some $k$, we choose $H$ to be a subgroup of order $2^k$. -Since the [pasta curves](https://o1-labs.github.io/proof-systems/specs/pasta.html) both have scalar fields of prime orders $p$ and $q$, their multiplicative subgroup is of order $p-1$ and $q-1$ respectively (without the zero element). +Since the [pasta curves](https://o1-labs.github.io/proof-systems/specs/pasta.html) both have scalar fields of prime orders $p$ and $q$, their multiplicative subgroup is of order $p-1$ and $q-1$ respectively (without the zero element). [The pasta curves were generated](https://forum.zcashcommunity.com/t/noob-question-about-plonk-halo2/39098) specifically so that both $p-1$ and $q-1$ are multiples of $2^{32}$. We say that they have *2-adicity* of 32. @@ -52,11 +52,11 @@ this allows you to easily find subgroups of different sizes of powers of 2, whic ## Efficient computation of the vanishing polynomial -For each curve, we generate a domain $H$ as the set $H = {1, \omega, \omega^2, \cdots, \omega^{n-1}}$. -As we work in a multiplicative subgroup, the polynomial that vanishes in all of these points can be written and efficiently calculated as $Z_H(x) = x^{|H|} - 1$. +For each curve, we generate a domain $H$ as the set $H = {1, \omega, \omega^2, \cdots, \omega^{n-1}}$. +As we work in a multiplicative subgroup, the polynomial that vanishes in all of these points can be written and efficiently calculated as $Z_H(x) = x^{|H|} - 1$. This is because, by definition, $\omega^{|H|} = 1$. If $x \in H$, then $x = \omega^i$ for some $i$, and we have: $$x^{|H|} = (\omega^i)^{|H|} = (\omega^{|H|})^i = 1^i = 1$$ -This optimization is important, as without it calculating the vanishing polynomial would take a number of operations linear in $|H|$ (which is not doable in a verifier circuit, for recursion). +This optimization is important, as without it calculating the vanishing polynomial would take a number of operations linear in $|H|$ (which is not doable in a verifier circuit, for recursion). With the optimization, it takes us a logarithmic number of operation (using [exponentiation by squaring](https://en.wikipedia.org/wiki/Exponentiation_by_squaring)) to calculate the vanishing polynomial. diff --git a/book/src/plonk/inner_product.md b/book/src/plonk/inner_product.md index 544191e1e8..f2591f025b 100644 --- a/book/src/plonk/inner_product.md +++ b/book/src/plonk/inner_product.md @@ -85,7 +85,7 @@ where each $G_i$ is distinct and has an unknown discrete logarithm as well. I'll often shorten the last formula as the inner product $\langle \vec{x}, \vec{G} \rangle$ for $\vec{x} = (x_1, \cdots, x_k)$ and $\vec{G} = (G_1, \cdots, G_k)$. To reveal a commitment, simply reveal the values $x_i$. -Pedersen hashing allow commitents that are non-hiding, but binding, as you can't open them to a different value than the originally comitted one. +Pedersen hashing allow commitents that are non-hiding, but binding, as you can't open them to a different value than the originally committed one. And as you can see, adding the commitment of $x$ and $y$ gives us the commitment of $x+y$: $$xG + yG = (x+y)G$$ diff --git a/book/src/plonk/maller.md b/book/src/plonk/maller.md index b2d97c10c8..254ebb0c02 100644 --- a/book/src/plonk/maller.md +++ b/book/src/plonk/maller.md @@ -1,6 +1,6 @@ -# Maller optimization to reduce proof size +# Maller's Optimization to Reduce Proof Size -In the PLONK paper, they make use of an optimization from Mary Maller in order to reduce the proof size. +In the $\plonk$ paper, they make use of an optimization from Mary Maller in order to reduce the proof size. ## Explanation @@ -19,7 +19,7 @@ Let's see the protocol where _Prover_ wants to prove to _Verifier_ that $$\forall x \in \mathbb{F}, \; h_1(x)h_2(x) - h_3(x) = 0$$ -given commitments of $h_1, h_2, h_3$, +given commitments of $h_1, h_2, h_3$, ![maller 1](../img/maller_1.png) @@ -37,7 +37,7 @@ Note right of Verifier: verifies that \n h1(s)h2(s) - h3(s) = 0 ``` --> -A shorter proof exists. Essentially, if the verifier already has the opening `h1(s)`, they can reduce the problem to showing that +A shorter proof exists. Essentially, if the verifier already has the opening `h1(s)`, they can reduce the problem to showing that $$ \forall x \in \mathbb{F}, \; L(x) = h_1(s)h_2(x) - h_3(x) = 0$$ @@ -87,9 +87,9 @@ The problem here is that you can't multiply the commitments together without usi If you're using an inner-product-based commitment, you can't even multiply commitments. -question: where does the multiplication of commitment occurs in the pairing-based protocol of PLONK? And how come we can use bootleproof if we need that multiplication of commitment? +question: where does the multiplication of commitment occurs in the pairing-based protocol of $\plonk$? And how come we can use bootleproof if we need that multiplication of commitment? -## Appendix: Original explanation from the PLONK paper +## Appendix: Original explanation from the $\plonk$ paper https://eprint.iacr.org/2019/953.pdf diff --git a/book/src/plonk/overview.md b/book/src/plonk/overview.md index 9b2426ddc7..bb088e9420 100644 --- a/book/src/plonk/overview.md +++ b/book/src/plonk/overview.md @@ -1,13 +1,13 @@ # Overview -The proof system in Mina is a variant of [PLONK](). To understand PLONK, you can refer to our [series of videos](https://www.youtube.com/watch?v=RUZcam_jrz0&list=PLBJMt6zV1c7Gh9Utg-Vng2V6EYVidTFCC) on the scheme. +The proof system in Mina is a variant of [$\plonk$](https://eprint.iacr.org/2019/953.pdf). To understand $\plonk$, you can refer to our [series of videos](https://www.youtube.com/watch?v=RUZcam_jrz0&list=PLBJMt6zV1c7Gh9Utg-Vng2V6EYVidTFCC) on the scheme. In this section we explain our variant, called **kimchi**. kimchi is not formally a zk-SNARK, as it is not succinct in the proof size. zk-SNARKs must have a $log(n)$ proof size where n is the number of gates in the circuit. In kimchi, due to the Bootleproof polynomial commitment approach, our proofs are $log(d)$ where $d$ is the maximum degree of the committed polynomials (in the order of the number of rows). In practice, our proofs are in the order of dozens of kilobytes. Note that kimchi is a zk-SNARK in terms of verification complexity ($log(n$)) and the proving complexity is quasilinear ($O(nlogn)$) – recall that the prover cannot be faster than linear. -Essentially what PLONK allows you to do is to, given a program with inputs and outputs, take a snapshot of its execution. Then, you can remove parts of the inputs, or outputs, or execution trace, and still prove to someone that the execution was performed correctly for the remaining inputs and outputs of the snapshot. +Essentially what $\plonk$ allows you to do is to, given a program with inputs and outputs, take a snapshot of its execution. Then, you can remove parts of the inputs, or outputs, or execution trace, and still prove to someone that the execution was performed correctly for the remaining inputs and outputs of the snapshot. There are a lot of steps involved in the transformation of "I know an input to this program such that when executed with these other inputs it gives this output" to "here's a sequence of operations you can execute to convince yourself of that". At the end of this chapter, you will understand that the protocol boils down to filling a number of tables (illustrated in the diagram below): tables that specify the circuit, and tables that describes an execution trace of the circuit (given some secret and public inputs). diff --git a/book/src/plonk/plookup.md b/book/src/plonk/plookup.md index 23466337e0..7268a6c8e2 100644 --- a/book/src/plonk/plookup.md +++ b/book/src/plonk/plookup.md @@ -1,8 +1,15 @@ -# Plookup +# $\plookup$ -Plookup allows us to check if witness values belong to a look up table. This is usualy useful for reducing the number of constraints needed for bit-wise operations. So in the rest of this document we'll use the XOR table as an example. +$\plookup$ allows us to check if witness values belong to a look up table. This is usually useful for reducing the number of constraints needed for bit-wise operations. So in the rest of this document we'll use the XOR table as an example. + + +| A | B | Y | +| --- | --- | --- | +| 0 | 0 | 0 | +| 0 | 1 | 1 | +| 1 | 0 | 1 | +| 1 | 1 | 0 | -![XOR table](../img/xor.png) First, let's define some terms: diff --git a/book/src/plonk/zkpm.md b/book/src/plonk/zkpm.md index 5ab2acf37f..4820a55095 100644 --- a/book/src/plonk/zkpm.md +++ b/book/src/plonk/zkpm.md @@ -53,16 +53,15 @@ $$ **Modified permutation checks.** To recall, the permutation check was originally as follows. For all $h \in H$, -- $L_1(h)(Z(h) - 1) = 0$ -- $$ - Z(h)[(a(h) + \beta h + \gamma)(b(h) + \beta k_1 h + \gamma)(c(h) + \beta k_2 h + \gamma)] \\ - = Z(\omega h)[(a(h) + \beta S_{\sigma1}(h) + \gamma)(b(h) + \beta S_{\sigma2}(h) + \gamma)(c(h) + \beta S_{\sigma3}(h) + \gamma)] - $$ +* $L_1(h)(Z(h) - 1) = 0$ +* $$Z(h)[(a(h) + \beta h + \gamma)(b(h) + \beta k_1 h + \gamma)(c(h) + \beta k_2 h + \gamma)] \\ + = Z(\omega h)[(a(h) + \beta S_{\sigma1}(h) + \gamma)(b(h) + \beta S_{\sigma2}(h) + \gamma)(c(h) + \beta S_{\sigma3}(h) + \gamma)]$$ + The modified permuation checks that ensures that the check is performed only on all the values except the last $k$ elements in the witness polynomials are as follows. -- For all $h \in H$, $L_1(h)(Z(h) - 1) = 0$ -- For all $h \in \blue{H\setminus \{h_{n-k}, \ldots, h_n\}}$, $$\begin{aligned} & Z(h)[(a(h) + \beta h + \gamma)(b(h) + \beta k_1 h + \gamma)(c(h) + \beta k_2 h + \gamma)] \\ +* For all $h \in H$, $L_1(h)(Z(h) - 1) = 0$ +* For all $h \in \blue{H\setminus \{h_{n-k}, \ldots, h_n\}}$, $$\begin{aligned} & Z(h)[(a(h) + \beta h + \gamma)(b(h) + \beta k_1 h + \gamma)(c(h) + \beta k_2 h + \gamma)] \\ &= Z(\omega h)[(a(h) + \beta S_{\sigma1}(h) + \gamma)(b(h) + \beta S_{\sigma2}(h) + \gamma)(c(h) + \beta S_{\sigma3}(h) + \gamma)] \end{aligned}$$ * For all $h \in H$, $L_{n-k}(h)(Z(h) - 1) = 0$ diff --git a/book/src/rfcs/keccak.md b/book/src/rfcs/keccak.md deleted file mode 100644 index 03174df193..0000000000 --- a/book/src/rfcs/keccak.md +++ /dev/null @@ -1 +0,0 @@ -# RFC 7: Keccak diff --git a/book/src/snarky/booleans.md b/book/src/snarky/booleans.md new file mode 100644 index 0000000000..ace944ed4d --- /dev/null +++ b/book/src/snarky/booleans.md @@ -0,0 +1,73 @@ +# Booleans + +Booleans are a good example of a [snarky variable](./vars.md#snarky-vars). + +```rust +pub struct Boolean(FieldVar); + +impl SnarkyType for Boolean +where + F: PrimeField, +{ + type Auxiliary = (); + + type OutOfCircuit = bool; + + const SIZE_IN_FIELD_ELEMENTS: usize = 1; + + fn to_cvars(&self) -> (Vec>, Self::Auxiliary) { + (vec![self.0.clone()], ()) + } + + fn from_cvars_unsafe(cvars: Vec>, _aux: Self::Auxiliary) -> Self { + assert_eq!(cvars.len(), Self::SIZE_IN_FIELD_ELEMENTS); + Self(cvars[0].clone()) + } + + fn check(&self, cs: &mut RunState) { + // TODO: annotation? + cs.assert_(Some("boolean check"), vec![BasicSnarkyConstraint::Boolean(self.0.clone())]); + } + + fn deserialize(&self) -> (Self::OutOfCircuit, Self::Auxiliary) { + todo!() + } + + fn serialize(out_of_circuit: Self::OutOfCircuit, aux: Self::Auxiliary) -> Self { + todo!() + } + + fn constraint_system_auxiliary() -> Self::Auxiliary { + todo!() + } + + fn value_to_field_elements(x: &Self::OutOfCircuit) -> (Vec, Self::Auxiliary) { + todo!() + } + + fn value_of_field_elements(x: (Vec, Self::Auxiliary)) -> Self::OutOfCircuit { + todo!() + } +} +``` + +## Check + +The `check()` function is simply constraining the `FieldVar` $x$ to be either $0$ or $1$ using the following constraint: + +$$x ( x - 1) = 0$$ + +It is trivial to use the [double generic gate](../specs/kimchi.md#double-generic-gate) for this. + +## And + +$$x \land y = x \times y$$ + +## Not + +$$\sim x = 1 - x$$ + +## Or + +* $\sim x \land \sim y = b$ +* $x \lor y = \sim b$ diff --git a/book/src/snarky/circuit-generation.md b/book/src/snarky/circuit-generation.md new file mode 100644 index 0000000000..bec22cce0e --- /dev/null +++ b/book/src/snarky/circuit-generation.md @@ -0,0 +1,29 @@ +# Circuit generation + +In circuit generation mode, the `has_witness` field of `RunState` is set to the default `CircuitGeneration`, and the program of the user is ran to completion. + +During the execution, the different snarky functions called on `RunState` will create [internal variables](./vars.md) as well as constraints. + +## Creation of variables + +[Variables](./vars.md) can be created via the `compute()` function, which takes two arguments: + +* A `TypeCreation` toggle, which is either set to `Checked` or `Unsafe`. We will describe this below. +* A closure representing the actual computation behind the variable. This computation will only take place when real values are computed, and can be non-deterministic (e.g. random, or external values provided by the user). Note that the closure takes one argument: a `WitnessGeneration`, a structure that allows you to read the runtime values of any variables that were previously created in your program. + +The `compute()` function also needs a type hint to understand what type of [snarky type](./vars.md#snarky-vars) it is creating. + +It then performs the following steps: + +* creates enough [`FieldVar`](./vars#circuit-vars) to hold the value to be created +* retrieves the auxiliary data needed to create the snarky type (TODO: explain auxiliary data) and create the [`snarky variable`](./vars.md#snarky-vars) out of the `FieldVar`s and the auxiliary data +* if the `TypeCreation` is set to `Checked`, call the `check()` function on the snarky type (which will constrain the value created), if it is set to `Unsafe` do nothing (in which case we're trusting that the value cannot be malformed, this is mostly used internally and it is highly-likely that users directly making use of `Unsafe` are writing bugs) + +```admonish +At this point we only created variables to hold future values, and made sure that they are constrained. +The actual values will fill the room created by the `FieldVar` only during the [witness generation](./witness-generation.md). +``` + +## Constraints + +All other functions exposed by the API are basically here to operate on variables and create constraints in doing so. diff --git a/book/src/snarky/kimchi-backend.md b/book/src/snarky/kimchi-backend.md new file mode 100644 index 0000000000..7ab2e38bb6 --- /dev/null +++ b/book/src/snarky/kimchi-backend.md @@ -0,0 +1,271 @@ +# Kimchi Backend + +![](https://i.imgur.com/KmKU5Pl.jpg) + +Underneath the snarky wrapper (in `snarky/checked_runner.rs`) lies what we used to call the `plonk_constraint_system` or `kimchi_backend` in `snarky/constraint_systen.rs`. + +```admonish +It is good to note that we're planning on removing this abstract separation between the snarky wrapper and the constraint system. +``` + +The logic in the kimchi backend serves two purposes: + +* **Circuit generation**. It is the logic that adds gates to our list of gates (representing the circuit). For most of these gates, the variables used are passed to the backend by the snarky wrapper, but some of them are created by the backend itself (see more in the [variables section](#variables)). +* **Witness generation**. It is the logic that creates the witness + +One can also perform two additional operations once the constraint system has been compiled: + +* Generate the prover and verifier index for the system. +* Get a hash of the constraint system (this includes the circuit, the number of public input) (TODO: verify that this is true) (TODO: what else should be in that hash? a version of snarky and a version of kimchi?). + +## A circuit + +A circuit is either being built, or has been contructed during a circuit generation phase: + +```rust +enum Circuit +where + F: PrimeField, +{ + /** A circuit still being written. */ + Unfinalized(Vec>), + /** Once finalized, a circuit is represented as a digest + and a list of gates that corresponds to the circuit. + */ + Compiled([u8; 32], Vec>), +} +``` + +## State + +The state of the kimchi backend looks like this: + +```rust +where + Field: PrimeField, +{ + /// A counter used to track variables + /// (similar to the one in the snarky wrapper) + next_internal_var: usize, + + /// Instruction on how to compute each internal variable + /// (as a linear combination of other variables). + /// Used during witness generation. + internal_vars: HashMap, Option)>, + + /// The symbolic execution trace table. + /// Each cell is a variable that takes a value during witness generation. + /// (if not set, it will take the value 0). + rows: Vec>>, + + /// The circuit once compiled + gates: Circuit, + + /// The row to use the next time we add a constraint. + // TODO: I think we can delete this + next_row: usize, + + /// The size of the public input + /// (which fills the first rows of our constraint system. + public_input_size: Option, + + // omitted values... +} +``` + +## Variables + +In the backend, there's two types of variables: + +```rust +enum V { + /// An external variable + /// (generated by snarky, via [exists]). + External(usize), + + /// An internal variable is generated to hold an intermediate value, + /// (e.g. in reducing linear combinations to single PLONK positions). + Internal(InternalVar), +} +``` + +Internal variables are basically a `usize` pointing to a hashmap in the state. + +That hashmap tells you how to compute the internal variable during witness generation: it is always a linear combination of other variables (and a constant). + +## Circuit generation + +During circuit generation, the snarky wrapper will make calls to the `add_constraint()` or `add_basic_snarky_constraint` function of the kimchi backend, specifying what gate to use and what variables to use in that gate. + +At this point, the snarky wrapper might have some variables that are not yet tracked as such (with a counter). +Rather, they are constants, or they are a combination of other variables. +You can see that as a small AST representing how to compute a variable. +(See the [variables section](./vars.md#circuit-vars) for more details). + +For this reason, they can hide a number of operations that haven't been constrained yet. +It is the role of the `add_constrain` logic to enforce that at this point constants, as well as linear combinations or scalings of variables, are encoded in the circuit. +This is done by adding enough generic gates (using the `reduce_lincom()` or `reduce_to_var()` functions). + +```admonish +This is a remnant of an optimization targetting R1CS (in which additions are for free). +An issue with this approach is the following: imagine that two circuit variables are created from the same circuit variable, imagine also that the original circuit variable contained a long AST, then both variables might end up creating the same constraints to convert that AST. +Currently, snarkyjs and pickles expose a `seal()` function that allows you to reduce this issue, at the cost of some manual work and mental tracking on the developer. +We should probably get rid of this, while making sure that we can continue to optimize generic gates +(in some cases you can merge two generic gates in one (TODO: give an example of where that can happen)). +Another solution is to keep track of what was reduced, and reuse previous reductions (similar to how we handle constants). +``` + +It is during this "reducing" step that internal variables (known only to the kimchi backend) are created. + +```admonish +The process is quite safe, as the kimchi backend cannot use the snarky wrapper variables directly (which are of type `FieldVar`). +Since the expected format (see the [variables section](#variables) is a number (of type `usize`), the only way to convert a non-tracked variable (constant, or scale, or linear combination) is to reduce it (and in the process constraining its value). +``` + +Depending on the gate being used, several constraints might be added via the `add_row()` function which does three things: + +1. figure out if there's any wiring to be done +2. add a gate to our list of gates (representing the circuit) +3. add the variables to our _symbolic_ execution trace table (symbolic in the sense that nothing has values yet) + +This process happens as the circuit is "parsed" and the constraint functions of the kimchi backend are called. + +This does not lead to a finalized circuit, see the next section to see how that is done. + +(TODO: ideally this should happen in the same step) + +## Finalization of the circuit. + +So far we've only talked about adding specific constraints to the circuit, but not about how public input are handled. + +The `finalization()` function of the kimchi backend does the following: + +* add as many generic rows as there are public inputs. +* construct the permutation +* computes a cache of the circuit (TODO: this is so unecessary) +* and other things that are not that important + +## Witness generation + +Witness generation happens by taking the finalized state (in the `compute_witness()` function) with a callback that can be used to retrieve the values of external variables (public input and public output). + +The algorithm follows these steps using the symbolic execution table we built during circuit generation: + +1. it initializes the execution trace table with zeros +2. go through the rows related to the public input and set the most-left column values to the ones obtained by the callback. +3. go through the other rows and compute the value of the variables left in the table + +Variables in step 3. should either: + +* be absent (`None`) and evaluated to the default value 0 +* point to an external variable, in which case the closure passed can be used to retrieve the value +* be an internal variable, in which case the value is computed by evaluating the AST that was used to create it. + +## Permutation + +The permutation is used to wire cells of the execution trace table (specifically, cells belonging to the first 7 columns). +It is also known as "copy constraints". + +```admonish +In snarky, the permutation is represented differently from kimchi, and thus needs to be converted to the kimchi's format before a proof can be created. +TODO: merge the representations +``` + +We use the permutation in ingenious ways to optimize circuits. +For example, we use it to encode each constants once, and wire it to places where it is used. +Another example, is that we use it to assert equality between two cells. + +## Implementation details + +There's two aspect of the implementation of the permutation, the first one is a hashmap of equivalence classes, which is used to track all the positions of a variable, the second one is making use of a [union find]() data structure to link variables that are equivalent (we'll talk about that after). + +The two data structures are in the kimchi backend's state: + +```rust +pub struct SnarkyConstraintSystem +where + Field: PrimeField, +{ + equivalence_classes: HashMap>>, + union_finds: disjoint_set::DisjointSet, + // omitted fields... +} +``` + +### equivalence classes + +As said previously, during circuit generation a symbolic execution trace table is created. It should look a bit like this (if there were only 3 columns and 4 rows): + +| | 0 | 1 | 2 | +| :-: | :-: | :-: | :-:| +| 0 | v1 | v1 | | +| 1 | | v2 | | +| 2 | | v2 | | +| 3 | | | v1 | + +From that, it should be clear that all the cells containing the variable `v1` should be connected, +and all the cells containing the variable `v2` should be as well. + +The format that the permutation expects is a [cycle](https://en.wikipedia.org/wiki/Cyclic_permutation): a list of cells where each cell is linked to the next, the last one wrapping around and linking to the first one. + +For example, a cycle for the `v1` variable could be: + +``` +(0, 0) -> (0, 1) +(0, 1) -> (3, 2) +(3, 2) -> (0, 0) +``` + +During circuit generation, a hashmap (called `equivalence_classes`) is used to track all the positions (row and column) of each variable. + +During finalization, all the different cycles are created by looking at all the variables existing in the hashmap. + +### Union finds + +Sometimes, we know that two variables will have equivalent values due to an `assert_equal()` being called to link them. +Since we link two variables together, they need to be part of the same cycle, and as such we need to be able to detect that to construct correct cycles. + +To do this, we use a [union find]() data structure, which allows us to easily find the unions of equivalent variables. + +When an `assert_equal()` is called, we link the two variables together using the `union_finds` data structure. + +During finalization, when we create the cycles, we use the `union_finds` data structure to find the equivalent variables. +We then create a new equivalence classes hashmap to merge the keys (variables) that are in the same set. +This is done before using the equivalence classes hashmap to construct the cycles. + +## Validation during witness generation + +We want to make sure that the program runs to completion, with the given input, before we start computing the proof (which is expensive). + +On top of that, we also want to tell the user what line in their circuit has rejected their input. Like a normal program. + +This is configurable, via a `eval_constraints` boolean in the state, and is set to `true` by default. + +To perform this validation, we do two things: + +**Constants**. In functions that assert, we always check if we're dealing with constants first. If we are dealing with constants, we don't create any constraints, and we directly check if the computation is correct. For example: + +```rust +pub fn assert_equals( + &self, + state: &mut RunState, + other: &FieldVar, + ) -> SnarkyResult<()> { + match (self, other) { + (FieldVar::Constant(x), FieldVar::Constant(y)) => { + if x == y { + Ok(()) + } else { + Err(/* ... */) + } + } + () => state.add_constraint(/* ... */) + } + } +``` + +**Runtime values**. During witness generation, actual values are passed to the circuit and can also not pass an "asserting" function. In these cases, we directly check the computation via the gates. + +What does it mean? For example, when we add a generic constraint during witness generation (which are backing a number of asserting functions), we check that the actual values satisfy the generic gate equation. + +Note that not all gates are behind asserting functions, for example the Poseidon gate does not have "bad" inputs, and as such we don't do any checks on it. If for some reason the inputs are bad, it means that the user did this on purpose, and the proof will fail. diff --git a/book/src/snarky/snarky-wrapper.md b/book/src/snarky/snarky-wrapper.md new file mode 100644 index 0000000000..9f652da76b --- /dev/null +++ b/book/src/snarky/snarky-wrapper.md @@ -0,0 +1,70 @@ +# Snarky wrapper + +Snarky, as of today, is constructed as two parts: + +* a snarky wrapper, which is explained in this document +* a backend underneath that wrapper, explained in the [kimchi backend section](./kimchi-backend.md) + +```admonish +This separation exists for legacy reasons, and ideally we should merge the two into a single library. +``` + +The snarky wrapper mostly exists in `checked_runner.rs`, and has the following state: + +```rust +where + F: PrimeField, +{ + /// The constraint system used to build the circuit. + /// If not set, the constraint system is not built. + system: Option>, + + /// The public input of the circuit used in witness generation. + // TODO: can we merge public_input and private_input? + public_input: Vec, + + // TODO: we could also just store `usize` here + pub(crate) public_output: Vec>, + + /// The private input of the circuit used in witness generation. Still not sure what that is, or why we care about this. + private_input: Vec, + + /// If set, the witness generation will check if the constraints are satisfied. + /// This is useful to simulate running the circuit and return an error if an assertion fails. + eval_constraints: bool, + + /// The number of public inputs. + num_public_inputs: usize, + + /// A counter used to track variables (this includes public inputs) as they're being created. + next_var: usize, + + /// Indication that we're running the witness generation (as opposed to the circuit creation). + mode: Mode, +} +``` + +The wrapper is designed to be used in different ways, depending on the fields set. + +```admonish +Ideally, we would like to only run this once and obtain a result that's an immutable compiled artifact. +Currently, `public_input`, `private_input`, `eval_constriants`, `next_var`, and `mode` all need to be mutable. +In the future these should be passed as arguments to functions, and should not exist in the state. +``` + +## Public output + +The support for public output is implemented as kind of a hack. + +When the developer writes a circuit, they have to specify the type of the public output. + +This allows the API to save enough room at the end of the public input, and store the variables used in the public output in the state. + +When the API calls the circuit written by the developer, it expects the public output (as a snarky type) to be returned by the function. +The compilation or proving API that ends up calling that function, can thus obtain the variables of the public output. +With that in hand, the API can continue to write the circuit to enforce an equality constraint between these variables being returned and the public output variable that it had previously stored in the state. + +Essentially, the kimchi backend will turn this into as many wiring as there are `FieldVar` in the public output. + +During witness generation, we need a way to modify the witness once we know the values of the public output. +As the public output `FieldVar`s were generated from the snarky wrapper (and not from the kimchi backend), the snarky wrapper should know their values after running the given circuit. diff --git a/book/src/snarky/vars.md b/book/src/snarky/vars.md new file mode 100644 index 0000000000..41fabe82fd --- /dev/null +++ b/book/src/snarky/vars.md @@ -0,0 +1,135 @@ +# Vars + +In this section we will introduce two types of variables: + +* Circuit vars, or `FieldVar`s, which are low-level variables representing field elements. +* Snarky vars, which are high-level variables that user can use to create more meaningful programs. + +## Circuit vars + +In snarky, we first define circuit variables (TODO: rename Field variable?) which represent field elements in a circuit. +These circuit variables, or cvars, can be represented differently in the system: + +```rust +pub enum FieldVar +where + F: PrimeField, +{ + /// A constant. + Constant(F), + + /// A variable that can be refered to via a `usize`. + Var(usize), + + /// The addition of two other [FieldVar]s. + Add(Box>, Box>), + + /// Scaling of a [FieldVar]. + Scale(F, Box>), +} +``` + +One can see a FieldVar as an AST, where two atoms exist: a `Var(usize)` which represents a private input, an a `Constant(F)` which represents a constant. +Anything else represents combinations of these two atoms. + +### Constants + +Note that a circuit variable does not represent a value that has been constrained in the circuit (yet). +This is why we need to know if a cvar is a constant, so that we can avoid constraining it too early. +For example, the following code does not encode 2 or 1 in the circuit, but will encode 3: + +```rust +let x: FieldVar = state.exists(|_| 2) + state.exists(|_| 3); +state.assert_eq(x, y); // 3 and y will be encoded in the circuit +``` + +whereas the following code will encode all variables: + +```rust +let x = y + y; +let one: FieldVar = state.exists(|_| 1); +assert_eq(x, one); +``` + +### Non-constants + +Right after being created, a `FieldVar` is not constrained yet, and needs to be constrained by the application. +That is unless the application wants the `FieldVar` to be a constant that will not need to be constrained (see previous example) or because the application wants the `FieldVar` to be a random value (unlikely) (TODO: we should add a "rand" function for that). + +In any case, a circuit variable which is not a constant has a value that is not known yet at circuit-generation time. +In some situations, we might not want to constrain the + + +### When do variables get constrained? + +In general, a circuit variable only gets constrained by an assertion call like `assert` or `assert_equals`. + +When variables are added together, or scaled, they do not directly get constrained. +This is due to optimizations targetting R1CS (which we don't support anymore) that were implemented in the original snarky library, and that we have kept in snarky-rs. + +Imagine the following example: + +```rust +let y = x1 + x2 + x3 +.... ; +let z = y + 3; +assert_eq(y, 6); +assert_eq(z, 7); +``` + +The first two lines will not create constraints, but simply create minimal ASTs that track all of the additions. + +Both assert calls will then reduce the variables to a single circuit variable, creating the same constraints twice. + +For this reason, there's a function `seal()` defined in pickles and snarkyjs. (TODO: more about `seal()`, and why is it not in snarky?) (TODO: remove the R1CS optimization) + +## Snarky vars + +Handling `FieldVar`s can be cumbersome, as they can only represent a single field element. +We might want to represent values that are either in a smaller range (e.g. [booleans](./booleans.md)) or that are made out of several `FieldVar`s. + +For this, snarky's API exposes the following trait, which allows users to define their own types: + +```rust +pub trait SnarkyType: Sized +where + F: PrimeField, +{ + /// ? + type Auxiliary; + + /// The equivalent type outside of the circuit. + type OutOfCircuit; + + const SIZE_IN_FIELD_ELEMENTS: usize; + + fn to_cvars(&self) -> (Vec>, Self::Auxiliary); + + fn from_cvars_unsafe(cvars: Vec>, aux: Self::Auxiliary) -> Self; + + fn check(&self, cs: &mut RunState); + + fn deserialize(&self) -> (Self::OutOfCircuit, Self::Auxiliary); + + fn serialize(out_of_circuit: Self::OutOfCircuit, aux: Self::Auxiliary) -> Self; + + fn constraint_system_auxiliary() -> Self::Auxiliary; + + fn value_to_field_elements(x: &Self::OutOfCircuit) -> (Vec, Self::Auxiliary); + + fn value_of_field_elements(x: (Vec, Self::Auxiliary)) -> Self::OutOfCircuit; +} +``` + +Such types are always handled as `OutOfCircuit` types (e.g. `bool`) by the users, and as a type implementing `SnarkyType` by snarky (e.g. [`Boolean`](./booleans.md)). +Thus, the user can pass them to snarky in two ways: + +**As public inputs**. In this case they will be serialized into field elements for snarky before [witness-generation](./witness-generation.md) (via the `value_to_field_elements()` function) + +**As private inputs**. In this case, they must be created using the `compute()` function with a closure returning an `OutOfCircuit` value by the user. +The call to `compute()` will need to have some type hint, for snarky to understand what `SnarkyType` it is creating. +This is because the relationship is currently only one-way: a `SnarkyType` knows what out-of-circuit type it relates to, but not the other way is not true. +(TODO: should we implement that though?) + +A `SnarkyType` always implements a `check()` function, which is called by snarky when `compute()` is called to create such a type. +The `check()` function is responsible for creating the constraints that sanitize the newly-created `SnarkyType` (and its underlying `FieldVar`s). +For example, creating a boolean would make sure that the underlying `FieldVar` is either 0 or 1. diff --git a/book/src/specs/kimchi.md b/book/src/specs/kimchi.md index aaa124ea4d..82e4c6acd3 100644 --- a/book/src/specs/kimchi.md +++ b/book/src/specs/kimchi.md @@ -329,7 +329,7 @@ Similarly to the generic gate, each values taking part in a lookup can be scaled The lookup functionality is an opt-in feature of kimchi that can be used by custom gates. From the user's perspective, not using any gates that make use of lookups means that the feature will be disabled and there will be no overhead to the protocol. -Refer to the [lookup RFC](../kimchi/lookup.md) for an overview of the lookup feature. +Please refer to the [lookup RFC](../kimchi/lookup.md) for an overview of the lookup feature. In this section, we describe the tables kimchi supports, as well as the different lookup selectors (and their associated queries) @@ -367,6 +367,12 @@ The range check table is a single-column table containing the numbers from 0 to This is used to check that the value fits in 12 bits. +##### Runtime tables + +Another type of lookup tables has been suggested in the [Extended Lookup Tables](../kimchi/extended-lookup-tables.md). + + + #### The Lookup Selectors **XorSelector**. Performs 4 queries to the XOR lookup table. @@ -817,12 +823,11 @@ of the above efficient algorithm for Montgomery curves $b\cdot y^2 = x^3 + a \cd ```ignore Acc := [2]T for i = n-1 ... 0: - Q := (r_i == 1) ? T : -T + Q := (k_{i + 1} == 1) ? T : -T Acc := Acc + (Q + Acc) -return (d_0 == 0) ? Q - P : Q +return (k_0 == 0) ? Acc - P : Acc ``` - The layout of the witness requires 2 rows. The i-th row will be a `VBSM` gate whereas the next row will be a `ZERO` gate. @@ -1346,25 +1351,25 @@ which is doable with the constraints in a `RangeCheck0` gate. Since our current is almost empty, we can use it to perform the range check within the same gate. Then, using the following layout and assuming that the gate has a coefficient storing the value $2^{rot}$, which is publicly known -| Gate | `Rot64` | `RangeCheck0` | -| ------ | ------------------- | ---------------- | -| Column | `Curr` | `Next` | -| ------ | ------------------- | ---------------- | -| 0 | copy `word` |`shifted` | -| 1 | copy `rotated` | 0 | -| 2 | `excess` | 0 | -| 3 | `bound_limb0` | `shifted_limb0` | -| 4 | `bound_limb1` | `shifted_limb1` | -| 5 | `bound_limb2` | `shifted_limb2` | -| 6 | `bound_limb3` | `shifted_limb3` | -| 7 | `bound_crumb0` | `shifted_crumb0` | -| 8 | `bound_crumb1` | `shifted_crumb1` | -| 9 | `bound_crumb2` | `shifted_crumb2` | -| 10 | `bound_crumb3` | `shifted_crumb3` | -| 11 | `bound_crumb4` | `shifted_crumb4` | -| 12 | `bound_crumb5` | `shifted_crumb5` | -| 13 | `bound_crumb6` | `shifted_crumb6` | -| 14 | `bound_crumb7` | `shifted_crumb7` | +| Gate | `Rot64` | `RangeCheck0` gadgets (designer's duty) | +| ------ | ------------------- | --------------------------------------------------------- | +| Column | `Curr` | `Next` | `Next` + 1 | `Next`+ 2, if needed | +| ------ | ------------------- | ---------------- | --------------- | -------------------- | +| 0 | copy `word` |`shifted` | copy `excess` | copy `word` | +| 1 | copy `rotated` | 0 | 0 | 0 | +| 2 | `excess` | 0 | 0 | 0 | +| 3 | `bound_limb0` | `shifted_limb0` | `excess_limb0` | `word_limb0` | +| 4 | `bound_limb1` | `shifted_limb1` | `excess_limb1` | `word_limb1` | +| 5 | `bound_limb2` | `shifted_limb2` | `excess_limb2` | `word_limb2` | +| 6 | `bound_limb3` | `shifted_limb3` | `excess_limb3` | `word_limb3` | +| 7 | `bound_crumb0` | `shifted_crumb0` | `excess_crumb0` | `word_crumb0` | +| 8 | `bound_crumb1` | `shifted_crumb1` | `excess_crumb1` | `word_crumb1` | +| 9 | `bound_crumb2` | `shifted_crumb2` | `excess_crumb2` | `word_crumb2` | +| 10 | `bound_crumb3` | `shifted_crumb3` | `excess_crumb3` | `word_crumb3` | +| 11 | `bound_crumb4` | `shifted_crumb4` | `excess_crumb4` | `word_crumb4` | +| 12 | `bound_crumb5` | `shifted_crumb5` | `excess_crumb5` | `word_crumb5` | +| 13 | `bound_crumb6` | `shifted_crumb6` | `excess_crumb6` | `word_crumb6` | +| 14 | `bound_crumb7` | `shifted_crumb7` | `excess_crumb7` | `word_crumb7` | In Keccak, rotations are performed over a 5x5 matrix state of w-bit words each cell. The values used to perform the rotation are fixed, public, and known in advance, according to the following table, @@ -1544,7 +1549,7 @@ Meaning, In this section we specify the setup that goes into creating two indexes from a circuit: -* A [*prover index*](#prover-index), necessary for the prover to to create proofs. +* A [*prover index*](#prover-index), necessary for the prover to create proofs. * A [*verifier index*](#verifier-index), necessary for the verifier to verify proofs. ```admonish @@ -1711,7 +1716,8 @@ pub struct ProverIndex> { /// The symbolic linearization of our circuit, which can compile to concrete types once certain values are learned in the protocol. #[serde(skip)] - pub linearization: Linearization>>, + pub linearization: + Linearization>, Column>, /// The mapping between powers of alpha and constraints #[serde(skip)] @@ -1857,7 +1863,8 @@ pub struct VerifierIndex> { pub lookup_index: Option>, #[serde(skip)] - pub linearization: Linearization>>, + pub linearization: + Linearization>, Column>, /// The mapping between powers of alpha and constraints #[serde(skip)] pub powers_of_alpha: Alphas, @@ -1867,7 +1874,7 @@ pub struct VerifierIndex> { ## Proof Construction & Verification -Originally, kimchi is based on an interactive protocol that was transformed into a non-interactive one using the [Fiat-Shamir](https://o1-labs.github.io/mina-book/crypto/plonk/fiat_shamir.html) transform. +Originally, kimchi is based on an interactive protocol that was transformed into a non-interactive one using the [Fiat-Shamir](https://o1-labs.github.io/proof-systems/plonk/fiat_shamir.html) transform. For this reason, it can be useful to visualize the high-level interactive protocol before the transformation: ```mermaid @@ -1960,8 +1967,10 @@ pub struct PointEvaluations { // TODO: this should really be vectors here, perhaps create another type for chunked evaluations? /// Polynomial evaluations contained in a `ProverProof`. -/// - **Chunked evaluations** `Field` is instantiated with vectors with a length that equals the length of the chunk -/// - **Non chunked evaluations** `Field` is instantiated with a field, so they are single-sized#[serde_as] +/// - **Chunked evaluations** `Field` is instantiated with vectors with a length +/// that equals the length of the chunk +/// - **Non chunked evaluations** `Field` is instantiated with a field, so they +/// are single-sized#[serde_as] #[serde_as] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct ProofEvaluations { @@ -1972,7 +1981,8 @@ pub struct ProofEvaluations { /// permutation polynomial pub z: Evals, /// permutation polynomials - /// (PERMUTS-1 evaluations because the last permutation is only used in commitment form) + /// (PERMUTS-1 evaluations because the last permutation is only used in + /// commitment form) pub s: [Evals; PERMUTS - 1], /// coefficient polynomials pub coefficients: [Evals; COLUMNS], @@ -1982,11 +1992,13 @@ pub struct ProofEvaluations { pub poseidon_selector: Evals, /// evaluation of the elliptic curve addition selector polynomial pub complete_add_selector: Evals, - /// evaluation of the elliptic curve variable base scalar multiplication selector polynomial + /// evaluation of the elliptic curve variable base scalar multiplication + /// selector polynomial pub mul_selector: Evals, /// evaluation of the endoscalar multiplication selector polynomial pub emul_selector: Evals, - /// evaluation of the endoscalar multiplication scalar computation selector polynomial + /// evaluation of the endoscalar multiplication scalar computation selector + /// polynomial pub endomul_scalar_selector: Evals, // Optional gates @@ -2022,7 +2034,8 @@ pub struct ProofEvaluations { pub lookup_gate_lookup_selector: Option, /// evaluation of the RangeCheck range check pattern selector polynomial pub range_check_lookup_selector: Option, - /// evaluation of the ForeignFieldMul range check pattern selector polynomial + /// evaluation of the ForeignFieldMul range check pattern selector + /// polynomial pub foreign_field_mul_lookup_selector: Option, } @@ -2054,7 +2067,8 @@ pub struct ProverCommitments { pub lookup: Option>, } -/// The proof that the prover creates from a [ProverIndex](super::prover_index::ProverIndex) and a `witness`. +/// The proof that the prover creates from a +/// [ProverIndex](super::prover_index::ProverIndex) and a `witness`. #[serde_as] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[serde(bound = "G: ark_serialize::CanonicalDeserialize + ark_serialize::CanonicalSerialize")] @@ -2072,7 +2086,8 @@ pub struct ProverProof { /// Two evaluations over a number of committed polynomials pub evals: ProofEvaluations>>, - /// Required evaluation for [Maller's optimization](https://o1-labs.github.io/mina-book/crypto/plonk/maller_15.html#the-evaluation-of-l) + /// Required evaluation for [Maller's + /// optimization](https://o1-labs.github.io/proof-systems/kimchi/maller_15.html#the-evaluation-of-l) #[serde_as(as = "o1_utils::serialization::SerdeAs")] pub ft_eval1: G::ScalarField, @@ -2159,7 +2174,7 @@ The prover then follows the following steps to create the proof: * If queries involve a lookup table with multiple columns then squeeze the Fq-Sponge to obtain the joint combiner challenge $j'$, otherwise set the joint combiner challenge $j'$ to $0$. - * Derive the scalar joint combiner $j$ from $j'$ using the endomorphism (TOOD: specify) + * Derive the scalar joint combiner $j$ from $j'$ using the endomorphism (TODO: specify) * If multiple lookup tables are involved, set the `table_id_combiner` as the $j^i$ with $i$ the maximum width of any used table. Essentially, this is to add a last column of table ids to the concatenated lookup tables. @@ -2192,7 +2207,7 @@ The prover then follows the following steps to create the proof: and by then dividing the resulting polynomial with the vanishing polynomial $Z_H$. TODO: specify the split of the permutation polynomial into perm and bnd? 1. commit (hiding) to the quotient polynomial $t$ -1. Absorb the the commitment of the quotient polynomial with the Fq-Sponge. +1. Absorb the commitment of the quotient polynomial with the Fq-Sponge. 1. Sample $\zeta'$ with the Fq-Sponge. 1. Derive $\zeta$ from $\zeta'$ using the endomorphism (TODO: specify) 1. If lookup is used, evaluate the following polynomials at $\zeta$ and $\zeta \omega$: @@ -2216,13 +2231,13 @@ The prover then follows the following steps to create the proof: $$(f_0(x), f_1(x), f_2(x), \ldots)$$ - TODO: do we want to specify more on that? It seems unecessary except for the t polynomial (or if for some reason someone sets that to a low value) + TODO: do we want to specify more on that? It seems unnecessary except for the t polynomial (or if for some reason someone sets that to a low value) 1. Evaluate the same polynomials without chunking them (so that each polynomial should correspond to a single value this time). 1. Compute the ft polynomial. - This is to implement [Maller's optimization](https://o1-labs.github.io/mina-book/crypto/plonk/maller_15.html). + This is to implement [Maller's optimization](https://o1-labs.github.io/proof-systems/kimchi/maller_15.html). 1. construct the blinding part of the ft polynomial commitment - [see this section](https://o1-labs.github.io/mina-book/crypto/plonk/maller_15.html#evaluation-proof-and-blinding-factors) + [see this section](https://o1-labs.github.io/proof-systems/kimchi/maller_15.html#evaluation-proof-and-blinding-factors) 1. Evaluate the ft polynomial at $\zeta\omega$ only. 1. Setup the Fr-Sponge 1. Squeeze the Fq-sponge and absorb the result with the Fr-Sponge. @@ -2273,7 +2288,7 @@ We define two helper algorithms below, used in the batch verification of proofs. We run the following algorithm: -1. Setup the Fq-Sponge. +1. Setup the Fq-Sponge. This sponge mostly absorbs group 1. Absorb the digest of the VerifierIndex. 1. Absorb the commitments of the previous challenges with the Fq-sponge. 1. Absorb the commitment of the public input polynomial with the Fq-Sponge. @@ -2295,7 +2310,7 @@ We run the following algorithm: 1. Absorb the commitment to the quotient polynomial $t$ into the argument. 1. Sample $\zeta'$ with the Fq-Sponge. 1. Derive $\zeta$ from $\zeta'$ using the endomorphism (TODO: specify). -1. Setup the Fr-Sponge. +1. Setup the Fr-Sponge. This sponge absorbs elements from 1. Squeeze the Fq-sponge and absorb the result with the Fr-Sponge. 1. Absorb the previous recursion challenges. 1. Compute evaluations for the previous recursion challenges. @@ -2310,9 +2325,9 @@ We run the following algorithm: * poseidon selector * the 15 register/witness * 6 sigmas evaluations (the last one is not evaluated) -1. Sample $v'$ with the Fr-Sponge. +1. Sample the "polyscale" $v'$ with the Fr-Sponge. 1. Derive $v$ from $v'$ using the endomorphism (TODO: specify). -1. Sample $u'$ with the Fr-Sponge. +1. Sample the "evalscale" $u'$ with the Fr-Sponge. 1. Derive $u$ from $u'$ using the endomorphism (TODO: specify). 1. Create a list of all polynomials that have an evaluation proof. 1. Compute the evaluation of $ft(\zeta)$. diff --git a/book/src/specs/poseidon.md b/book/src/specs/poseidon.md index f39f989131..7c1efaa57c 100644 --- a/book/src/specs/poseidon.md +++ b/book/src/specs/poseidon.md @@ -15,24 +15,24 @@ We might want to piggy back on the [zcash poseidon spec](https://github.com/C2SP We define a base sponge, and a scalar sponge. Both must be instantiated when verifying a proof (this is due to recursion-support baked in Kimchi). -External users of kimchi (or pickles) are most likely to interact with a wrap proof (see the [pickles specification](./pickles.md)). +External users of kimchi (or pickles) are most likely to interact with a wrap proof (see the [pickles specification](./pickles.md)). As such, the sponges they need to instantiate are most likely to be instantiated with: * Poseidon-Fp for base sponge -* Poseidon-Fq for the scalar sponge +* Poseidon-Fq for the scalar sponge ### Base Sponge * `new(params) -> BaseSponge`. Creates a new base sponge. * `BaseSponge.absorb(field_element)`. Absorbs a field element by calling the underlying sponge `absorb` function. * `BaseSponge.absorb_point(point)`. Absorbs an elliptic curve point. If the point is the point at infinity, absorb two zeros. Otherwise, absorb the x and y coordinates with two calls to `absorb`. -* `BaseSponge.absorb_scalar(field_element_of_scalar_field)`. Absorbs a scalar. +* `BaseSponge.absorb_scalar(field_element_of_scalar_field)`. Absorbs a scalar. * If the scalar field is smaller than the base field (e.g. Fp is smaller than Fq), then the scalar is casted to a field element of the base field and absorbed via `absorb`. * Otherwise, the value is split between its least significant bit and the rest. Then both values are casted to field elements of the base field and absorbed via `absorb` (the high bits first, then the low bit). * `BaseSponge.digest() -> field_element`. The `squeeze` function of the underlying sponge function is called and the first field element is returned. -* `BaseSponge.digest_scalar() -> field_element_of_scalar_field`. -* `BaseSponge.challenge // TODO: specify`. -* `BaseSponge.challenge_fq // TODO: specify`. +* `BaseSponge.digest_scalar() -> field_element_of_scalar_field`. +* `BaseSponge.challenge // TODO: specify`. +* `BaseSponge.challenge_fq // TODO: specify`. ### Scalar Sponge @@ -64,7 +64,7 @@ def apply_mds(state): n[1] = state[0] * mds[1][0] + state[1] * mds[1][1] + state[2] * mds[1][2] n[2] = state[0] * mds[2][0] + state[1] * mds[2][1] + state[2] * mds[2][2] return n - + # a round in the permutation def apply_round(round, state): # sbox @@ -89,9 +89,9 @@ def permutation(state): state[1] += constant[1] state[2] += constant[2] round_offset = 1 - + for round in range(round_offset, ROUNDS + round_offset): - apply_round(round, state) + apply_round(round, state) ``` ### Sponge @@ -115,7 +115,7 @@ def absorb(sponge, field_element): elif sponge.offset == RATE: sponge.state = permutation(sponge.state) sponge.offset = 0 - + # absorb by adding to the state sponge.state[sponge.offset] += field_element sponge.offset += 1 @@ -157,7 +157,7 @@ You can find the MDS matrix and round constants we use in [/poseidon/src/pasta/f ## Test vectors -We have test vectors contained in [/poseidon/src/tests/test_vectors/kimchi.json](https://github.com/o1-labs/proof-systems/tree/master/poseidon/src/tests/test_vectors/kimchi.json). +We have test vectors contained in [/poseidon/tests/test_vectors/kimchi.json](https://github.com/o1-labs/proof-systems/tree/master/poseidon/tests/test_vectors/kimchi.json). ## Pointers to the OCaml code diff --git a/book/src/specs/urs.md b/book/src/specs/urs.md index 83dfa635c1..dd63174520 100644 --- a/book/src/specs/urs.md +++ b/book/src/specs/urs.md @@ -2,7 +2,7 @@ ```admonish Note that the URS is called SRS in our codebase at the moment. -SRS is incorrect, as it is used in the presence of a trusted setup (which kimchi does not have). +SRS is incorrect, as it is used in the presence of a trusted setup (which kimchi does not have). This needs to be fixed. ``` @@ -68,7 +68,7 @@ TODO: specify the encoding ### Vesta -**`Gs`**. +**`Gs`**. ``` G0 = (x=121C4426885FD5A9701385AAF8D43E52E7660F1FC5AFC5F6468CC55312FC60F8, y=21B439C01247EA3518C5DDEB324E4CB108AF617780DDF766D96D3FD8AB028B70) diff --git a/circuit-construction/Cargo.toml b/circuit-construction/Cargo.toml index 2141e692de..fdd047984b 100644 --- a/circuit-construction/Cargo.toml +++ b/circuit-construction/Cargo.toml @@ -15,32 +15,32 @@ bench = false # needed for criterion (https://bheisler.github.io/criterion.rs/bo [dependencies] ark-ff.workspace = true -ark-ec.workspace = true +ark-ec.workspace = true ark-poly.workspace = true -ark-serialize.workspace = true -blake2.workspace = true -num-derive.workspace = true +ark-serialize.workspace = true +blake2.workspace = true +num-derive.workspace = true num-traits.workspace = true -itertools.workspace = true -rand.workspace = true -rand_core.workspace = true -rayon.workspace = true -rmp-serde.workspace = true -serde.workspace = true +itertools.workspace = true +rand.workspace = true +rand_core.workspace = true +rayon.workspace = true +rmp-serde.workspace = true +serde.workspace = true serde_with.workspace = true -thiserror.workspace = true +thiserror.workspace = true poly-commitment.workspace = true -groupmap.workspace = true -mina-curves.workspace = true -o1-utils.workspace = true +groupmap.workspace = true +mina-curves.workspace = true +o1-utils.workspace = true mina-poseidon.workspace = true -kimchi.workspace = true +kimchi.workspace = true [dev-dependencies] proptest.workspace = true -proptest-derive.workspace = true -colored.workspace = true +proptest-derive.workspace = true +colored.workspace = true # benchmarks criterion.workspace = true diff --git a/circuit-construction/src/tests/example_proof.rs b/circuit-construction/src/tests/example_proof.rs index dded30b22b..ffa08a3527 100644 --- a/circuit-construction/src/tests/example_proof.rs +++ b/circuit-construction/src/tests/example_proof.rs @@ -48,12 +48,11 @@ const PUBLIC_INPUT_LENGTH: usize = 3; #[test] fn test_example_circuit() { - use mina_curves::pasta::Pallas; - use mina_curves::pasta::Vesta; + use mina_curves::pasta::{Pallas, Vesta}; // create SRS let srs = { - let mut srs = SRS::::create(1 << 7); // 2^7 = 128 - srs.get_lagrange_basis_from_domain_size(Radix2EvaluationDomain::new(srs.g.len())); + let srs = SRS::::create(1 << 7); // 2^7 = 128 + srs.get_lagrange_basis(Radix2EvaluationDomain::new(srs.g.len()).unwrap()); Arc::new(srs) }; diff --git a/clippy.toml b/clippy.toml new file mode 100644 index 0000000000..66b530223b --- /dev/null +++ b/clippy.toml @@ -0,0 +1,6 @@ +# Default is 7. However, it seems to be arbitrary. Some protocols do require +# many parameters, like polynomial commitment scheme when we want to make an +# opening. +# Note that we should not use too much this threshold, and stay clean in the +# choice of our routine parameters. +too-many-arguments-threshold = 8 \ No newline at end of file diff --git a/curves/Cargo.toml b/curves/Cargo.toml index 2116b66c86..bd2bbd6673 100644 --- a/curves/Cargo.toml +++ b/curves/Cargo.toml @@ -12,10 +12,11 @@ license = "Apache-2.0" [dependencies] ark-ec.workspace = true ark-ff.workspace = true +num-bigint.workspace = true [dev-dependencies] -rand.workspace = true +rand.workspace = true ark-test-curves.workspace = true ark-algebra-test-templates.workspace = true ark-serialize.workspace = true -ark-std.workspace = true +ark-std.workspace = true \ No newline at end of file diff --git a/curves/src/pasta/curves/mod.rs b/curves/src/pasta/curves/mod.rs index 76aa40ef47..2adf2de2b8 100644 --- a/curves/src/pasta/curves/mod.rs +++ b/curves/src/pasta/curves/mod.rs @@ -1,5 +1,2 @@ pub mod pallas; pub mod vesta; - -#[cfg(test)] -mod tests; diff --git a/curves/src/pasta/curves/tests.rs b/curves/src/pasta/curves/tests.rs deleted file mode 100644 index 40a57c2091..0000000000 --- a/curves/src/pasta/curves/tests.rs +++ /dev/null @@ -1,5 +0,0 @@ -use crate::pasta::{ProjectivePallas, ProjectiveVesta}; -use ark_algebra_test_templates::*; - -test_group!(g1; ProjectivePallas; sw); -test_group!(g2; ProjectiveVesta; sw); diff --git a/curves/src/pasta/fields/mod.rs b/curves/src/pasta/fields/mod.rs index 158eae558d..84be91fdba 100644 --- a/curves/src/pasta/fields/mod.rs +++ b/curves/src/pasta/fields/mod.rs @@ -43,6 +43,3 @@ pub trait SquareRootField: Field { /// Sets `self` to be the square root of `self`, if it exists. fn sqrt_in_place(&mut self) -> Option<&mut Self>; } - -#[cfg(test)] -mod tests; diff --git a/curves/tests/pasta_curves.rs b/curves/tests/pasta_curves.rs new file mode 100644 index 0000000000..8cf829fa74 --- /dev/null +++ b/curves/tests/pasta_curves.rs @@ -0,0 +1,75 @@ +use std::str::FromStr; + +use ark_algebra_test_templates::*; +use mina_curves::pasta::{Fp, Pallas, ProjectivePallas, ProjectiveVesta}; +use num_bigint::BigUint; + +test_group!(g1; ProjectivePallas; sw); +test_group!(g2; ProjectiveVesta; sw); + +#[test] +fn test_regression_vesta_biguint_into_returns_canonical_representation() { + // This regression test is to ensure that the BigUint::into() impl for Vesta + // returns the canonical representation of the field element and not the + // montgomery representation. + // Commit: 0863042ca383c6578564e166724e9dbc66da19bf + let p_x = Fp::from_str("1").unwrap(); + let p_y = Fp::from_str( + "12418654782883325593414442427049395787963493412651469444558597405572177144507", + ) + .unwrap(); + let p1 = Pallas::new_unchecked(p_x, p_y); + let p_x_biguint: BigUint = p1.x.into(); + let p_y_biguint: BigUint = p1.y.into(); + + assert_eq!(p_x_biguint, BigUint::from_str("1").unwrap()); + assert_eq!( + p_y_biguint, + BigUint::from_str( + "12418654782883325593414442427049395787963493412651469444558597405572177144507", + ) + .unwrap() + ); +} + +#[test] +fn test_regression_vesta_addition_affine() { + // This regression test is to ensure that the addition of two points in + // affine coordinates using the inplace operator `+` works correctly. + // Commit: 0863042ca383c6578564e166724e9dbc66da19bf + let p1_x = Fp::from_str("1").unwrap(); + let p1_y = Fp::from_str( + "12418654782883325593414442427049395787963493412651469444558597405572177144507", + ) + .unwrap(); + let p1 = Pallas::new_unchecked(p1_x, p1_y); + + let p2_x = Fp::from_str( + "20444556541222657078399132219657928148671392403212669005631716460534733845831", + ) + .unwrap(); + let p2_y = Fp::from_str( + "12418654782883325593414442427049395787963493412651469444558597405572177144507", + ) + .unwrap(); + let p2 = Pallas::new_unchecked(p2_x, p2_y); + + // The type annotation ensures we have a point with affine coordinates, + // relying on implicit conversion if the addition outputs a point in a + // different coordinates set. + let p3: Pallas = (p1 + p2).into(); + + let expected_p3_x = BigUint::from_str( + "8503465768106391777493614032514048814691664078728891710322960303815233784505", + ) + .unwrap(); + let expected_p3_y = BigUint::from_str( + "16529367526445723262478303825122581175399563069290091271396079358777790485830", + ) + .unwrap(); + + let p3_x_biguint: BigUint = p3.x.into(); + let p3_y_biguint: BigUint = p3.y.into(); + assert_eq!(expected_p3_x, p3_x_biguint); + assert_eq!(expected_p3_y, p3_y_biguint); +} diff --git a/curves/src/pasta/fields/tests.rs b/curves/tests/pasta_fields.rs similarity index 70% rename from curves/src/pasta/fields/tests.rs rename to curves/tests/pasta_fields.rs index 0489cfc4cf..85a1f0aa6e 100644 --- a/curves/src/pasta/fields/tests.rs +++ b/curves/tests/pasta_fields.rs @@ -1,5 +1,5 @@ -use crate::pasta::fields::{Fp as Fr, Fq}; use ark_algebra_test_templates::*; +use mina_curves::pasta::fields::{Fp as Fr, Fq}; test_field!(fq; Fq; mont_prime_field); test_field!(fr; Fr; mont_prime_field); diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000000..b08466f985 --- /dev/null +++ b/flake.nix @@ -0,0 +1,37 @@ +{ + description = "proof-system prerequisites"; + + inputs = { + naersk.url = "github:nix-community/naersk/master"; + nixpkgs.url = "github:NixOS/nixpkgs/nixpkgs-unstable"; + utils.url = "github:numtide/flake-utils"; + }; + + outputs = { self, nixpkgs, utils, naersk }: + utils.lib.eachDefaultSystem (system: + let + pkgs = import nixpkgs { inherit system; }; + naersk-lib = pkgs.callPackage naersk { }; + in + { + defaultPackage = naersk-lib.buildPackage ./.; + devShell = with pkgs; mkShell { + buildInputs = [ + # rust inputs + rustup + rust-analyzer + libiconv # Needed for macOS for rustup + + # ocaml inputs + opam + + # go inputs + gopls + go + + ]; + RUST_SRC_PATH = rustPlatform.rustLibSrc; + }; + } + ); +} \ No newline at end of file diff --git a/folding/Cargo.toml b/folding/Cargo.toml new file mode 100644 index 0000000000..240ec67bf5 --- /dev/null +++ b/folding/Cargo.toml @@ -0,0 +1,38 @@ +[package] +name = "folding" +version = "0.1.0" +repository = "https://github.com/o1-labs/proof-systems" +homepage = "https://o1-labs.github.io/proof-systems/" +documentation = "https://o1-labs.github.io/proof-systems/rustdoc/" +edition = "2021" +license = "Apache-2.0" + +[dependencies] +ark-serialize.workspace = true +o1-utils.workspace = true +itertools.workspace = true +log.workspace = true +kimchi.workspace = true +poly-commitment.workspace = true +groupmap.workspace = true +mina-curves.workspace = true +mina-poseidon.workspace = true +num-bigint.workspace = true +num-integer.workspace = true +num-traits.workspace = true +rmp-serde.workspace = true +serde_json.workspace = true +serde.workspace = true +serde_with.workspace = true +strum.workspace = true +strum_macros.workspace = true +ark-poly.workspace = true +ark-ff.workspace = true +ark-ec.workspace = true +rand.workspace = true +rayon.workspace = true +thiserror.workspace = true +derivative = "2" + +[dev-dependencies] +ark-bn254.workspace = true \ No newline at end of file diff --git a/folding/src/checker.rs b/folding/src/checker.rs new file mode 100644 index 0000000000..fdbb24ebb4 --- /dev/null +++ b/folding/src/checker.rs @@ -0,0 +1,261 @@ +//! A kind of pseudo-prover, will compute the expressions over the witness a check row by row +//! for a zero result. + +use crate::{ + expressions::{FoldingColumnTrait, FoldingCompatibleExpr, FoldingCompatibleExprInner}, + instance_witness::Instance, + ExpExtension, FoldingConfig, Radix2EvaluationDomain, RelaxedInstance, RelaxedWitness, +}; +use ark_ec::AffineRepr; +use ark_ff::{Field, Zero}; +use ark_poly::Evaluations; +use kimchi::circuits::{expr::Variable, gate::CurrOrNext}; +use std::ops::Index; + +#[cfg(not(test))] +use log::debug; +#[cfg(test)] +use std::println as debug; + +// 1. We continue by defining a generic type of columns and selectors. +// The selectors can be seen as additional (public) columns that are not part of +// the witness. +// The column must implement the trait [Hash] as it will be used by internal +// structures of the library. +#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)] +pub enum Column { + X(usize), + Selector(usize), +} + +// 2. We implement the trait [FoldingColumnTrait] that allows to distinguish +// between the public and private inputs, often called the "instances" and the +// "witnesses". +// By default, we consider that the columns are all witness values and selectors +// are public. +impl FoldingColumnTrait for Column { + fn is_witness(&self) -> bool { + match self { + Column::X(_) => true, + Column::Selector(_) => false, + } + } +} + +// 3. We define different traits that can be used generically by the folding +// examples. +// It can be used by "pseudo-provers". + +pub struct Provider { + pub instance: C::Instance, + pub witness: C::Witness, +} + +impl Provider { + pub fn new(instance: C::Instance, witness: C::Witness) -> Self { + Self { instance, witness } + } +} + +pub struct ExtendedProvider { + pub inner_provider: Provider, + pub instance: RelaxedInstance<::Curve, ::Instance>, + pub witness: RelaxedWitness<::Curve, ::Witness>, +} + +impl ExtendedProvider { + pub fn new( + instance: RelaxedInstance, + witness: RelaxedWitness, + ) -> Self { + let inner_provider = { + let instance = instance.extended_instance.instance.clone(); + let witness = witness.extended_witness.witness.clone(); + Provider::new(instance, witness) + }; + Self { + inner_provider, + instance, + witness, + } + } +} + +pub trait Provide { + fn resolve( + &self, + inner: FoldingCompatibleExprInner, + domain: Radix2EvaluationDomain<::ScalarField>, + ) -> Vec<::ScalarField>; +} + +impl Provide for Provider +where + C::Witness: Index< + C::Column, + Output = Evaluations< + ::ScalarField, + Radix2EvaluationDomain<::ScalarField>, + >, + >, + C::Witness: Index< + C::Selector, + Output = Evaluations< + ::ScalarField, + Radix2EvaluationDomain<::ScalarField>, + >, + >, + C::Instance: Index::ScalarField>, +{ + fn resolve( + &self, + inner: FoldingCompatibleExprInner, + domain: Radix2EvaluationDomain<::ScalarField>, + ) -> Vec<::ScalarField> { + let domain_size = domain.size as usize; + match inner { + FoldingCompatibleExprInner::Constant(c) => { + vec![c; domain_size] + } + FoldingCompatibleExprInner::Challenge(chal) => { + let v = self.instance[chal]; + vec![v; domain_size] + } + FoldingCompatibleExprInner::Cell(var) => { + let Variable { col, row } = var; + + let col = &self.witness[col].evals; + + let mut col = col.clone(); + //check this, while not relevant in this case I think it should be right rotation + if let CurrOrNext::Next = row { + col.rotate_left(1); + } + col + } + FoldingCompatibleExprInner::Extensions(_) => { + panic!("not handled here"); + } + } + } +} + +impl Provide for ExtendedProvider +where + C::Witness: Index< + C::Column, + Output = Evaluations< + ::ScalarField, + Radix2EvaluationDomain<::ScalarField>, + >, + >, + C::Witness: Index< + C::Selector, + Output = Evaluations< + ::ScalarField, + Radix2EvaluationDomain<::ScalarField>, + >, + >, + C::Instance: Index::ScalarField>, +{ + fn resolve( + &self, + inner: FoldingCompatibleExprInner, + domain: Radix2EvaluationDomain<::ScalarField>, + ) -> Vec<::ScalarField> { + match inner { + FoldingCompatibleExprInner::Extensions(ext) => match ext { + ExpExtension::U => { + let u = self.instance.u; + let domain_size = domain.size as usize; + vec![u; domain_size] + } + ExpExtension::Error => self.witness.error_vec.evals.clone(), + ExpExtension::ExtendedWitness(i) => self + .witness + .extended_witness + .extended + .get(&i) + .unwrap() + .evals + .clone(), + ExpExtension::Alpha(i) => { + let alpha = self + .instance + .extended_instance + .instance + .get_alphas() + .get(i) + .unwrap(); + let domain_size = domain.size as usize; + vec![alpha; domain_size] + } + ExpExtension::Selector(s) => { + let col = &self.inner_provider.witness[s].evals; + col.clone() + } + }, + e => self.inner_provider.resolve(e, domain), + } + } +} + +pub trait Checker: Provide { + fn check_rec( + &self, + exp: FoldingCompatibleExpr, + domain: Radix2EvaluationDomain<::ScalarField>, + ) -> Vec<::ScalarField> { + let e2 = exp.clone(); + let res = match exp { + FoldingCompatibleExpr::Atom(inner) => self.resolve(inner, domain), + FoldingCompatibleExpr::Double(e) => { + let v = self.check_rec(*e, domain); + v.into_iter().map(|x| x.double()).collect() + } + FoldingCompatibleExpr::Square(e) => { + let v = self.check_rec(*e, domain); + v.into_iter().map(|x| x.square()).collect() + } + FoldingCompatibleExpr::Add(e1, e2) => { + let v1 = self.check_rec(*e1, domain); + let v2 = self.check_rec(*e2, domain); + v1.into_iter().zip(v2).map(|(a, b)| a + b).collect() + } + FoldingCompatibleExpr::Sub(e1, e2) => { + let v1 = self.check_rec(*e1, domain); + let v2 = self.check_rec(*e2, domain); + v1.into_iter().zip(v2).map(|(a, b)| a - b).collect() + } + FoldingCompatibleExpr::Mul(e1, e2) => { + let v1 = self.check_rec(*e1, domain); + let v2 = self.check_rec(*e2, domain); + v1.into_iter().zip(v2).map(|(a, b)| a * b).collect() + } + FoldingCompatibleExpr::Pow(e, exp) => { + let v = self.check_rec(*e, domain); + v.into_iter().map(|x| x.pow([exp])).collect() + } + }; + debug!("exp: {:?}", e2); + debug!("res: [\n"); + for e in res.iter() { + debug!("{e}\n"); + } + debug!("]"); + res + } + + fn check( + &self, + exp: &FoldingCompatibleExpr, + domain: Radix2EvaluationDomain<::ScalarField>, + ) { + let res = self.check_rec(exp.clone(), domain); + for (i, row) in res.iter().enumerate() { + if !row.is_zero() { + panic!("check in row {i} failed, {row} != 0"); + } + } + } +} diff --git a/folding/src/columns.rs b/folding/src/columns.rs new file mode 100644 index 0000000000..fd08561785 --- /dev/null +++ b/folding/src/columns.rs @@ -0,0 +1,39 @@ +//! This module contains description of the additional columns used by our +//! folding scheme implementation. The columns are the base layer of the folding +//! scheme as they describe the basic expressiveness of the system. + +use crate::FoldingConfig; +use ark_ec::AffineRepr; +use derivative::Derivative; +use kimchi::circuits::expr::Variable; + +/// Describes the additional columns. It is parametrized by a configuration for +/// the folding scheme, described in the trait [FoldingConfig]. For instance, +/// the configuration describes the initial columns of the circuit, the +/// challenges and the underlying field. +#[derive(Derivative)] +#[derivative( + Hash(bound = "C: FoldingConfig"), + Debug(bound = "C: FoldingConfig"), + Clone(bound = "C: FoldingConfig"), + PartialEq(bound = "C: FoldingConfig"), + Eq(bound = "C: FoldingConfig") +)] +pub enum ExtendedFoldingColumn { + /// The variables of the initial circuit, without quadraticization and not + /// homogeonized. + Inner(Variable), + /// For the extra columns added by the module `quadraticization`. + WitnessExtended(usize), + /// The error term introduced in the "relaxed" instance. + Error, + /// A constant value in our expression + Constant(::ScalarField), + /// A challenge used by the PIOP or the folding scheme. + Challenge(C::Challenge), + /// A list of randomizer to combine expressions + Alpha(usize), + /// A "virtual" selector that can be used to activate/deactivate expressions + /// while folding/accumulating multiple expressions. + Selector(C::Selector), +} diff --git a/folding/src/decomposable_folding.rs b/folding/src/decomposable_folding.rs new file mode 100644 index 0000000000..6d52557c0f --- /dev/null +++ b/folding/src/decomposable_folding.rs @@ -0,0 +1,207 @@ +//! This variant of folding is designed to efficiently handle cases where +//! certain assumptions about the witness can be made. +//! Specifically, an alternative is provided such that the scheme is created +//! from a set of list of constraints, each set associated with a particular +//! selector, as opposed to a single list of constraints. + +use crate::{ + columns::ExtendedFoldingColumn, + error_term::{compute_error, ExtendedEnv}, + expressions::{ExpExtension, FoldingCompatibleExpr, FoldingCompatibleExprInner, FoldingExp}, + instance_witness::{RelaxableInstance, RelaxablePair, RelaxedInstance, RelaxedWitness}, + BaseField, FoldingConfig, FoldingOutput, FoldingScheme, ScalarField, +}; +use ark_poly::{Evaluations, Radix2EvaluationDomain}; +use mina_poseidon::FqSponge; +use poly_commitment::{PolyComm, SRS}; +use std::collections::BTreeMap; + +pub struct DecomposableFoldingScheme<'a, CF: FoldingConfig> { + inner: FoldingScheme<'a, CF>, +} + +impl<'a, CF: FoldingConfig> DecomposableFoldingScheme<'a, CF> { + /// Creates a new folding scheme for decomposable circuits. + /// It takes as input: + /// - a set of constraints, each associated with a particular selector; + /// - a list of common constraints, that are applied to every instance + /// regardless of the selector (can be empty); + /// - a structured reference string; + /// - a domain; + /// - a structure of the associated folding configuration. + /// The function uses the normal `FoldingScheme::new()` function to create + /// the decomposable scheme, using for that the concatenation of the + /// constraints associated with each selector multiplied by the selector, + /// and the common constraints. + /// This product is performed with + /// `FoldingCompatibleExprInner::Extensions(ExpExtension::Selector(s))`. + pub fn new( + // constraints with a dynamic selector + constraints: BTreeMap>>, + // constraints to be applied to every single instance regardless of selectors + common_constraints: Vec>, + srs: &'a CF::Srs, + domain: Radix2EvaluationDomain>, + structure: &CF::Structure, + ) -> (Self, FoldingCompatibleExpr) { + let constraints = constraints + .into_iter() + .flat_map(|(s, exps)| { + exps.into_iter().map(move |exp| { + let s = FoldingCompatibleExprInner::Extensions(ExpExtension::Selector(s)); + let s = Box::new(FoldingCompatibleExpr::Atom(s)); + FoldingCompatibleExpr::Mul(s, Box::new(exp)) + }) + }) + .chain(common_constraints) + .collect(); + let (inner, exp) = FoldingScheme::new(constraints, srs, domain, structure); + (DecomposableFoldingScheme { inner }, exp) + } + + /// Return the number of additional columns added by quadraticization + pub fn get_number_of_additional_columns(&self) -> usize { + self.inner.get_number_of_additional_columns() + } + + #[allow(clippy::type_complexity)] + /// folding with a selector will assume that only the selector in question + /// is enabled (i.e. set to 1) in all rows, and any other selector is 0 over + /// all rows. + /// If that is not the case, providing `None` will fold without assumptions + pub fn fold_instance_witness_pair( + &self, + a: A, + b: B, + selector: Option, + fq_sponge: &mut Sponge, + ) -> FoldingOutput + where + A: RelaxablePair, + B: RelaxablePair, + Sponge: FqSponge, CF::Curve, ScalarField>, + { + let scheme = &self.inner; + let a = a.relax(&scheme.zero_vec); + let b = b.relax(&scheme.zero_vec); + + let u = (a.0.u, b.0.u); + + let (left_instance, left_witness) = a; + let (right_instance, right_witness) = b; + let env = ExtendedEnv::new( + &scheme.structure, + [left_instance, right_instance], + [left_witness, right_witness], + scheme.domain, + selector, + ); + + let env = env.compute_extension(&scheme.extended_witness_generator, scheme.srs); + let error = compute_error(&scheme.expression, &env, u); + let error_evals = error.map(|e| Evaluations::from_vec_and_domain(e, scheme.domain)); + + let error_commitments = error_evals + .iter() + .map(|e| scheme.srs.commit_evaluations_non_hiding(scheme.domain, e)) + .collect::>(); + let error_commitments: [PolyComm; 2] = error_commitments.try_into().unwrap(); + + let error = error_evals.into_iter().map(|e| e.evals).collect::>(); + let error: [Vec<_>; 2] = error.try_into().unwrap(); + + // sanity check to verify that we only have one commitment in polycomm + // (i.e. domain = poly size) + assert_eq!(error_commitments[0].len(), 1); + assert_eq!(error_commitments[1].len(), 1); + + let t0 = &error_commitments[0].get_first_chunk(); + let t1 = &error_commitments[1].get_first_chunk(); + + let to_absorb = env.to_absorb(t0, t1); + fq_sponge.absorb_fr(&to_absorb.0); + fq_sponge.absorb_g(&to_absorb.1); + + let challenge = fq_sponge.challenge(); + + let ( + [relaxed_extended_left_instance, relaxed_extended_right_instance], + [relaxed_extended_left_witness, relaxed_extended_right_witness], + ) = env.unwrap(); + + let folded_instance = RelaxedInstance::combine_and_sub_cross_terms( + // FIXME: remove clone + relaxed_extended_left_instance.clone(), + relaxed_extended_right_instance.clone(), + challenge, + &error_commitments, + ); + + let folded_witness = RelaxedWitness::combine_and_sub_cross_terms( + relaxed_extended_left_witness, + relaxed_extended_right_witness, + challenge, + error, + ); + FoldingOutput { + folded_instance, + folded_witness, + t_0: error_commitments[0].clone(), + t_1: error_commitments[1].clone(), + relaxed_extended_left_instance, + relaxed_extended_right_instance, + to_absorb, + } + } + + /// Fold two relaxable instances into a relaxed instance. + /// It is parametrized by two different types `A` and `B` that represent + /// "relaxable" instances to be able to fold a normal and "already relaxed" + /// instance. + pub fn fold_instance_pair( + &self, + a: A, + b: B, + error_commitments: [PolyComm; 2], + fq_sponge: &mut Sponge, + ) -> RelaxedInstance + where + A: RelaxableInstance, + B: RelaxableInstance, + Sponge: FqSponge, CF::Curve, ScalarField>, + { + let a: RelaxedInstance = a.relax(); + let b: RelaxedInstance = b.relax(); + + // sanity check to verify that we only have one commitment in polycomm + // (i.e. domain = poly size) + assert_eq!(error_commitments[0].len(), 1); + assert_eq!(error_commitments[1].len(), 1); + + let to_absorb = { + let mut left = a.to_absorb(); + let right = b.to_absorb(); + left.0.extend(right.0); + left.1.extend(right.1); + left.1.extend([ + error_commitments[0].get_first_chunk(), + error_commitments[1].get_first_chunk(), + ]); + left + }; + + fq_sponge.absorb_fr(&to_absorb.0); + fq_sponge.absorb_g(&to_absorb.1); + + let challenge = fq_sponge.challenge(); + + RelaxedInstance::combine_and_sub_cross_terms(a, b, challenge, &error_commitments) + } +} + +pub(crate) fn check_selector(exp: &FoldingExp) -> Option<&C::Selector> { + match exp { + FoldingExp::Atom(ExtendedFoldingColumn::Selector(s)) => Some(s), + _ => None, + } +} diff --git a/folding/src/error_term.rs b/folding/src/error_term.rs new file mode 100644 index 0000000000..be5f163adb --- /dev/null +++ b/folding/src/error_term.rs @@ -0,0 +1,467 @@ +//! This module contains the functions used to compute the error terms, as +//! described in the [top-level documentation of the expressions +//! module](crate::expressions). + +use crate::{ + columns::ExtendedFoldingColumn, + decomposable_folding::check_selector, + eval_leaf::EvalLeaf, + expressions::{Degree, FoldingExp, IntegratedFoldingExpr, Sign}, + quadraticization::ExtendedWitnessGenerator, + FoldingConfig, FoldingEnv, Instance, RelaxedInstance, RelaxedWitness, ScalarField, +}; +use ark_ff::{Field, One, Zero}; +use ark_poly::{Evaluations, Radix2EvaluationDomain}; +use kimchi::circuits::expr::Variable; +use poly_commitment::{PolyComm, SRS}; + +// FIXME: for optimisation, as values are not necessarily Fp elements and are +// relatively small, we could get rid of the scalar field objects, and only use +// bigint where we only apply the modulus when needed. + +/// This type refers to the two instances to be folded +#[derive(Clone, Copy)] +pub enum Side { + Left = 0, + Right = 1, +} + +impl Side { + pub fn other(self) -> Self { + match self { + Side::Left => Side::Right, + Side::Right => Side::Left, + } + } +} + +/// Evaluates the expression in the provided side +pub(crate) fn eval_sided<'a, C: FoldingConfig>( + exp: &FoldingExp, + env: &'a ExtendedEnv, + side: Side, +) -> EvalLeaf<'a, ScalarField> { + use FoldingExp::*; + + match exp { + Atom(col) => env.col(col, side), + Double(e) => { + let col = eval_sided(e, env, side); + col.map(Field::double, |f| { + Field::double_in_place(f); + }) + } + Square(e) => { + let col = eval_sided(e, env, side); + col.map(Field::square, |f| { + Field::square_in_place(f); + }) + } + Add(e1, e2) => eval_sided(e1, env, side) + eval_sided(e2, env, side), + Sub(e1, e2) => eval_sided(e1, env, side) - eval_sided(e2, env, side), + Mul(e1, e2) => { + //this assumes to some degree that selectors don't multiply each other + let selector = check_selector(e1) + .or(check_selector(e2)) + .zip(env.enabled_selector()) + .map(|(s1, s2)| s1 == s2); + match selector { + Some(false) => { + let zero_vec = vec![ScalarField::::zero(); env.domain.size as usize]; + EvalLeaf::Result(zero_vec) + } + Some(true) | None => { + let d1 = e1.folding_degree(); + let d2 = e2.folding_degree(); + let e1 = match d1 { + Degree::Two => eval_sided(e1, env, side), + _ => eval_exp_error(e1, env, side), + }; + let e2 = match d2 { + Degree::Two => eval_sided(e2, env, side), + _ => eval_exp_error(e2, env, side), + }; + e1 * e2 + } + } + } + Pow(e, i) => match i { + 0 => EvalLeaf::Const(ScalarField::::one()), + 1 => eval_sided(e, env, side), + i => { + let err = eval_sided(e, env, side); + let mut acc = err.clone(); + for _ in 1..*i { + acc = acc * err.clone() + } + acc + } + }, + } +} + +pub(crate) fn eval_exp_error<'a, C: FoldingConfig>( + exp: &FoldingExp, + env: &'a ExtendedEnv, + side: Side, +) -> EvalLeaf<'a, ScalarField> { + use FoldingExp::*; + + match exp { + Atom(col) => env.col(col, side), + Double(e) => { + let col = eval_exp_error(e, env, side); + col.map(Field::double, |f| { + Field::double_in_place(f); + }) + } + Square(e) => match exp.folding_degree() { + Degree::Two => { + let cross = eval_exp_error(e, env, side) * eval_exp_error(e, env, side.other()); + cross.map(Field::double, |f| { + Field::double_in_place(f); + }) + } + _ => { + let e = eval_exp_error(e, env, side); + e.map(Field::square, |f| { + Field::square_in_place(f); + }) + } + }, + Add(e1, e2) => eval_exp_error(e1, env, side) + eval_exp_error(e2, env, side), + Sub(e1, e2) => eval_exp_error(e1, env, side) - eval_exp_error(e2, env, side), + Mul(e1, e2) => { + //this assumes to some degree that selectors don't multiply each other + let selector = check_selector(e1) + .or(check_selector(e2)) + .zip(env.enabled_selector()) + .map(|(s1, s2)| s1 == s2); + match selector { + Some(false) => { + let zero_vec = vec![ScalarField::::zero(); env.domain.size as usize]; + EvalLeaf::Result(zero_vec) + } + Some(true) | None => match (exp.folding_degree(), e1.folding_degree()) { + (Degree::Two, Degree::One) => { + let first = + eval_exp_error(e1, env, side) * eval_exp_error(e2, env, side.other()); + let second = + eval_exp_error(e1, env, side.other()) * eval_exp_error(e2, env, side); + first + second + } + _ => eval_exp_error(e1, env, side) * eval_exp_error(e2, env, side), + }, + } + } + Pow(_, 0) => EvalLeaf::Const(ScalarField::::one()), + Pow(e, 1) => eval_exp_error(e, env, side), + Pow(e, 2) => match (exp.folding_degree(), e.folding_degree()) { + (Degree::Two, Degree::One) => { + let first = eval_exp_error(e, env, side) * eval_exp_error(e, env, side.other()); + let second = eval_exp_error(e, env, side.other()) * eval_exp_error(e, env, side); + first + second + } + _ => { + let err = eval_exp_error(e, env, side); + err.clone() * err + } + }, + Pow(e, i) => match exp.folding_degree() { + Degree::Zero => { + let e = eval_exp_error(e, env, side); + // TODO: Implement `pow` here for efficiency + let mut acc = e.clone(); + for _ in 1..*i { + acc = acc * e.clone(); + } + acc + } + _ => panic!("degree over 2"), + }, + } +} + +/// Computes the error terms of a folding/homogeneous expression. +/// The extended environment contains all the evaluations of the columns, +/// including the ones added by the quadraticization process. +/// `u` is the variables used to homogeneize the expression. +/// The output is a pair of error terms. To see how it is computed, see the +/// [top-level documentation of the expressions module](crate::expressions). +pub(crate) fn compute_error( + exp: &IntegratedFoldingExpr, + env: &ExtendedEnv, + u: (ScalarField, ScalarField), +) -> [Vec>; 2] { + // FIXME: for speed, use inplace operations, and avoid cloning and + // allocating a new element. + // An allocation can cost a third of the time required for an addition and a + // 9th for a multiplication on the scalar field + // Indirections are also costly, so we should avoid them as much as + // possible, and inline code. + let (ul, ur) = (u.0, u.1); + let u_cross = ul * ur; + let zero_vec = vec![ScalarField::::zero(); env.domain.size as usize]; + let zero = || EvalLeaf::Result(zero_vec.clone()); + + let alphas_l = env + .get_relaxed_instance(Side::Left) + .extended_instance + .instance + .get_alphas(); + let alphas_r = env + .get_relaxed_instance(Side::Right) + .extended_instance + .instance + .get_alphas(); + + let t_0 = { + let t_0 = (zero(), zero()); + let (l, r) = exp.degree_0.iter().fold(t_0, |(l, r), (exp, sign, alpha)| { + //could be left or right, doesn't matter for constant terms + let exp = eval_exp_error(exp, env, Side::Left); + let alpha_l = alphas_l.get(*alpha).expect("alpha not present"); + let alpha_r = alphas_r.get(*alpha).expect("alpha not present"); + let left = exp.clone() * alpha_l; + let right = exp * alpha_r; + match sign { + Sign::Pos => (l + left, r + right), + Sign::Neg => (l - left, r - right), + } + }); + let cross2 = u_cross.double(); + let e0 = l.clone() * cross2 + r.clone() * ul.square(); + let e1 = r * cross2 + l * ur.square(); + (e0, e1) + }; + + let t_1 = { + let t_1 = (zero(), zero(), zero()); + let (l, cross, r) = exp + .degree_1 + .iter() + .fold(t_1, |(l, cross, r), (exp, sign, alpha)| { + let expl = eval_exp_error(exp, env, Side::Left); + let expr = eval_exp_error(exp, env, Side::Right); + let alpha_l = alphas_l.get(*alpha).expect("alpha not present"); + let alpha_r = alphas_r.get(*alpha).expect("alpha not present"); + let expr_cross = expl.clone() * alpha_r + expr.clone() * alpha_l; + let left = expl * alpha_l; + let right = expr * alpha_r; + match sign { + Sign::Pos => (l + left, cross + expr_cross, r + right), + Sign::Neg => (l - left, cross - expr_cross, r - right), + } + }); + let e0 = cross.clone() * ul + l * ur; + let e1 = cross.clone() * ur + r * ul; + (e0, e1) + }; + let t_2 = (zero(), zero()); + let t_2 = exp.degree_2.iter().fold(t_2, |(l, r), (exp, sign, alpha)| { + let expl = eval_sided(exp, env, Side::Left); + let expr = eval_sided(exp, env, Side::Right); + //left or right matter in some way, but not at the top level call + let cross = eval_exp_error(exp, env, Side::Left); + let alpha_l = alphas_l.get(*alpha).expect("alpha not present"); + let alpha_r = alphas_r.get(*alpha).expect("alpha not present"); + let left = expl * alpha_r + cross.clone() * alpha_l; + let right = expr * alpha_l + cross * alpha_r; + match sign { + Sign::Pos => (l + left, r + right), + Sign::Neg => (l - left, r - right), + } + }); + let t = [t_1, t_2] + .into_iter() + .fold(t_0, |(tl, tr), (txl, txr)| (tl + txl, tr + txr)); + + match t { + (EvalLeaf::Result(l), EvalLeaf::Result(r)) => [l, r], + _ => unreachable!(), + } +} + +/// An extended environment contains the evaluations of all the columns, including +/// the ones added by the quadraticization process. It also contains the +/// the two instances and witnesses that are being folded. +/// The domain is required to define the polynomial size of the evaluations of +/// the error terms. +pub(crate) struct ExtendedEnv { + inner: CF::Env, + instances: [RelaxedInstance; 2], + witnesses: [RelaxedWitness; 2], + domain: Radix2EvaluationDomain>, + selector: Option, +} + +impl ExtendedEnv { + pub fn new( + structure: &CF::Structure, + // maybe better to have some structure extended or something like that + instances: [RelaxedInstance; 2], + witnesses: [RelaxedWitness; 2], + domain: Radix2EvaluationDomain>, + selector: Option, + ) -> Self { + let inner_instances = [ + &instances[0].extended_instance.instance, + &instances[1].extended_instance.instance, + ]; + let inner_witnesses = [ + &witnesses[0].extended_witness.witness, + &witnesses[1].extended_witness.witness, + ]; + let inner = ::new(structure, inner_instances, inner_witnesses); + Self { + inner, + instances, + witnesses, + domain, + selector, + } + } + + pub fn enabled_selector(&self) -> Option<&CF::Selector> { + self.selector.as_ref() + } + + #[allow(clippy::type_complexity)] + pub fn unwrap( + self, + ) -> ( + [RelaxedInstance; 2], + [RelaxedWitness; 2], + ) { + let Self { + instances, + witnesses, + .. + } = self; + (instances, witnesses) + } + + pub fn get_relaxed_instance(&self, side: Side) -> &RelaxedInstance { + &self.instances[side as usize] + } + + pub fn get_relaxed_witness(&self, side: Side) -> &RelaxedWitness { + &self.witnesses[side as usize] + } + + pub fn col(&self, col: &ExtendedFoldingColumn, side: Side) -> EvalLeaf> { + use EvalLeaf::Col; + use ExtendedFoldingColumn::*; + let relaxed_instance = self.get_relaxed_instance(side); + let relaxed_witness = self.get_relaxed_witness(side); + let alphas = relaxed_instance.extended_instance.instance.get_alphas(); + match col { + Inner(Variable { col, row }) => Col(self.inner.col(*col, *row, side)), + WitnessExtended(i) => Col(&relaxed_witness + .extended_witness + .extended + .get(i) + .expect("extended column not present") + .evals), + Error => panic!("shouldn't happen"), + Constant(c) => EvalLeaf::Const(*c), + Challenge(chall) => EvalLeaf::Const(self.inner.challenge(*chall, side)), + Alpha(i) => { + let alpha = alphas.get(*i).expect("alpha not present"); + EvalLeaf::Const(alpha) + } + Selector(s) => Col(self.inner.selector(s, side)), + } + } + + pub fn col_try(&self, col: &ExtendedFoldingColumn, side: Side) -> bool { + use ExtendedFoldingColumn::*; + let relaxed_witness = self.get_relaxed_witness(side); + match col { + WitnessExtended(i) => relaxed_witness.extended_witness.extended.get(i).is_some(), + Error => panic!("shouldn't happen"), + Inner(_) | Constant(_) | Challenge(_) | Alpha(_) | Selector(_) => true, + } + } + + pub fn add_witness_evals(&mut self, i: usize, evals: Vec>, side: Side) { + let (_instance, relaxed_witness) = match side { + Side::Left => (&self.instances[0], &mut self.witnesses[0]), + Side::Right => (&self.instances[1], &mut self.witnesses[1]), + }; + let evals = Evaluations::from_vec_and_domain(evals, self.domain); + relaxed_witness.extended_witness.add_witness_evals(i, evals); + } + + pub fn needs_extension(&self, side: Side) -> bool { + !match side { + Side::Left => self.witnesses[0].extended_witness.is_extended(), + Side::Right => self.witnesses[1].extended_witness.is_extended(), + } + } + + /// Computes the extended witness column and the corresponding commitments, + /// updating the innner instance/witness pairs + pub fn compute_extension( + self, + witness_generator: &ExtendedWitnessGenerator, + srs: &CF::Srs, + ) -> Self { + let env = self; + let env = witness_generator.compute_extended_witness(env, Side::Left); + let env = witness_generator.compute_extended_witness(env, Side::Right); + let env = env.compute_extended_commitments(srs, Side::Left); + env.compute_extended_commitments(srs, Side::Right) + } + + // FIXME: use reference to avoid indirect copying/cloning. + /// Computes the commitments of the columns added by quadriaticization, for + /// the given side. + /// The commitments are added to the instance, in the same order for both + /// side. + /// Note that this function is only going to be called on the left instance + /// once. When we fold the second time, the left instance will already be + /// relaxed and will have the extended columns. + /// Therefore, the blinder is always the one provided by the user, and it is + /// saved in the field `blinder` in the case of a relaxed instance that has + /// been built from a non-relaxed one. + fn compute_extended_commitments(mut self, srs: &CF::Srs, side: Side) -> Self { + let (relaxed_instance, relaxed_witness) = match side { + Side::Left => (&mut self.instances[0], &self.witnesses[0]), + Side::Right => (&mut self.instances[1], &self.witnesses[1]), + }; + + // FIXME: use parallelisation + let blinder = PolyComm::new(vec![relaxed_instance.blinder]); + for (expected_i, (i, wit)) in relaxed_witness.extended_witness.extended.iter().enumerate() { + // in case any where to be missing for some reason + assert_eq!(*i, expected_i); + // Blinding the commitments to support the case the witness is zero. + // The IVC circuit expects to have non-zero commitments. + let commit = srs + .commit_evaluations_custom(self.domain, wit, &blinder) + .unwrap() + .commitment; + relaxed_instance.extended_instance.extended.push(commit) + } + // FIXME: maybe returning a value is not necessary as it does inplace operations. + // It implies copying on the stack and possibly copy multiple times. + self + } + + /// Return the list of scalars and commitments to be absorbed, by + /// concatenating the ones of the left with the ones of the right instance + pub(crate) fn to_absorb( + &self, + t0: &CF::Curve, + t1: &CF::Curve, + ) -> (Vec>, Vec) { + let mut left = self.instances[0].to_absorb(); + let right = self.instances[1].to_absorb(); + + left.0.extend(right.0); + left.1.extend(right.1); + left.1.extend([t0, t1]); + left + } +} diff --git a/folding/src/eval_leaf.rs b/folding/src/eval_leaf.rs new file mode 100644 index 0000000000..691d3fbc54 --- /dev/null +++ b/folding/src/eval_leaf.rs @@ -0,0 +1,140 @@ +#[derive(Clone, Debug)] +/// Result of a folding expression evaluation. +pub enum EvalLeaf<'a, F> { + Const(F), + Col(&'a [F]), // slice will suffice + Result(Vec), +} + +impl<'a, F: std::fmt::Display> std::fmt::Display for EvalLeaf<'a, F> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let slice = match self { + EvalLeaf::Const(c) => { + write!(f, "Const: {}", c)?; + return Ok(()); + } + EvalLeaf::Col(a) => a, + EvalLeaf::Result(a) => a.as_slice(), + }; + writeln!(f, "[")?; + for e in slice.iter() { + writeln!(f, "{e}")?; + } + write!(f, "]")?; + Ok(()) + } +} + +impl<'a, F: std::ops::Add + Clone> std::ops::Add for EvalLeaf<'a, F> { + type Output = Self; + + fn add(self, rhs: Self) -> Self { + Self::bin_op(|a, b| a + b, self, rhs) + } +} + +impl<'a, F: std::ops::Sub + Clone> std::ops::Sub for EvalLeaf<'a, F> { + type Output = Self; + + fn sub(self, rhs: Self) -> Self { + Self::bin_op(|a, b| a - b, self, rhs) + } +} + +impl<'a, F: std::ops::Mul + Clone> std::ops::Mul for EvalLeaf<'a, F> { + type Output = Self; + + fn mul(self, rhs: Self) -> Self { + Self::bin_op(|a, b| a * b, self, rhs) + } +} + +impl<'a, F: std::ops::Mul + Clone> std::ops::Mul for EvalLeaf<'a, F> { + type Output = Self; + + fn mul(self, rhs: F) -> Self { + self * Self::Const(rhs) + } +} + +impl<'a, F: Clone> EvalLeaf<'a, F> { + pub fn map F, I: Fn(&mut F)>(self, map: M, in_place: I) -> Self { + use EvalLeaf::*; + match self { + Const(c) => Const(map(&c)), + Col(col) => { + let res = col.iter().map(map).collect(); + Result(res) + } + Result(mut col) => { + for cell in col.iter_mut() { + in_place(cell); + } + Result(col) + } + } + } + + fn bin_op F>(f: M, a: Self, b: Self) -> Self { + use EvalLeaf::*; + match (a, b) { + (Const(a), Const(b)) => Const(f(a, b)), + (Const(a), Col(b)) => { + let res = b.iter().map(|b| f(a.clone(), b.clone())).collect(); + Result(res) + } + (Col(a), Const(b)) => { + let res = a.iter().map(|a| f(a.clone(), b.clone())).collect(); + Result(res) + } + (Col(a), Col(b)) => { + let res = (a.iter()) + .zip(b.iter()) + .map(|(a, b)| f(a.clone(), b.clone())) + .collect(); + Result(res) + } + (Result(mut a), Const(b)) => { + for a in a.iter_mut() { + *a = f(a.clone(), b.clone()) + } + Result(a) + } + (Const(a), Result(mut b)) => { + for b in b.iter_mut() { + *b = f(a.clone(), b.clone()) + } + Result(b) + } + (Result(mut a), Col(b)) => { + for (a, b) in a.iter_mut().zip(b.iter()) { + *a = f(a.clone(), b.clone()) + } + Result(a) + } + (Col(a), Result(mut b)) => { + for (a, b) in a.iter().zip(b.iter_mut()) { + *b = f(a.clone(), b.clone()) + } + Result(b) + } + (Result(mut a), Result(b)) => { + for (a, b) in a.iter_mut().zip(b.into_iter()) { + *a = f(a.clone(), b) + } + Result(a) + } + } + } + + pub fn unwrap(self) -> Vec + where + F: Clone, + { + match self { + EvalLeaf::Col(res) => res.to_vec(), + EvalLeaf::Result(res) => res, + EvalLeaf::Const(_) => panic!("Attempted to unwrap a constant"), + } + } +} diff --git a/folding/src/expressions.rs b/folding/src/expressions.rs new file mode 100644 index 0000000000..03ed1a840e --- /dev/null +++ b/folding/src/expressions.rs @@ -0,0 +1,1155 @@ +//! Implement a library to represent expressions/multivariate polynomials that +//! can be used with folding schemes like +//! [Nova](https://eprint.iacr.org/2021/370). +//! +//! We do enforce expressions to be degree `2` maximum to apply our folding +//! scheme. +//! +//! Before folding, we do suppose that each expression has been reduced to +//! degree `2` using [crate::quadraticization]. +//! +//! The library introduces different types of expressions: +//! - [FoldingCompatibleExpr]: an expression that can be used with folding. It +//! aims to be an intermediate representation from +//! [kimchi::circuits::expr::Expr]. It can be printed in a human-readable way +//! using the trait [ToString]. +//! - [FoldingExp]: an internal representation of a folded expression. +//! - [IntegratedFoldingExpr]: a simplified expression with all terms separated +//! +//! When using the library, the user should: +//! - Convert an expression from [kimchi::circuits::expr::Expr] into a +//! [FoldingCompatibleExpr] using the trait [From]. +//! - Convert a list of [FoldingCompatibleExpr] into a [IntegratedFoldingExpr] +//! using the function [folding_expression]. +//! +//! The user can also choose to build a structure [crate::FoldingScheme] from a +//! list of [FoldingCompatibleExpr]. +//! +//! As a reminder, after we reduce to degree 2, the multivariate polynomial +//! `P(X_{1}, ..., X_{n})` describing the NP relation will be +//! "relaxed" in another polynomial of the form `P_relaxed(X_{1}, ..., X_{n}, u)`. +//! First, we decompose the polynomial `P` in its monomials of degree `0`, `1` and `2`: +//! ```text +//! P(X_{1}, ..., X_{n}) = ∑_{i} f_{i, 0}(X_{1}, ..., X_{n}) + +//! ∑_{i} f_{i, 1}(X_{1}, ..., X_{n}) + +//! ∑_{i} f_{i, 2}(X_{1}, ..., X_{n}) +//! ``` +//! where `f_{i, 0}` is a monomial of degree `0`, `f_{i, 1}` is a monomial of degree +//! `1` and `f_{i, 2}` is a monomial of degree `2`. +//! For instance, for the polynomial `P(X_{1}, X_{2}, X_{3}) = X_{1} * X_{2} + +//! (1 - X_{3})`, we have: +//! ```text +//! f_{0, 0}(X_{1}, X_{2}, X_{3}) = 1 +//! f_{0, 1}(X_{1}, X_{2}, X_{3}) = -X_{3} +//! f_{0, 2}(X_{1}, X_{2}, X_{3}) = X_{1} * X_{2} +//! ``` +//! Then, we can relax the polynomial `P` in `P_relaxed` by adding a new +//! variable `u` in the following way: +//! - For the monomials `f_{i, 0}`, i.e. the monomials of degree `0`, we add `u^2` +//! to the expression. +//! - For the monomials `f_{i, 1}`, we add `u` to the expression. +//! - For the monomials `f_{i, 2}`, we keep the expression as is. +//! +//! For the polynomial `P(X_{1}, X_{2}, X_{3}) = X_{1} * X_{2} + (1 - X_{3})`, we have: +//! ```text +//! P_relaxed(X_{1}, X_{2}, X_{3}, u) = X_{1} * X_{2} + u (u - X_{3}) +//! ``` +//! +//! From the relaxed form of the polynomial, we can "fold" multiple instances of +//! the NP relation by randomising it into a single instance by adding an error +//! term `E`. +//! For instance, for the polynomial `P_relaxed(X_{1}, X_{2}, X_{3}, u) = X_{1} * +//! X_{2} + u (u - X_{3})`, +//! for two instances `(X_{1}, X_{2}, X_{3}, u)` and `(X_{1}', X_{2}', X_{3}', +//! u')`, we can fold them into a single instance by coining a random value `r`: +//! ```text +//! X''_{1} = X_{1} + r X_{1}' +//! X''_{2} = X_{2} + r X_{2}' +//! X''_{3} = X_{3} + r X_{3}' +//! u'' = u + r u' +//! ``` +//! Computing the polynomial `P_relaxed(X''_{1}, X''_{2}, X''_{3}, u'')` will +//! give: +//! ```text +//! (X_{1} + r X'_{1}) (X_{2} + r X'_{2}) \ +//! + (u + r u') [(u + r u') - (X_{3} + r X'_{3})] +//! ``` +//! which can be simplified into: +//! ```text +//! P_relaxed(X_{1}, X_{2}, X_{3}, u) + P_relaxed(r X_{1}', r X_{2}', r X_{3}', r u') +//! + r [u (u' - X_{3}) + u' (u - X_{3})] + r [X_{1} X_{2}' + X_{2} X_{1}'] +//! \---------------------------------/ \----------------------------------/ +//! cross terms of monomials of degree 1 cross terms of monomials of degree 2 +//! and degree 0 +//! ``` +//! The error term `T` (or "cross term") is the last term of the expression, +//! multiplied by `r`. +//! More generally, the error term is the sum of all monomials introduced by +//! the "cross terms" of the instances. For example, if there is a monomial of +//! degree 2 like `X_{1} * X_{2}`, it introduces the cross terms +//! `r X_{1} X_{2}' + r X_{2} X_{1}'`. For a monomial of degree 1, for example +//! `u X_{1}`, it introduces the cross terms `r u X_{1}' + r u' X_{1}`. +//! +//! Note that: +//! ```text +//! P_relaxed(r X_{1}', r X_{2}', r X_{3}', r u') +//! = r^2 P_relaxed(X_{1}', X_{2}', X_{3}', u') +//! ``` +//! and `P_relaxed` is of degree `2`. More +//! precisely, `P_relaxed` is homogenous. And that is the main idea of folding: +//! the "relaxation" of a polynomial means we make it homogenous for a certain +//! degree `d` by introducing the new variable `u`, and introduce the concept of +//! "error terms" that will englobe the "cross-terms". The prover takes care of +//! computing the cross-terms and commit to them. +//! +//! While folding, we aggregate the error terms of all instances into a single +//! error term, E. +//! In our example, if we have a folded instance with the non-zero +//! error terms `E_{1}` and `E_{2}`, we have: +//! ```text +//! E = E_{1} + r T + E_{2} +//! ``` +//! +//! ## Aggregating constraints +//! +//! The library also provides a way to fold NP relations described by a list of +//! multi-variate polynomials, like we usually have in a zkSNARK circuit. +//! +//! In PlonK, we aggregate all the polynomials into a single polynomial by +//! coining a random value `α`. For instance, if we have two polynomials `P` and +//! `Q` describing our computation in a zkSNARK circuit, we usually use the +//! randomized polynomial `P + α Q` (used to build the quotient polynomial in +//! PlonK). +//! +//! More generally, if for each row, our computation is constrained by the polynomial +//! list `[P_{1}, P_{2}, ..., P_{n}]`, we can aggregate them into a single +//! polynomial `P_{agg} = ∑_{i} α^{i} P_{i}`. Multiplying by the α terms +//! consequently increases the overall degree of the expression. +//! +//! In particular, when we reduce a polynomial to degree 2, we have this case +//! where the circuit is described by a list of polynomials and we aggregate +//! them into a single polynomial. +//! +//! For instance, if we have two polynomials `P(X_{1}, X_{2}, X_{3})` and +//! `Q(X_{1}, X_{2}, X_{3})` such that: +//! ```text +//! P(X_{1}, X_{2}, X_{3}) = X_{1} * X_{2} + (1 - X_{3}) +//! Q(X_{1}, X_{2}, X_{3}) = X_{1} + X_{2} +//! ``` +//! +//! The relaxed form of the polynomials are: +//! ```text +//! P_relaxed(X_{1}, X_{2}, X_{3}, u) = X_{1} * X_{2} + u (u - X_{3}) +//! Q_relaxed(X_{1}, X_{2}, X_{3}, u) = u X_{1} + u X_{2} +//! ``` +//! +//! We start by coining `α_{1}` and `α_{2}` and we compute the polynomial +//! `P'(X_{1}, X_{2}, X_{3}, u, α_{1})` and `Q'(X_{1}, X_{2}, X_{3}, α_{2})` such that: +//! ```text +//! P'(X_{1}, X_{2}, X_{3}, u, α_{1}) = α_{1} P_relaxed(X_{1}, X_{2}, X_{3}, u) +//! = α_{1} (X_{1} * X_{2} + u (u - X_{3})) +//! = α_{1} X_{1} * X_{2} + α_{1} u^2 - α_{1} u X_{3} +//! Q'(X_{1}, X_{2}, X_{3}, u, α_{2}) = α_{2} Q_relaxed(X_{1}, X_{2}, X_{3}, u) +//! = α_{2} (u X_{1} + u X_{2}) +//! = α_{2} u X_{1} + α_{2} u X_{2} +//! ``` +//! and we want to fold the multivariate polynomial S defined over six +//! variables: +//! ```text +//! S(X_{1}, X_{2}, X_{3}, u, α_{1}, α_{2}) +//! = P'(X_{1}, X_{2}, X_{3}, u, α_{1}) + Q'(X_{1}, X_{2}, X_{3}, u, α_{2})`. +//! = α_{1} X_{1} X_{2} + +//! α_{1} u^2 - +//! α_{1} u X_{3} + +//! α_{2} u X_{1} + +//! α_{2} u X_{2} +//! ``` +//! +//! Note that we end up with everything of the same degree, which is `3` in this +//! case. The variables `α_{1}` and `α_{2}` increase the degree of the +//! homogeneous expressions by one. +//! +//! For two given instances `(X_{1}, X_{2}, X_{3}, u, α_{1}, α_{2})` and +//! `(X_{1}', X_{2}', X_{3}', u', α_{1}', α_{2}')`, we coin a random value `r` and we compute: +//! ```text +//! X''_{1} = X_{1} + r X'_{1} +//! X''_{2} = X_{2} + r X'_{2} +//! X''_{3} = X_{3} + r X'_{3} +//! u'' = u + r u' +//! α''_{1} = α_{1} + r α'_{1} +//! α''_{2} = α_{2} + r α'_{2} +//! ``` +//! +//! From there, we compute the evaluations of the polynomial S at the point +//! `S(X''_{1}, X''_{2}, X''_{3}, u'', α''_{1}, α''_{2})`, which gives: +//! ```text +//! S(X_{1}, X_{2}, X_{3}, u, α_{1}, α_{2}) +//! + S(r X'_{1}, r X'_{2}, r X'_{3}, r u', r α'_{1}, r α'_{2}) +//! + r T_{0} +//! + r^2 T_{1} +//! ``` +//! where `T_{0}` (respectively `T_{1}`) are cross terms that are multiplied by +//! `r` (respectively `r^2`). More precisely, for `T_{0}` we have: +//! ```text +//! T_{0} = a_{1} X_{1} X'{2} + +//! X_{2} (α_{1} X'_{1} + α'_{1} X_{1}) + +//! // we repeat for a_{1} u^{2}, ... as described below +//! ``` +//! We must see each monomial as a polynomial P(X, Y, Z) of degree 3, and the +//! cross-term for each monomial will be, for (X', Y', Z') and (X, Y, Z): +//! ```text +//! X Y Z' + Z (X Y' + X' Y) +//! ``` +//! +//! As for the degree`2` case described before, we notice that the polynomial S +//! is homogeneous of degree 3, i.e. +//! ```text +//! S(r X'_{1}, r X'_{2}, r X'_{3}, r u', r α'_{1}, r α'_{2}) +//! = r^3 S(X'_{1}, X'_{2}, X'_{3}, u', α'_{1}, α'_{2}) +//! ``` +//! +//! ## Fiat-Shamir challenges, interactive protocols and lookup arguments +//! +//! Until now, we have described a way to fold multi-variate polynomials, which +//! is mostly a generalization of [Nova](https://eprint.iacr.org/2021/370) for +//! any multi-variate polynomial. +//! However, we did not describe how it can be used to describe and fold +//! interactive protocols based on polynomials, like PlonK. We do suppose the +//! interactive protocol can be made non-interactive by using the Fiat-Shamir +//! transformation. +//! +//! To fold interactive protocols, our folding scheme must also support +//! Fiat-Shamir challenges. This implementation handles this by representing +//! challenges as new variables in the polynomial describing the NP relation. +//! The challenges are then aggregated in the same way as the other variables. +//! +//! For instance, let's consider the additive +//! lookup/logup argument. For a detailed description of the protocol, see [the +//! online +//! documentation](https://o1-labs.github.io/proof-systems/rustdoc/kimchi_msm/logup/index.html). +//! We will suppose we have only one table `T` and Alice wants to prove to Bob +//! that she knows that all evaluations of `f(X)` is in `t(X)`. The additive +//! lookup argument is described by the polynomial equation: +//! ```text +//! β + f(x) = m(x) (β + t(x)) +//! ``` +//! where β is the challenge, `f(x)` is the polynomial whose evaluations describe +//! the value Alice wants to prove to Bob that is in the table, `m(x)` is +//! the polynomial describing the multiplicities, and `t(x)` is the +//! polynomial describing the (fixed) table. +//! +//! The equation can be described by the multi-variate polynomial `LOGUP`: +//! ```text +//! LOGUP(β, F, M, T) = β + F - M (β + T) +//! ``` +//! +//! The relaxed/homogeneous version of the polynomial LOGUP is: +//! ```text +//! LOGUP_relaxed(β, F, M, T, u) = u β + u F - M (β + T) +//! ``` +//! +//! Folding this polynomial means that we will coin a random value `r`, and we compute: +//! ```text +//! β'' = β + r β' +//! F'' = F + r F' +//! M'' = M + r M' +//! T'' = T + r T' +//! u'' = u + r u' +//! ``` +//! +//! ## Supporting polynomial commitment blinders +//! +//! The library also supports polynomial commitment blinders. The blinding +//! factors are represented as new variables in the polynomial describing the NP +//! relation. The blinding factors are then aggregated in the same way as the +//! other variables. +//! We want to support blinders in the polynomial commitment scheme to avoid +//! committing to the zero zero polynomial. Using a blinder, we can always +//! suppose that our elliptic curves points are not the point at infinity. +//! The library handles the blinding factors as variables in each instance. +//! +//! When doing the final proof, the blinder factor that will need to be used is +//! the one from the final relaxed instance. + +use crate::{ + columns::ExtendedFoldingColumn, + quadraticization::{quadraticize, ExtendedWitnessGenerator, Quadraticized}, + FoldingConfig, ScalarField, +}; +use ark_ec::AffineRepr; +use ark_ff::One; +use derivative::Derivative; +use itertools::Itertools; +use kimchi::circuits::{ + berkeley_columns::BerkeleyChallengeTerm, + expr::{ConstantExprInner, ConstantTerm, ExprInner, Operations, Variable}, + gate::CurrOrNext, +}; +use num_traits::Zero; + +/// Describe the degree of a constraint. +/// As described in the [top level documentation](super::expressions), we only +/// support constraints with degree up to `2` +#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum Degree { + Zero, + One, + Two, +} + +impl std::ops::Add for Degree { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + use Degree::*; + match (self, rhs) { + (_, Two) | (Two, _) => Two, + (_, One) | (One, _) => One, + (Zero, Zero) => Zero, + } + } +} + +impl std::ops::Mul for &Degree { + type Output = Degree; + + fn mul(self, rhs: Self) -> Self::Output { + use Degree::*; + match (self, rhs) { + (Zero, other) | (other, Zero) => *other, + (One, One) => Two, + _ => panic!("The folding library does support only expressions of degree `2` maximum"), + } + } +} + +pub trait FoldingColumnTrait: Copy + Clone { + fn is_witness(&self) -> bool; + + /// Return the degree of the column + /// - `0` if the column is a constant + /// - `1` if the column will take part of the randomisation (see [top level + /// documentation](super::expressions) + fn degree(&self) -> Degree { + match self.is_witness() { + true => Degree::One, + false => Degree::Zero, + } + } +} + +/// Extra expressions that can be created by folding +#[derive(Derivative)] +#[derivative( + Clone(bound = "C: FoldingConfig"), + Debug(bound = "C: FoldingConfig"), + PartialEq(bound = "C: FoldingConfig") +)] +pub enum ExpExtension { + /// The variable `u` used to make the polynomial homogenous + U, + /// The error term + Error, + /// Additional columns created by quadraticization + ExtendedWitness(usize), + /// The random values `α_{i}` used to aggregate constraints + Alpha(usize), + /// Represent a dynamic selector, in the case of using decomposable folding + Selector(C::Selector), +} + +/// Components to be used to convert multivariate polynomials into "compatible" +/// multivariate polynomials that will be translated to folding expressions. +#[derive(Derivative)] +#[derivative( + Clone(bound = "C: FoldingConfig"), + PartialEq(bound = "C: FoldingConfig"), + Debug(bound = "C: FoldingConfig") +)] +pub enum FoldingCompatibleExprInner { + Constant(::ScalarField), + Challenge(C::Challenge), + Cell(Variable), + /// extra nodes created by folding, should not be passed to folding + Extensions(ExpExtension), +} + +/// Compatible folding expressions that can be used with folding schemes. +/// An expression from [kimchi::circuits::expr::Expr] can be converted into a +/// [FoldingCompatibleExpr] using the trait [From]. +/// From there, an expression of type [IntegratedFoldingExpr] can be created +/// using the function [folding_expression]. +#[derive(Derivative)] +#[derivative( + Clone(bound = "C: FoldingConfig"), + PartialEq(bound = "C: FoldingConfig"), + Debug(bound = "C: FoldingConfig") +)] +pub enum FoldingCompatibleExpr { + Atom(FoldingCompatibleExprInner), + Pow(Box, u64), + Add(Box, Box), + Sub(Box, Box), + Mul(Box, Box), + Double(Box), + Square(Box), +} + +impl std::ops::Add for FoldingCompatibleExpr { + type Output = Self; + + fn add(self, rhs: Self) -> Self { + Self::Add(Box::new(self), Box::new(rhs)) + } +} + +impl std::ops::Sub for FoldingCompatibleExpr { + type Output = Self; + + fn sub(self, rhs: Self) -> Self { + Self::Sub(Box::new(self), Box::new(rhs)) + } +} + +impl std::ops::Mul for FoldingCompatibleExpr { + type Output = Self; + + fn mul(self, rhs: Self) -> Self { + Self::Mul(Box::new(self), Box::new(rhs)) + } +} + +/// Implement a human-readable version of a folding compatible expression. +// FIXME: use Display instead, to follow the recommendation of the trait. +impl ToString for FoldingCompatibleExpr { + fn to_string(&self) -> String { + match self { + FoldingCompatibleExpr::Atom(c) => match c { + FoldingCompatibleExprInner::Constant(c) => { + if c.is_zero() { + "0".to_string() + } else { + c.to_string() + } + } + FoldingCompatibleExprInner::Challenge(c) => { + format!("{:?}", c) + } + FoldingCompatibleExprInner::Cell(cell) => { + let Variable { col, row } = cell; + let next = match row { + CurrOrNext::Curr => "", + CurrOrNext::Next => " * ω", + }; + format!("Col({:?}){}", col, next) + } + FoldingCompatibleExprInner::Extensions(e) => match e { + ExpExtension::U => "U".to_string(), + ExpExtension::Error => "E".to_string(), + ExpExtension::ExtendedWitness(i) => { + format!("ExWit({})", i) + } + ExpExtension::Alpha(i) => format!("α_{i}"), + ExpExtension::Selector(s) => format!("Selec({:?})", s), + }, + }, + FoldingCompatibleExpr::Double(e) => { + format!("2 {}", e.to_string()) + } + FoldingCompatibleExpr::Square(e) => { + format!("{} ^ 2", e.to_string()) + } + FoldingCompatibleExpr::Add(e1, e2) => { + format!("{} + {}", e1.to_string(), e2.to_string()) + } + FoldingCompatibleExpr::Sub(e1, e2) => { + format!("{} - {}", e1.to_string(), e2.to_string()) + } + FoldingCompatibleExpr::Mul(e1, e2) => { + format!("({}) ({})", e1.to_string(), e2.to_string()) + } + FoldingCompatibleExpr::Pow(_, _) => todo!(), + } + } +} + +/// Internal expression used for folding. +/// A "folding" expression is a multivariate polynomial like defined in +/// [kimchi::circuits::expr] with the following differences. +/// - No constructors related to zero-knowledge or lagrange basis (i.e. no +/// constructors related to the PIOP) +/// - The variables includes a set of columns that describes the initial circuit +/// shape, with additional columns strictly related to the folding scheme (error +/// term, etc). +// TODO: renamed in "RelaxedExpression"? +#[derive(Derivative)] +#[derivative( + Hash(bound = "C:FoldingConfig"), + Debug(bound = "C:FoldingConfig"), + Clone(bound = "C:FoldingConfig"), + PartialEq(bound = "C:FoldingConfig"), + Eq(bound = "C:FoldingConfig") +)] +pub enum FoldingExp { + Atom(ExtendedFoldingColumn), + Pow(Box, u64), + Add(Box, Box), + Mul(Box, Box), + Sub(Box, Box), + Double(Box), + Square(Box), +} + +impl std::ops::Add for FoldingExp { + type Output = Self; + + fn add(self, rhs: Self) -> Self { + Self::Add(Box::new(self), Box::new(rhs)) + } +} + +impl std::ops::Sub for FoldingExp { + type Output = Self; + + fn sub(self, rhs: Self) -> Self { + Self::Sub(Box::new(self), Box::new(rhs)) + } +} + +impl std::ops::Mul for FoldingExp { + type Output = Self; + + fn mul(self, rhs: Self) -> Self { + Self::Mul(Box::new(self), Box::new(rhs)) + } +} + +impl FoldingExp { + pub fn double(self) -> Self { + Self::Double(Box::new(self)) + } +} + +/// Converts an expression "compatible" with folding into a folded expression. +// TODO: use "into"? +// FIXME: add independent tests +// FIXME: test independently the behavior of pow_to_mul, and explain only why 8 +// maximum +impl FoldingCompatibleExpr { + pub fn simplify(self) -> FoldingExp { + use FoldingExp::*; + match self { + FoldingCompatibleExpr::Atom(atom) => match atom { + FoldingCompatibleExprInner::Constant(c) => Atom(ExtendedFoldingColumn::Constant(c)), + FoldingCompatibleExprInner::Challenge(c) => { + Atom(ExtendedFoldingColumn::Challenge(c)) + } + FoldingCompatibleExprInner::Cell(col) => Atom(ExtendedFoldingColumn::Inner(col)), + FoldingCompatibleExprInner::Extensions(ext) => { + match ext { + // TODO: this shouldn't be allowed, but is needed for now to add + // decomposable folding without many changes, it should be + // refactored at some point in the future + ExpExtension::Selector(s) => Atom(ExtendedFoldingColumn::Selector(s)), + _ => { + panic!("this should only be created by folding itself") + } + } + } + }, + FoldingCompatibleExpr::Double(exp) => Double(Box::new((*exp).simplify())), + FoldingCompatibleExpr::Square(exp) => Square(Box::new((*exp).simplify())), + FoldingCompatibleExpr::Add(e1, e2) => { + let e1 = Box::new(e1.simplify()); + let e2 = Box::new(e2.simplify()); + Add(e1, e2) + } + FoldingCompatibleExpr::Sub(e1, e2) => { + let e1 = Box::new(e1.simplify()); + let e2 = Box::new(e2.simplify()); + Sub(e1, e2) + } + FoldingCompatibleExpr::Mul(e1, e2) => { + let e1 = Box::new(e1.simplify()); + let e2 = Box::new(e2.simplify()); + Mul(e1, e2) + } + FoldingCompatibleExpr::Pow(e, p) => Self::pow_to_mul(e.simplify(), p), + } + } + + fn pow_to_mul(exp: FoldingExp, p: u64) -> FoldingExp + where + C::Column: Clone, + C::Challenge: Clone, + { + use FoldingExp::*; + let e = Box::new(exp); + let e_2 = Box::new(Square(e.clone())); + match p { + 2 => *e_2, + 3 => Mul(e, e_2), + 4..=8 => { + let e_4 = Box::new(Square(e_2.clone())); + match p { + 4 => *e_4, + 5 => Mul(e, e_4), + 6 => Mul(e_2, e_4), + 7 => Mul(e, Box::new(Mul(e_2, e_4))), + 8 => Square(e_4), + _ => unreachable!(), + } + } + _ => panic!("unsupported"), + } + } + + /// Maps variable (column index) in expression using the `mapper` + /// function. Can be used to modify (remap) the indexing of + /// columns after the expression is built. + pub fn map_variable( + self, + mapper: &(dyn Fn(Variable) -> Variable), + ) -> FoldingCompatibleExpr { + use FoldingCompatibleExpr::*; + match self { + FoldingCompatibleExpr::Atom(atom) => match atom { + FoldingCompatibleExprInner::Cell(col) => { + Atom(FoldingCompatibleExprInner::Cell((mapper)(col))) + } + atom => Atom(atom), + }, + FoldingCompatibleExpr::Double(exp) => Double(Box::new(exp.map_variable(mapper))), + FoldingCompatibleExpr::Square(exp) => Square(Box::new(exp.map_variable(mapper))), + FoldingCompatibleExpr::Add(e1, e2) => { + let e1 = Box::new(e1.map_variable(mapper)); + let e2 = Box::new(e2.map_variable(mapper)); + Add(e1, e2) + } + FoldingCompatibleExpr::Sub(e1, e2) => { + let e1 = Box::new(e1.map_variable(mapper)); + let e2 = Box::new(e2.map_variable(mapper)); + Sub(e1, e2) + } + FoldingCompatibleExpr::Mul(e1, e2) => { + let e1 = Box::new(e1.map_variable(mapper)); + let e2 = Box::new(e2.map_variable(mapper)); + Mul(e1, e2) + } + FoldingCompatibleExpr::Pow(e, p) => Pow(Box::new(e.map_variable(mapper)), p), + } + } + + /// Map all quad columns into regular witness columns. + pub fn flatten_quad_columns( + self, + mapper: &(dyn Fn(usize) -> Variable), + ) -> FoldingCompatibleExpr { + use FoldingCompatibleExpr::*; + match self { + FoldingCompatibleExpr::Atom(atom) => match atom { + FoldingCompatibleExprInner::Extensions(ExpExtension::ExtendedWitness(i)) => { + Atom(FoldingCompatibleExprInner::Cell((mapper)(i))) + } + atom => Atom(atom), + }, + FoldingCompatibleExpr::Double(exp) => { + Double(Box::new(exp.flatten_quad_columns(mapper))) + } + FoldingCompatibleExpr::Square(exp) => { + Square(Box::new(exp.flatten_quad_columns(mapper))) + } + FoldingCompatibleExpr::Add(e1, e2) => { + let e1 = Box::new(e1.flatten_quad_columns(mapper)); + let e2 = Box::new(e2.flatten_quad_columns(mapper)); + Add(e1, e2) + } + FoldingCompatibleExpr::Sub(e1, e2) => { + let e1 = Box::new(e1.flatten_quad_columns(mapper)); + let e2 = Box::new(e2.flatten_quad_columns(mapper)); + Sub(e1, e2) + } + FoldingCompatibleExpr::Mul(e1, e2) => { + let e1 = Box::new(e1.flatten_quad_columns(mapper)); + let e2 = Box::new(e2.flatten_quad_columns(mapper)); + Mul(e1, e2) + } + FoldingCompatibleExpr::Pow(e, p) => Pow(Box::new(e.flatten_quad_columns(mapper)), p), + } + } +} + +impl FoldingExp { + /// Compute the degree of a folding expression. + /// Only constants are of degree `0`, the rest is of degree `1`. + /// An atom of degree `1` means that the atom is going to be randomised as + /// described in the [top level documentation](super::expressions). + pub(super) fn folding_degree(&self) -> Degree { + use Degree::*; + match self { + FoldingExp::Atom(ex_col) => match ex_col { + ExtendedFoldingColumn::Inner(col) => col.col.degree(), + ExtendedFoldingColumn::WitnessExtended(_) => One, + ExtendedFoldingColumn::Error => One, + ExtendedFoldingColumn::Constant(_) => Zero, + ExtendedFoldingColumn::Challenge(_) => One, + ExtendedFoldingColumn::Alpha(_) => One, + ExtendedFoldingColumn::Selector(_) => One, + }, + FoldingExp::Double(e) => e.folding_degree(), + FoldingExp::Square(e) => &e.folding_degree() * &e.folding_degree(), + FoldingExp::Mul(e1, e2) => &e1.folding_degree() * &e2.folding_degree(), + FoldingExp::Add(e1, e2) | FoldingExp::Sub(e1, e2) => { + e1.folding_degree() + e2.folding_degree() + } + FoldingExp::Pow(_, 0) => Zero, + FoldingExp::Pow(e, 1) => e.folding_degree(), + FoldingExp::Pow(e, i) => { + let degree = e.folding_degree(); + let mut acc = degree; + for _ in 1..*i { + acc = &acc * °ree; + } + acc + } + } + } + + /// Convert a folding expression into a compatible one. + fn into_compatible(self) -> FoldingCompatibleExpr { + use FoldingCompatibleExpr::*; + use FoldingCompatibleExprInner::*; + match self { + FoldingExp::Atom(c) => match c { + ExtendedFoldingColumn::Inner(col) => Atom(Cell(col)), + ExtendedFoldingColumn::WitnessExtended(i) => { + Atom(Extensions(ExpExtension::ExtendedWitness(i))) + } + ExtendedFoldingColumn::Error => Atom(Extensions(ExpExtension::Error)), + ExtendedFoldingColumn::Constant(c) => Atom(Constant(c)), + ExtendedFoldingColumn::Challenge(c) => Atom(Challenge(c)), + ExtendedFoldingColumn::Alpha(i) => Atom(Extensions(ExpExtension::Alpha(i))), + ExtendedFoldingColumn::Selector(s) => Atom(Extensions(ExpExtension::Selector(s))), + }, + FoldingExp::Double(exp) => Double(Box::new(exp.into_compatible())), + FoldingExp::Square(exp) => Square(Box::new(exp.into_compatible())), + FoldingExp::Add(e1, e2) => { + let e1 = Box::new(e1.into_compatible()); + let e2 = Box::new(e2.into_compatible()); + Add(e1, e2) + } + FoldingExp::Sub(e1, e2) => { + let e1 = Box::new(e1.into_compatible()); + let e2 = Box::new(e2.into_compatible()); + Sub(e1, e2) + } + FoldingExp::Mul(e1, e2) => { + let e1 = Box::new(e1.into_compatible()); + let e2 = Box::new(e2.into_compatible()); + Mul(e1, e2) + } + // TODO: Replace with `Pow` + FoldingExp::Pow(_, 0) => Atom(Constant(::ScalarField::one())), + FoldingExp::Pow(e, 1) => e.into_compatible(), + FoldingExp::Pow(e, i) => { + let e = e.into_compatible(); + let mut acc = e.clone(); + for _ in 1..i { + acc = Mul(Box::new(e.clone()), Box::new(acc)) + } + acc + } + } + } +} + +/// Used to encode the sign of a term in a polynomial. +// FIXME: is it really needed? +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum Sign { + Pos, + Neg, +} + +impl std::ops::Neg for Sign { + type Output = Self; + + fn neg(self) -> Self { + match self { + Sign::Pos => Sign::Neg, + Sign::Neg => Sign::Pos, + } + } +} + +/// A term of a polynomial +/// For instance, in the polynomial `3 X_{1} X_{2} + 2 X_{3}`, the terms are +/// `3 X_{1} X_{2}` and `2 X_{3}`. +/// The sign is used to encode the sign of the term at the expression level. +/// It is used to split a polynomial in its terms/monomials of degree `0`, `1` +/// and `2`. +#[derive(Derivative)] +#[derivative(Debug, Clone(bound = "C: FoldingConfig"))] +pub struct Term { + pub exp: FoldingExp, + pub sign: Sign, +} + +impl Term { + fn double(self) -> Self { + let Self { exp, sign } = self; + let exp = FoldingExp::Double(Box::new(exp)); + Self { exp, sign } + } +} + +impl std::ops::Mul for &Term { + type Output = Term; + + fn mul(self, rhs: Self) -> Self::Output { + let sign = if self.sign == rhs.sign { + Sign::Pos + } else { + Sign::Neg + }; + let exp = FoldingExp::Mul(Box::new(self.exp.clone()), Box::new(rhs.exp.clone())); + Term { exp, sign } + } +} + +impl std::ops::Neg for Term { + type Output = Self; + + fn neg(self) -> Self::Output { + Term { + sign: -self.sign, + ..self + } + } +} + +/// A value of type [IntegratedFoldingExpr] is the result of the split of a +/// polynomial in its monomials of degree `0`, `1` and `2`. +/// It is used to compute the error terms. For an example, have a look at the +/// [top level documentation](super::expressions). +#[derive(Derivative)] +#[derivative( + Debug(bound = "C: FoldingConfig"), + Clone(bound = "C: FoldingConfig"), + Default(bound = "C: FoldingConfig") +)] +pub struct IntegratedFoldingExpr { + // (exp,sign,alpha) + pub(super) degree_0: Vec<(FoldingExp, Sign, usize)>, + pub(super) degree_1: Vec<(FoldingExp, Sign, usize)>, + pub(super) degree_2: Vec<(FoldingExp, Sign, usize)>, +} + +impl IntegratedFoldingExpr { + /// Combines constraints into single expression + pub fn final_expression(self) -> FoldingCompatibleExpr { + use FoldingCompatibleExpr::*; + /// TODO: should use powers of alpha + use FoldingCompatibleExprInner::*; + let Self { + degree_0, + degree_1, + degree_2, + } = self; + let [d0, d1, d2] = [degree_0, degree_1, degree_2] + .map(|exps| { + let init = + FoldingExp::Atom(ExtendedFoldingColumn::Constant(ScalarField::::zero())); + exps.into_iter().fold(init, |acc, (exp, sign, alpha)| { + let exp = FoldingExp::Mul( + Box::new(exp), + Box::new(FoldingExp::Atom(ExtendedFoldingColumn::Alpha(alpha))), + ); + match sign { + Sign::Pos => FoldingExp::Add(Box::new(acc), Box::new(exp)), + Sign::Neg => FoldingExp::Sub(Box::new(acc), Box::new(exp)), + } + }) + }) + .map(|e| e.into_compatible()); + let u = || Box::new(Atom(Extensions(ExpExtension::U))); + let u2 = || Box::new(Square(u())); + let d0 = FoldingCompatibleExpr::Mul(Box::new(d0), u2()); + let d1 = FoldingCompatibleExpr::Mul(Box::new(d1), u()); + let d2 = Box::new(d2); + let exp = FoldingCompatibleExpr::Add(Box::new(d0), Box::new(d1)); + let exp = FoldingCompatibleExpr::Add(Box::new(exp), d2); + FoldingCompatibleExpr::Add( + Box::new(exp), + Box::new(Atom(Extensions(ExpExtension::Error))), + ) + } +} + +pub fn extract_terms(exp: FoldingExp) -> Box>> { + use FoldingExp::*; + let exps: Box>> = match exp { + exp @ Atom(_) => Box::new( + [Term { + exp, + sign: Sign::Pos, + }] + .into_iter(), + ), + Double(exp) => Box::new(extract_terms(*exp).map(Term::double)), + Square(exp) => { + let terms = extract_terms(*exp).collect_vec(); + let mut combinations = Vec::with_capacity(terms.len() ^ 2); + for t1 in terms.iter() { + for t2 in terms.iter() { + combinations.push(t1 * t2) + } + } + Box::new(combinations.into_iter()) + } + Add(e1, e2) => { + let e1 = extract_terms(*e1); + let e2 = extract_terms(*e2); + Box::new(e1.chain(e2)) + } + Sub(e1, e2) => { + let e1 = extract_terms(*e1); + let e2 = extract_terms(*e2).map(|t| -t); + Box::new(e1.chain(e2)) + } + Mul(e1, e2) => { + let e1 = extract_terms(*e1).collect_vec(); + let e2 = extract_terms(*e2).collect_vec(); + let mut combinations = Vec::with_capacity(e1.len() * e2.len()); + for t1 in e1.iter() { + for t2 in e2.iter() { + combinations.push(t1 * t2) + } + } + Box::new(combinations.into_iter()) + } + Pow(_, 0) => Box::new( + [Term { + exp: FoldingExp::Atom(ExtendedFoldingColumn::Constant( + ::ScalarField::one(), + )), + sign: Sign::Pos, + }] + .into_iter(), + ), + Pow(e, 1) => extract_terms(*e), + Pow(e, mut i) => { + let e = extract_terms(*e).collect_vec(); + let mut acc = e.clone(); + // Could do this inplace, but it's more annoying to write + while i > 2 { + let mut combinations = Vec::with_capacity(e.len() * acc.len()); + for t1 in e.iter() { + for t2 in acc.iter() { + combinations.push(t1 * t2) + } + } + acc = combinations; + i -= 1; + } + Box::new(acc.into_iter()) + } + }; + exps +} + +/// Convert a list of folding compatible expression into the folded form. +pub fn folding_expression( + exps: Vec>, +) -> (IntegratedFoldingExpr, ExtendedWitnessGenerator, usize) { + let simplified_expressions = exps.into_iter().map(|exp| exp.simplify()).collect_vec(); + let ( + Quadraticized { + original_constraints: expressions, + extra_constraints: extra_expressions, + extended_witness_generator, + }, + added_columns, + ) = quadraticize(simplified_expressions); + let mut terms = vec![]; + let mut alpha = 0; + // Alpha is always increased, equal to the total number of + // expressions. We could optimise it and only assign increasing + // alphas in "blocks" that depend on selectors. This would make + // #alphas equal to the expressions in the biggest block (+ some + // columns common for all blocks of the circuit). + for exp in expressions.into_iter() { + terms.extend(extract_terms(exp).map(|term| (term, alpha))); + alpha += 1; + } + for exp in extra_expressions.into_iter() { + terms.extend(extract_terms(exp).map(|term| (term, alpha))); + alpha += 1; + } + let mut integrated = IntegratedFoldingExpr::default(); + for (term, alpha) in terms.into_iter() { + let Term { exp, sign } = term; + let degree = exp.folding_degree(); + let t = (exp, sign, alpha); + match degree { + Degree::Zero => integrated.degree_0.push(t), + Degree::One => integrated.degree_1.push(t), + Degree::Two => integrated.degree_2.push(t), + } + } + (integrated, extended_witness_generator, added_columns) +} + +// CONVERSIONS FROM EXPR TO FOLDING COMPATIBLE EXPRESSIONS + +impl From> + for FoldingCompatibleExprInner +where + Config::Curve: AffineRepr, + Config::Challenge: From, +{ + fn from(expr: ConstantExprInner) -> Self { + match expr { + ConstantExprInner::Challenge(chal) => { + FoldingCompatibleExprInner::Challenge(chal.into()) + } + ConstantExprInner::Constant(c) => match c { + ConstantTerm::Literal(f) => FoldingCompatibleExprInner::Constant(f), + ConstantTerm::EndoCoefficient | ConstantTerm::Mds { row: _, col: _ } => { + panic!("When special constants are involved, don't forget to simplify the expression before.") + } + }, + } + } +} + +impl> + From, Col>> + for FoldingCompatibleExprInner +where + Config::Curve: AffineRepr, + Config::Challenge: From, +{ + // TODO: check if this needs some special treatment for Extensions + fn from(expr: ExprInner, Col>) -> Self { + match expr { + ExprInner::Constant(cexpr) => cexpr.into(), + ExprInner::Cell(col) => FoldingCompatibleExprInner::Cell(col), + ExprInner::UnnormalizedLagrangeBasis(_) => { + panic!("UnnormalizedLagrangeBasis should not be used in folding expressions") + } + ExprInner::VanishesOnZeroKnowledgeAndPreviousRows => { + panic!("VanishesOnZeroKnowledgeAndPreviousRows should not be used in folding expressions") + } + } + } +} + +impl> + From, Col>>> + for FoldingCompatibleExpr +where + Config::Curve: AffineRepr, + Config::Challenge: From, +{ + fn from(expr: Operations, Col>>) -> Self { + match expr { + Operations::Atom(inner) => FoldingCompatibleExpr::Atom(inner.into()), + Operations::Add(x, y) => { + FoldingCompatibleExpr::Add(Box::new((*x).into()), Box::new((*y).into())) + } + Operations::Mul(x, y) => { + FoldingCompatibleExpr::Mul(Box::new((*x).into()), Box::new((*y).into())) + } + Operations::Sub(x, y) => { + FoldingCompatibleExpr::Sub(Box::new((*x).into()), Box::new((*y).into())) + } + Operations::Double(x) => FoldingCompatibleExpr::Double(Box::new((*x).into())), + Operations::Square(x) => FoldingCompatibleExpr::Square(Box::new((*x).into())), + Operations::Pow(e, p) => FoldingCompatibleExpr::Pow(Box::new((*e).into()), p), + _ => panic!("Operation not supported in folding expressions"), + } + } +} + +impl> + From>> for FoldingCompatibleExpr +where + Config::Curve: AffineRepr, + Config::Challenge: From, +{ + fn from(expr: Operations>) -> Self { + match expr { + Operations::Add(x, y) => { + FoldingCompatibleExpr::Add(Box::new((*x).into()), Box::new((*y).into())) + } + Operations::Mul(x, y) => { + FoldingCompatibleExpr::Mul(Box::new((*x).into()), Box::new((*y).into())) + } + Operations::Sub(x, y) => { + FoldingCompatibleExpr::Sub(Box::new((*x).into()), Box::new((*y).into())) + } + Operations::Double(x) => FoldingCompatibleExpr::Double(Box::new((*x).into())), + Operations::Square(x) => FoldingCompatibleExpr::Square(Box::new((*x).into())), + Operations::Pow(e, p) => FoldingCompatibleExpr::Pow(Box::new((*e).into()), p), + _ => panic!("Operation not supported in folding expressions"), + } + } +} + +impl> + From>, Col>>> + for FoldingCompatibleExpr +where + Config::Curve: AffineRepr, + Config::Challenge: From, +{ + fn from( + expr: Operations>, Col>>, + ) -> Self { + match expr { + Operations::Atom(inner) => match inner { + ExprInner::Constant(op) => match op { + // The constant expressions nodes are considered as top level + // expressions in folding + Operations::Atom(inner) => FoldingCompatibleExpr::Atom(inner.into()), + Operations::Add(x, y) => { + FoldingCompatibleExpr::Add(Box::new((*x).into()), Box::new((*y).into())) + } + Operations::Mul(x, y) => { + FoldingCompatibleExpr::Mul(Box::new((*x).into()), Box::new((*y).into())) + } + Operations::Sub(x, y) => { + FoldingCompatibleExpr::Sub(Box::new((*x).into()), Box::new((*y).into())) + } + Operations::Double(x) => FoldingCompatibleExpr::Double(Box::new((*x).into())), + Operations::Square(x) => FoldingCompatibleExpr::Square(Box::new((*x).into())), + Operations::Pow(e, p) => FoldingCompatibleExpr::Pow(Box::new((*e).into()), p), + _ => panic!("Operation not supported in folding expressions"), + }, + ExprInner::Cell(col) => { + FoldingCompatibleExpr::Atom(FoldingCompatibleExprInner::Cell(col)) + } + ExprInner::UnnormalizedLagrangeBasis(_) => { + panic!("UnnormalizedLagrangeBasis should not be used in folding expressions") + } + ExprInner::VanishesOnZeroKnowledgeAndPreviousRows => { + panic!("VanishesOnZeroKnowledgeAndPreviousRows should not be used in folding expressions") + } + }, + Operations::Add(x, y) => { + FoldingCompatibleExpr::Add(Box::new((*x).into()), Box::new((*y).into())) + } + Operations::Mul(x, y) => { + FoldingCompatibleExpr::Mul(Box::new((*x).into()), Box::new((*y).into())) + } + Operations::Sub(x, y) => { + FoldingCompatibleExpr::Sub(Box::new((*x).into()), Box::new((*y).into())) + } + Operations::Double(x) => FoldingCompatibleExpr::Double(Box::new((*x).into())), + Operations::Square(x) => FoldingCompatibleExpr::Square(Box::new((*x).into())), + Operations::Pow(e, p) => FoldingCompatibleExpr::Pow(Box::new((*e).into()), p), + _ => panic!("Operation not supported in folding expressions"), + } + } +} diff --git a/folding/src/instance_witness.rs b/folding/src/instance_witness.rs new file mode 100644 index 0000000000..52626a84c0 --- /dev/null +++ b/folding/src/instance_witness.rs @@ -0,0 +1,515 @@ +//! This module defines a list of traits and structures that are used by the +//! folding scheme. +//! The folding library is built over generic traits like [Instance] and +//! [Witness] that defines the the NP relation R. +//! +//! This module describes 3 different types of instance/witness pairs: +//! - [Instance] and [Witness]: the original instance and witness. These are the +//! ones that the user must provide. +//! - [ExtendedInstance] and [ExtendedWitness]: the instance and witness +//! extended by quadraticization. +//! - [RelaxedInstance] and [RelaxedWitness]: the instance and witness related +//! to the relaxed/homogeneous polynomials. +//! +//! Note that [Instance], [ExtendedInstance] and [RelaxedInstance] are +//! supposed to be used to encapsulate the public inputs and challenges. It is +//! the common information the prover and verifier have. +//! [Witness], [ExtendedWitness] and [RelaxedWitness] are supposed to be used +//! to encapsulate the private inputs. For instance, it is the evaluations of +//! the polynomials. +//! +//! A generic trait [Foldable] is defined to combine two objects of the same +//! type using a challenge. + +// FIXME: for optimisation, as values are not necessarily Fp elements and are +// relatively small, we could get rid of the scalar field objects, and only use +// bigint where we only apply the modulus when needed. + +use crate::{Alphas, Evals}; +use ark_ff::Field; +use num_traits::One; +use poly_commitment::commitment::{CommitmentCurve, PolyComm}; +use std::collections::BTreeMap; + +pub trait Foldable { + /// Combine two objects 'a' and 'b' into a new object using the challenge. + // FIXME: rename in fold2 + fn combine(a: Self, b: Self, challenge: F) -> Self; +} + +pub trait Instance: Sized + Foldable { + /// This method returns the scalars and commitments that must be absorbed by + /// the sponge. It is not supposed to do any absorption itself, and the user + /// is responsible for calling the sponge absorb methods with the elements + /// returned by this method. + /// When called on a RelaxedInstance, elements will be returned in the + /// following order, for given instances L and R + /// ```text + /// scalar = L.to_absorb().0 | L.u | R.to_absorb().0 | R.u + /// points_l = L.to_absorb().1 | L.extended | L.error // where extended is the commitments to the extra columns + /// points_r = R.to_absorb().1 | R.extended | R.error // where extended is the commitments to the extra columns + /// t_0 and t_1 first and second error terms + /// points = points_l | points_r | t_0 | t_1 + /// ``` + /// A user implementing the IVC circuit should absorb the elements in the + /// following order: + /// ```text + /// sponge.absorb_fr(scalar); // absorb the scalar elements + /// sponge.absorb_g(points); // absorb the commitments + /// ``` + /// This is the order used by the folding library in the method + /// `fold_instance_witness_pair`. + /// From there, a challenge can be coined using: + /// ```text + /// let challenge_r = sponge.challenge(); + /// ``` + fn to_absorb(&self) -> (Vec, Vec); + + /// Returns the alphas values for the instance + fn get_alphas(&self) -> &Alphas; + + /// Return the blinder that can be used while committing to polynomials. + fn get_blinder(&self) -> G::ScalarField; +} + +pub trait Witness: Sized + Foldable {} + +// -- Structures that consist of extending the original instance and witness +// -- with the extra columns added by quadraticization. + +impl> ExtendedWitness { + /// This method returns an extended witness which is defined as the witness itself, + /// followed by an empty BTreeMap. + /// The map will be later filled by the quadraticization witness generator. + fn extend(witness: W) -> ExtendedWitness { + let extended = BTreeMap::new(); + ExtendedWitness { witness, extended } + } +} + +impl> ExtendedInstance { + /// This method returns an extended instance which is defined as the instance itself, + /// followed by an empty vector. + fn extend(instance: I) -> ExtendedInstance { + ExtendedInstance { + instance, + extended: vec![], + } + } +} + +// -- Extended witness +/// This structure represents a witness extended with extra columns that are +/// added by quadraticization +#[derive(Clone, Debug)] +pub struct ExtendedWitness> { + /// This is the original witness, without quadraticization + pub witness: W, + /// Extra columns added by quadraticization to lower the degree of + /// expressions to 2 + pub extended: BTreeMap>, +} + +impl> Foldable for ExtendedWitness { + fn combine(a: Self, b: Self, challenge: ::ScalarField) -> Self { + let Self { + witness: witness1, + extended: ex1, + } = a; + let Self { + witness: witness2, + extended: ex2, + } = b; + // We fold the original witness + let witness = W::combine(witness1, witness2, challenge); + // And we fold the columns created by quadraticization. + // W <- W1 + c W2 + let extended = ex1 + .into_iter() + .zip(ex2) + .map(|(a, b)| { + let (i, mut evals) = a; + assert_eq!(i, b.0); + evals + .evals + .iter_mut() + .zip(b.1.evals) + .for_each(|(a, b)| *a += b * challenge); + (i, evals) + }) + .collect(); + Self { witness, extended } + } +} + +impl> Witness for ExtendedWitness {} + +impl> ExtendedWitness { + pub(crate) fn add_witness_evals(&mut self, i: usize, evals: Evals) { + self.extended.insert(i, evals); + } + + /// Return true if the no extra columns are added by quadraticization + /// + /// Can be used to know if the extended witness columns are already + /// computed, to avoid overriding them + pub fn is_extended(&self) -> bool { + !self.extended.is_empty() + } +} + +// -- Extended instance +/// An extended instance is an instance that has been extended with extra +/// columns by quadraticization. +/// The original instance is stored in the `instance` field. +/// The extra columns are stored in the `extended` field. +/// For instance, if the original instance has `n` columns, and the system is +/// described by a degree 3 polynomial, an additional column will be added, and +/// `extended` will contain `1` commitment. +// FIXME: We should forbid cloning, for memory footprint. +#[derive(PartialEq, Eq, Clone)] +pub struct ExtendedInstance> { + /// The original instance. + pub instance: I, + /// Commitments to the extra columns added by quadraticization + pub extended: Vec>, +} + +impl> Foldable for ExtendedInstance { + fn combine(a: Self, b: Self, challenge: ::ScalarField) -> Self { + let Self { + instance: instance1, + extended: ex1, + } = a; + let Self { + instance: instance2, + extended: ex2, + } = b; + // Combining first the existing commitments (i.e. not the one added by + // quadraticization) + // They are supposed to be blinded already + let instance = I::combine(instance1, instance2, challenge); + // For each commitment, compute + // Comm(W) + c * Comm(W') + let extended = ex1 + .into_iter() + .zip(ex2) + .map(|(a, b)| &a + &b.scale(challenge)) + .collect(); + Self { instance, extended } + } +} + +impl> Instance for ExtendedInstance { + /// Return the elements to be absorbed by the sponge + /// + /// The commitments to the additional columns created by quadriticization + /// are appended to the existing commitments of the initial instance + /// to be absorbed. The scalars are unchanged. + fn to_absorb(&self) -> (Vec, Vec) { + let mut elements = self.instance.to_absorb(); + let extended_commitments = self.extended.iter().map(|commit| { + assert_eq!(commit.len(), 1); + commit.get_first_chunk() + }); + elements.1.extend(extended_commitments); + elements + } + + fn get_alphas(&self) -> &Alphas { + self.instance.get_alphas() + } + + /// Returns the blinder value. It is the same as the one of the original + fn get_blinder(&self) -> G::ScalarField { + self.instance.get_blinder() + } +} + +// -- "Relaxed"/"Homogenized" structures + +/// A relaxed instance is an instance that has been relaxed by the folding scheme. +/// It contains the original instance, extended with the columns added by +/// quadriticization, the scalar `u` and a commitment to the +/// slack/error term. +/// See page 15 of [Nova](https://eprint.iacr.org/2021/370.pdf). +// FIXME: We should forbid cloning, for memory footprint. +#[derive(PartialEq, Eq, Clone)] +pub struct RelaxedInstance> { + /// The original instance, extended with the columns added by + /// quadriticization + pub extended_instance: ExtendedInstance, + /// The scalar `u` that is used to homogenize the polynomials + pub u: G::ScalarField, + /// The commitment to the error term, introduced when homogenizing the + /// polynomials + pub error_commitment: PolyComm, + /// Blinder used for the commitments to the cross terms + pub blinder: G::ScalarField, +} + +impl> RelaxedInstance { + /// Returns the elements to be absorbed by the sponge + /// + /// The scalar elements of the are appended with the scalar `u` and the + /// commitments are appended by the commitment to the error term. + pub fn to_absorb(&self) -> (Vec, Vec) { + let mut elements = self.extended_instance.to_absorb(); + elements.0.push(self.u); + assert_eq!(self.error_commitment.len(), 1); + elements.1.push(self.error_commitment.get_first_chunk()); + elements + } + + /// Provides access to commitments to the extra columns added by + /// quadraticization + pub fn get_extended_column_commitment(&self, i: usize) -> Option<&PolyComm> { + self.extended_instance.extended.get(i) + } + + /// Combining the commitments of each instance and adding the cross terms + /// into the error term. + /// This corresponds to the computation `E <- E1 - c T1 - c^2 T2 + c^3 E2`. + /// As we do support folding of degree 3, we have two cross terms `T1` and + /// `T2`. + /// For more information, see the [top-level + /// documentation](crate::expressions). + pub(super) fn combine_and_sub_cross_terms( + a: Self, + b: Self, + challenge: ::ScalarField, + cross_terms: &[PolyComm; 2], + ) -> Self { + // Compute E1 + c^3 E2 and all other folding of commitments. The + // resulting error commitment is stored in res.commitment. + let mut res = Self::combine(a, b, challenge); + let [t0, t1] = cross_terms; + // Eq 4, page 15 of the Nova paper + // Computing (E1 + c^3 E2) - c T1 - c^2 T2 + res.error_commitment = + &res.error_commitment - (&(&t0.scale(challenge) + &t1.scale(challenge.square()))); + res + } +} + +/// A relaxed instance can be folded. +impl> Foldable for RelaxedInstance { + /// Combine two relaxed instances into a new relaxed instance. + fn combine(a: Self, b: Self, challenge: ::ScalarField) -> Self { + // We do support degree 3 folding, therefore, we must compute: + // E <- E1 - (c T1 + c^2 T2) + c^3 E2 + // (page 15, eq 3 of the Nova paper) + // The term T1 and T2 are the cross terms + let challenge_square = challenge * challenge; + let challenge_cube = challenge_square * challenge; + let RelaxedInstance { + extended_instance: extended_instance_1, + u: u1, + error_commitment: e1, + blinder: blinder1, + } = a; + let RelaxedInstance { + extended_instance: extended_instance_2, + u: u2, + error_commitment: e2, + blinder: blinder2, + } = b; + // We simply fold the blinders + // = 1 = 1 + // r_E <- r_E1 + c r_T1 + c^2 r_T2 + c^3 r_E2 + let blinder = blinder1 + challenge + challenge_square + challenge_cube * blinder2; + let extended_instance = + >::combine(extended_instance_1, extended_instance_2, challenge); + // Combining the challenges + // eq 3, page 15 of the Nova paper + let u = u1 + u2 * challenge; + // We do have 2 cross terms as we have degree 3 folding + // e1 + c^3 e^2 + let error_commitment = &e1 + &e2.scale(challenge_cube); + RelaxedInstance { + // I <- I1 + c I2 + extended_instance, + // u <- u1 + c u2 + u, + // E <- E1 - (c T1 + c^2 T2) + c^3 E2 + error_commitment, + blinder, + } + } +} + +// -- Relaxed witnesses +#[derive(Clone, Debug)] +pub struct RelaxedWitness> { + /// The original witness, extended with the columns added by + /// quadriticization. + pub extended_witness: ExtendedWitness, + /// The error vector, introduced when homogenizing the polynomials. + /// For degree 3 folding, it is `E1 - c T1 - c^2 T2 + c^3 E2` + pub error_vec: Evals, +} + +impl> RelaxedWitness { + /// Combining the existing error terms with the cross-terms T1 and T2 given + /// as parameters. + /// existing error terms cross terms + /// /--------------------\ /-------------\ + /// The result is ` E1 + c^3 E2 - (c T1 + c^2 T2)` + /// We do have two cross terms as we work with homogeneous polynomials of + /// degree 3. The value is saved into the field `error_vec` of the relaxed + /// witness. + /// This corresponds to the step 4, page 15 of the Nova paper, but with two + /// cross terms (T1 and T2), see [top-level + /// documentation](crate::expressions). + pub(super) fn combine_and_sub_cross_terms( + a: Self, + b: Self, + challenge: ::ScalarField, + cross_terms: [Vec; 2], + ) -> Self { + // Computing E1 + c^3 E2 + let mut res = Self::combine(a, b, challenge); + + // Now substracting the cross terms + let [e0, e1] = cross_terms; + + for (res, (e0, e1)) in res + .error_vec + .evals + .iter_mut() + .zip(e0.into_iter().zip(e1.into_iter())) + { + // FIXME: for optimisation, use inplace operators. Allocating can be + // costly + // should be the same as e0 * c + e1 * c^2 + *res -= ((e1 * challenge) + e0) * challenge; + } + res + } + + /// Provides access to the extra columns added by quadraticization + pub fn get_extended_column(&self, i: &usize) -> Option<&Evals> { + self.extended_witness.extended.get(i) + } +} + +/// A relaxed/homogenized witness can be folded. +impl> Foldable for RelaxedWitness { + fn combine(a: Self, b: Self, challenge: ::ScalarField) -> Self { + let RelaxedWitness { + extended_witness: a, + error_vec: mut e1, + } = a; + let RelaxedWitness { + extended_witness: b, + error_vec: e2, + } = b; + // We combine E1 and E2 into E1 + c^3 E2 as we do have two cross-terms + // with degree 3 folding + let challenge_cube = (challenge * challenge) * challenge; + let extended_witness = >::combine(a, b, challenge); + for (a, b) in e1.evals.iter_mut().zip(e2.evals.into_iter()) { + *a += b * challenge_cube; + } + let error_vec = e1; + RelaxedWitness { + extended_witness, + error_vec, + } + } +} + +// -- Relaxable instance +pub trait RelaxableInstance> { + fn relax(self) -> RelaxedInstance; +} + +impl> RelaxableInstance for I { + /// This method takes an Instance and a commitment to zero and extends the + /// instance, returning a relaxed instance which is composed by the extended + /// instance, the scalar one, and the error commitment which is set to the + /// commitment to zero. + fn relax(self) -> RelaxedInstance { + let extended_instance = ExtendedInstance::extend(self); + let blinder = extended_instance.instance.get_blinder(); + let u = G::ScalarField::one(); + let error_commitment = PolyComm::new(vec![G::zero()]); + RelaxedInstance { + extended_instance, + u, + error_commitment, + blinder, + } + } +} + +/// A relaxed instance is trivially relaxable. +impl> RelaxableInstance for RelaxedInstance { + fn relax(self) -> RelaxedInstance { + self + } +} + +/// Trait to make a witness relaxable/homogenizable +pub trait RelaxableWitness> { + fn relax(self, zero_poly: &Evals) -> RelaxedWitness; +} + +impl> RelaxableWitness for W { + /// This method takes a witness and a vector of evaluations to the zero + /// polynomial, returning a relaxed witness which is composed by the + /// extended witness and the error vector that is set to the zero + /// polynomial. + fn relax(self, zero_poly: &Evals) -> RelaxedWitness { + let extended_witness = ExtendedWitness::extend(self); + RelaxedWitness { + extended_witness, + error_vec: zero_poly.clone(), + } + } +} + +impl> RelaxableWitness for RelaxedWitness { + fn relax(self, _zero_poly: &Evals) -> RelaxedWitness { + self + } +} + +pub trait RelaxablePair, W: Witness> { + fn relax( + self, + zero_poly: &Evals, + ) -> (RelaxedInstance, RelaxedWitness); +} + +impl RelaxablePair for (I, W) +where + G: CommitmentCurve, + I: Instance + RelaxableInstance, + W: Witness + RelaxableWitness, +{ + fn relax( + self, + zero_poly: &Evals, + ) -> (RelaxedInstance, RelaxedWitness) { + let (instance, witness) = self; + ( + RelaxableInstance::relax(instance), + RelaxableWitness::relax(witness, zero_poly), + ) + } +} + +impl RelaxablePair for (RelaxedInstance, RelaxedWitness) +where + G: CommitmentCurve, + I: Instance + RelaxableInstance, + W: Witness + RelaxableWitness, +{ + fn relax( + self, + _zero_poly: &Evals, + ) -> (RelaxedInstance, RelaxedWitness) { + self + } +} diff --git a/folding/src/lib.rs b/folding/src/lib.rs new file mode 100644 index 0000000000..1cd94a940e --- /dev/null +++ b/folding/src/lib.rs @@ -0,0 +1,468 @@ +//! This library implements basic components to fold computations expressed as +//! multivariate polynomials of any degree. It is based on the "folding scheme" +//! described in the [Nova](https://eprint.iacr.org/2021/370.pdf) paper. +//! It implements different components to achieve it: +//! - [quadraticization]: a submodule to reduce multivariate polynomials +//! to degree `2`. +//! - [decomposable_folding]: a submodule to "parallelize" folded +//! computations. +//! +//! Examples can be found in the directory `examples`. +//! +//! The folding library is meant to be used in harmony with the library `ivc`. +//! To use the library, the user has to define first a "folding configuration" +//! described in the trait [FoldingConfig]. +//! After that, the user can provide folding compatible expressions and build a +//! folding scheme [FoldingScheme]. The process is described in the module +//! [expressions]. +// TODO: the documentation above might need more descriptions. + +use ark_ec::AffineRepr; +use ark_ff::{Field, One, Zero}; +use ark_poly::{EvaluationDomain, Evaluations, Radix2EvaluationDomain}; +use error_term::{compute_error, ExtendedEnv}; +use expressions::{folding_expression, FoldingColumnTrait, IntegratedFoldingExpr}; +use instance_witness::{Foldable, RelaxableInstance, RelaxablePair}; +use kimchi::circuits::gate::CurrOrNext; +use mina_poseidon::FqSponge; +use poly_commitment::{commitment::CommitmentCurve, PolyComm, SRS}; +use quadraticization::ExtendedWitnessGenerator; +use std::{ + fmt::Debug, + hash::Hash, + iter::successors, + rc::Rc, + sync::atomic::{AtomicUsize, Ordering}, +}; + +// Make available outside the crate to avoid code duplication +pub use error_term::Side; +pub use expressions::{ExpExtension, FoldingCompatibleExpr}; +pub use instance_witness::{Instance, RelaxedInstance, RelaxedWitness, Witness}; + +pub mod columns; +pub mod decomposable_folding; + +mod error_term; + +pub mod eval_leaf; +pub mod expressions; +pub mod instance_witness; +pub mod quadraticization; +pub mod standard_config; + +/// Define the different structures required for the examples (both internal and +/// external) +pub mod checker; + +// Simple type alias as ScalarField/BaseField is often used. Reduce type +// complexity for clippy. +// Should be moved into FoldingConfig, but associated type defaults are unstable +// at the moment. +type ScalarField = <::Curve as AffineRepr>::ScalarField; +type BaseField = <::Curve as AffineRepr>::BaseField; + +// 'static seems to be used for expressions. Can we get rid of it? +pub trait FoldingConfig: Debug + 'static { + type Column: FoldingColumnTrait + Debug + Eq + Hash; + + // in case of using docomposable folding, if not it can be just () + type Selector: Clone + Debug + Eq + Hash + Copy + Ord + PartialOrd; + + /// The type of an abstract challenge that can be found in the expressions + /// provided as constraints. + type Challenge: Clone + Copy + Debug + Eq + Hash; + + /// The target curve used by the polynomial commitment + type Curve: CommitmentCurve; + + /// The SRS used by the polynomial commitment. The SRS is used to commit to + /// the additional columns that are added by the quadraticization. + type Srs: SRS; + + /// For Plonk, it will be the commitments to the polynomials and the challenges + type Instance: Instance + Clone; + + /// For PlonK, it will be the polynomials in evaluation form that we commit + /// to, i.e. the columns. + /// In the generic prover/verifier, it would be `kimchi_msm::witness::Witness`. + type Witness: Witness + Clone; + + type Structure: Clone; + + type Env: FoldingEnv< + ::ScalarField, + Self::Instance, + Self::Witness, + Self::Column, + Self::Challenge, + Self::Selector, + Structure = Self::Structure, + >; +} + +/// Describe a folding environment. +/// The type parameters are: +/// - `F`: The field of the circuit/computation +/// - `I`: The instance type, i.e the public inputs +/// - `W`: The type of the witness, i.e. the private inputs +/// - `Col`: The type of the column +/// - `Chal`: The type of the challenge +/// - `Selector`: The type of the selector +pub trait FoldingEnv { + /// Structure which could be storing useful information like selectors, etc. + type Structure; + + /// Creates a new environment storing the structure, instances and + /// witnesses. + fn new(structure: &Self::Structure, instances: [&I; 2], witnesses: [&W; 2]) -> Self; + + /// Obtains a given challenge from the expanded instance for one side. + /// The challenges are stored inside the instances structs. + fn challenge(&self, challenge: Chal, side: Side) -> F; + + /// Returns the evaluations of a given column witness at omega or zeta*omega. + fn col(&self, col: Col, curr_or_next: CurrOrNext, side: Side) -> &[F]; + + /// similar to [Self::col], but folding may ask for a dynamic selector directly + /// instead of just column that happens to be a selector + fn selector(&self, s: &Selector, side: Side) -> &[F]; +} + +type Evals = Evaluations>; + +pub struct FoldingScheme<'a, CF: FoldingConfig> { + pub expression: IntegratedFoldingExpr, + pub srs: &'a CF::Srs, + pub domain: Radix2EvaluationDomain>, + pub zero_vec: Evals>, + pub structure: CF::Structure, + pub extended_witness_generator: ExtendedWitnessGenerator, + quadraticization_columns: usize, +} + +impl<'a, CF: FoldingConfig> FoldingScheme<'a, CF> { + pub fn new( + constraints: Vec>, + srs: &'a CF::Srs, + domain: Radix2EvaluationDomain>, + structure: &CF::Structure, + ) -> (Self, FoldingCompatibleExpr) { + let (expression, extended_witness_generator, quadraticization_columns) = + folding_expression(constraints); + let zero = >::zero(); + let evals = std::iter::repeat(zero).take(domain.size()).collect(); + let zero_vec = Evaluations::from_vec_and_domain(evals, domain); + let final_expression = expression.clone().final_expression(); + let scheme = Self { + expression, + srs, + domain, + zero_vec, + structure: structure.clone(), + extended_witness_generator, + quadraticization_columns, + }; + (scheme, final_expression) + } + + /// Return the number of additional columns added by quadraticization + pub fn get_number_of_additional_columns(&self) -> usize { + self.quadraticization_columns + } + + /// This is the main entry point to fold two instances and their witnesses. + /// The process is as follows: + /// - Both pairs are relaxed. + /// - Both witnesses and instances are extended, i.e. all polynomials are + /// reduced to degree 2 and additional constraints are added to the + /// expression. + /// - While computing the commitments to the additional columns, the + /// commitments are added into a list to absorb them into the sponge later. + /// - The error terms are computed and committed. + /// - The sponge absorbs the commitments and challenges. + #[allow(clippy::type_complexity)] + pub fn fold_instance_witness_pair( + &self, + a: A, + b: B, + fq_sponge: &mut Sponge, + ) -> FoldingOutput + where + A: RelaxablePair, + B: RelaxablePair, + Sponge: FqSponge, CF::Curve, ScalarField>, + { + let a = a.relax(&self.zero_vec); + let b = b.relax(&self.zero_vec); + + let u = (a.0.u, b.0.u); + + let (left_instance, left_witness) = a; + let (right_instance, right_witness) = b; + let env = ExtendedEnv::new( + &self.structure, + [left_instance, right_instance], + [left_witness, right_witness], + self.domain, + None, + ); + // Computing the additional columns, resulting of the quadritization + // process. + // Side-effect: commitments are added in both relaxed (extended) instance. + let env: ExtendedEnv = + env.compute_extension(&self.extended_witness_generator, self.srs); + + // Computing the error terms + let error: [Vec>; 2] = compute_error(&self.expression, &env, u); + let error_evals = error.map(|e| Evaluations::from_vec_and_domain(e, self.domain)); + + // Committing to the cross terms + // Default blinder for commiting to the cross terms + let blinders = PolyComm::new(vec![ScalarField::::one()]); + let error_commitments = error_evals + .iter() + .map(|e| { + self.srs + .commit_evaluations_custom(self.domain, e, &blinders) + .unwrap() + .commitment + }) + .collect::>(); + let error_commitments: [PolyComm; 2] = error_commitments.try_into().unwrap(); + + let error: [Vec<_>; 2] = error_evals.map(|e| e.evals); + + // sanity check to verify that we only have one commitment in polycomm + // (i.e. domain = poly size) + assert_eq!(error_commitments[0].len(), 1); + assert_eq!(error_commitments[1].len(), 1); + + let t_0 = &error_commitments[0].get_first_chunk(); + let t_1 = &error_commitments[1].get_first_chunk(); + + // Absorbing the commitments into the sponge + let to_absorb = env.to_absorb(t_0, t_1); + + fq_sponge.absorb_fr(&to_absorb.0); + fq_sponge.absorb_g(&to_absorb.1); + + let challenge = fq_sponge.challenge(); + + let ( + [relaxed_extended_left_instance, relaxed_extended_right_instance], + [relaxed_extended_left_witness, relaxed_extended_right_witness], + ) = env.unwrap(); + + let folded_instance = RelaxedInstance::combine_and_sub_cross_terms( + // FIXME: remove clone + relaxed_extended_left_instance.clone(), + relaxed_extended_right_instance.clone(), + challenge, + &error_commitments, + ); + + let folded_witness = RelaxedWitness::combine_and_sub_cross_terms( + relaxed_extended_left_witness, + relaxed_extended_right_witness, + challenge, + error, + ); + FoldingOutput { + folded_instance, + folded_witness, + t_0: error_commitments[0].clone(), + t_1: error_commitments[1].clone(), + relaxed_extended_left_instance, + relaxed_extended_right_instance, + to_absorb, + } + } + + /// Fold two relaxable instances into a relaxed instance. + /// It is parametrized by two different types `A` and `B` that represent + /// "relaxable" instances to be able to fold a normal and "already relaxed" + /// instance. + pub fn fold_instance_pair( + &self, + a: A, + b: B, + error_commitments: [PolyComm; 2], + fq_sponge: &mut Sponge, + ) -> RelaxedInstance + where + A: RelaxableInstance, + B: RelaxableInstance, + Sponge: FqSponge, CF::Curve, ScalarField>, + { + let a: RelaxedInstance = a.relax(); + let b: RelaxedInstance = b.relax(); + + // sanity check to verify that we only have one commitment in polycomm + // (i.e. domain = poly size) + assert_eq!(error_commitments[0].len(), 1); + assert_eq!(error_commitments[1].len(), 1); + + let to_absorb = { + let mut left = a.to_absorb(); + let right = b.to_absorb(); + left.0.extend(right.0); + left.1.extend(right.1); + left.1.extend([ + error_commitments[0].get_first_chunk(), + error_commitments[1].get_first_chunk(), + ]); + left + }; + + fq_sponge.absorb_fr(&to_absorb.0); + fq_sponge.absorb_g(&to_absorb.1); + + let challenge = fq_sponge.challenge(); + + RelaxedInstance::combine_and_sub_cross_terms(a, b, challenge, &error_commitments) + } + + #[allow(clippy::type_complexity)] + /// Verifier of the folding scheme; returns a new folded instance, + /// which can be then compared with the one claimed to be the real + /// one. + pub fn verify_fold( + &self, + left_instance: RelaxedInstance, + right_instance: RelaxedInstance, + t_0: PolyComm, + t_1: PolyComm, + fq_sponge: &mut Sponge, + ) -> RelaxedInstance + where + Sponge: FqSponge, CF::Curve, ScalarField>, + { + let to_absorb = { + let mut left = left_instance.to_absorb(); + let right = right_instance.to_absorb(); + left.0.extend(right.0); + left.1.extend(right.1); + left.1 + .extend([t_0.get_first_chunk(), t_1.get_first_chunk()]); + left + }; + + fq_sponge.absorb_fr(&to_absorb.0); + fq_sponge.absorb_g(&to_absorb.1); + + let challenge = fq_sponge.challenge(); + + RelaxedInstance::combine_and_sub_cross_terms( + // FIXME: remove clone + left_instance.clone(), + right_instance.clone(), + challenge, + &[t_0, t_1], + ) + } +} + +/// Output of the folding prover +pub struct FoldingOutput { + /// The folded instance, containing, in particular, the result `C_l + r C_r` + pub folded_instance: RelaxedInstance, + /// Folded witness, containing, in particular, the result of the evaluations + /// `W_l + r W_r` + pub folded_witness: RelaxedWitness, + /// The error terms of degree 1, see the top-level documentation of + /// [crate::expressions] + pub t_0: PolyComm, + /// The error terms of degree 2, see the top-level documentation of + /// [crate::expressions] + pub t_1: PolyComm, + /// The left relaxed instance, including the potential additional columns + /// added by quadritization + pub relaxed_extended_left_instance: RelaxedInstance, + /// The right relaxed instance, including the potential additional columns + /// added by quadritization + pub relaxed_extended_right_instance: RelaxedInstance, + /// Elements to absorbed in IVC, in the same order as done in folding + pub to_absorb: (Vec>, Vec), +} + +impl FoldingOutput { + #[allow(clippy::type_complexity)] + pub fn pair( + self, + ) -> ( + RelaxedInstance, + RelaxedWitness, + ) { + (self.folded_instance, self.folded_witness) + } +} + +/// Combinators that will be used to fold the constraints, +/// called the "alphas". +/// The alphas are exceptional, their number cannot be known ahead of time as it +/// will be defined by folding. +/// The values will be computed as powers in new instances, but after folding +/// each alpha will be a linear combination of other alphas, instand of a power +/// of other element. This type represents that, allowing to also recognize +/// which case is present. +#[derive(Debug, Clone)] +pub enum Alphas { + Powers(F, Rc), + Combinations(Vec), +} + +impl PartialEq for Alphas { + fn eq(&self, other: &Self) -> bool { + // Maybe there's a more efficient way + self.clone().powers() == other.clone().powers() + } +} + +impl Eq for Alphas {} + +impl Foldable for Alphas { + fn combine(a: Self, b: Self, challenge: F) -> Self { + let a = a.powers(); + let b = b.powers(); + assert_eq!(a.len(), b.len()); + let comb = a + .into_iter() + .zip(b) + .map(|(a, b)| a + b * challenge) + .collect(); + Self::Combinations(comb) + } +} + +impl Alphas { + pub fn new(alpha: F) -> Self { + Self::Powers(alpha, Rc::new(AtomicUsize::from(0))) + } + + pub fn new_sized(alpha: F, count: usize) -> Self { + Self::Powers(alpha, Rc::new(AtomicUsize::from(count))) + } + + pub fn get(&self, i: usize) -> Option { + match self { + Alphas::Powers(alpha, count) => { + let _ = count.fetch_max(i + 1, Ordering::Relaxed); + let i = [i as u64]; + Some(alpha.pow(i)) + } + Alphas::Combinations(alphas) => alphas.get(i).cloned(), + } + } + + pub fn powers(self) -> Vec { + match self { + Alphas::Powers(alpha, count) => { + let n = count.load(Ordering::Relaxed); + let alphas = successors(Some(F::one()), |last| Some(*last * alpha)); + alphas.take(n).collect() + } + Alphas::Combinations(c) => c, + } + } +} diff --git a/folding/src/quadraticization.rs b/folding/src/quadraticization.rs new file mode 100644 index 0000000000..7d8b6da12e --- /dev/null +++ b/folding/src/quadraticization.rs @@ -0,0 +1,215 @@ +//! A library to reduce constraints into degree 2. +// TODO: more documentation + +use crate::{ + columns::ExtendedFoldingColumn, + error_term::{eval_sided, ExtendedEnv, Side}, + expressions::{Degree, FoldingExp}, + FoldingConfig, +}; +use std::collections::{BTreeMap, HashMap, VecDeque}; + +pub struct Quadraticized { + pub original_constraints: Vec>, + pub extra_constraints: Vec>, + pub extended_witness_generator: ExtendedWitnessGenerator, +} + +/// Returns the constraints converted into degree 2 or less and the extra +/// constraints added in the process +pub fn quadraticize( + constraints: Vec>, +) -> (Quadraticized, usize) { + let mut recorder = ExpRecorder::new(); + let original_constraints = constraints + .into_iter() + .map(|exp| lower_degree_to_2(exp, &mut recorder)) + .collect(); + let added_columns = recorder.next; + let (extra_constraints, exprs) = recorder.into_constraints(); + ( + Quadraticized { + original_constraints, + extra_constraints, + extended_witness_generator: ExtendedWitnessGenerator { exprs }, + }, + added_columns, + ) +} + +/// Records expressions that have been extracted into an extra column +struct ExpRecorder { + recorded_exprs: HashMap, usize>, + pub(crate) next: usize, +} + +impl ExpRecorder { + fn new() -> Self { + Self { + recorded_exprs: Default::default(), + next: 0, + } + } + + fn get_id(&mut self, exp: FoldingExp) -> usize { + *self.recorded_exprs.entry(exp).or_insert_with(|| { + let id = self.next; + self.next += 1; + id + }) + } + + #[allow(clippy::type_complexity)] + fn into_constraints(self) -> (Vec>, VecDeque<(usize, FoldingExp)>) { + let ExpRecorder { recorded_exprs, .. } = self; + let mut witness_generator = VecDeque::with_capacity(recorded_exprs.len()); + let mut new_constraints = BTreeMap::new(); + for (exp, id) in recorded_exprs.into_iter() { + let left = FoldingExp::Atom(ExtendedFoldingColumn::WitnessExtended(id)); + let constraint = FoldingExp::Sub(Box::new(left), Box::new(exp.clone())); + new_constraints.insert(id, constraint); + witness_generator.push_front((id, exp)); + } + (new_constraints.into_values().collect(), witness_generator) + } +} + +impl FoldingExp { + fn degree(&self) -> usize { + match self { + e @ FoldingExp::Atom(_) => match e.folding_degree() { + Degree::Zero => 0, + Degree::One => 1, + Degree::Two => 2, + }, + FoldingExp::Double(exp) => exp.degree(), + FoldingExp::Square(exp) => exp.degree() * 2, + FoldingExp::Add(e1, e2) | FoldingExp::Sub(e1, e2) => { + std::cmp::max(e1.degree(), e2.degree()) + } + FoldingExp::Mul(e1, e2) => e1.degree() + e2.degree(), + FoldingExp::Pow(e, i) => e.degree() * (*i as usize), + } + } +} + +fn lower_degree_to_1( + exp: FoldingExp, + rec: &mut ExpRecorder, +) -> FoldingExp { + let degree = exp.degree(); + match degree { + 0 => exp, + 1 => exp, + _ => match exp { + FoldingExp::Add(e1, e2) => FoldingExp::Add( + Box::new(lower_degree_to_1(*e1, rec)), + Box::new(lower_degree_to_1(*e2, rec)), + ), + FoldingExp::Sub(e1, e2) => FoldingExp::Sub( + Box::new(lower_degree_to_1(*e1, rec)), + Box::new(lower_degree_to_1(*e2, rec)), + ), + e @ FoldingExp::Square(_) | e @ FoldingExp::Mul(_, _) => { + let exp = lower_degree_to_2(e, rec); + let id = rec.get_id(exp); + FoldingExp::Atom(ExtendedFoldingColumn::WitnessExtended(id)) + } + FoldingExp::Double(exp) => FoldingExp::Double(Box::new(lower_degree_to_1(*exp, rec))), + _ => todo!(), + }, + } +} + +fn lower_degree_to_2( + exp: FoldingExp, + rec: &mut ExpRecorder, +) -> FoldingExp { + use FoldingExp::*; + let degree = exp.degree(); + if degree <= 2 { + return exp; + } + + match exp { + FoldingExp::Atom(_) => panic!("a column shouldn't be above degree 1"), + FoldingExp::Double(exp) => Double(Box::new(lower_degree_to_2(*exp, rec))), + FoldingExp::Square(exp) => Square(Box::new(lower_degree_to_1(*exp, rec))), + FoldingExp::Add(e1, e2) => { + let e1 = lower_degree_to_2(*e1, rec); + let e2 = lower_degree_to_2(*e2, rec); + Add(Box::new(e1), Box::new(e2)) + } + FoldingExp::Sub(e1, e2) => { + let e1 = lower_degree_to_2(*e1, rec); + let e2 = lower_degree_to_2(*e2, rec); + Sub(Box::new(e1), Box::new(e2)) + } + FoldingExp::Mul(e1, e2) => { + let d1 = e1.degree(); + let d2 = e2.degree(); + assert_eq!(degree, d1 + d2); + let (e1, e2) = (*e1, *e2); + let (e1, e2) = match (d1, d2) { + (0, _) => (e1, lower_degree_to_2(e2, rec)), + (_, 0) => (lower_degree_to_2(e1, rec), e2), + (_, _) => (lower_degree_to_1(e1, rec), lower_degree_to_1(e2, rec)), + }; + Mul(Box::new(e1), Box::new(e2)) + } + exp @ FoldingExp::Pow(_, 0) => exp, + FoldingExp::Pow(e, 1) => lower_degree_to_2(*e, rec), + FoldingExp::Pow(e, 2) => FoldingExp::Pow(Box::new(lower_degree_to_1(*e, rec)), 2), + FoldingExp::Pow(e, i) => { + let e = lower_degree_to_1(*e, rec); + let mut acc = e.clone(); + for _ in 1..i - 1 { + acc = lower_degree_to_1(FoldingExp::Mul(Box::new(e.clone()), Box::new(acc)), rec); + } + FoldingExp::Mul(Box::new(e.clone()), Box::new(acc)) + } + } +} + +pub struct ExtendedWitnessGenerator { + exprs: VecDeque<(usize, FoldingExp)>, +} + +impl ExtendedWitnessGenerator { + pub(crate) fn compute_extended_witness( + &self, + mut env: ExtendedEnv, + side: Side, + ) -> ExtendedEnv { + if env.needs_extension(side) { + let mut pending = self.exprs.clone(); + + while let Some((i, exp)) = pending.pop_front() { + if check_evaluable(&exp, &env, side) { + let evals = eval_sided(&exp, &env, side).unwrap(); + env.add_witness_evals(i, evals, side); + } else { + pending.push_back((i, exp)) + } + } + } + + env + } +} + +/// Checks if the expression can be evaluated in the current environment +fn check_evaluable( + exp: &FoldingExp, + env: &ExtendedEnv, + side: Side, +) -> bool { + match exp { + FoldingExp::Atom(col) => env.col_try(col, side), + FoldingExp::Double(e) | FoldingExp::Square(e) => check_evaluable(e, env, side), + FoldingExp::Add(e1, e2) | FoldingExp::Sub(e1, e2) | FoldingExp::Mul(e1, e2) => { + check_evaluable(e1, env, side) && check_evaluable(e2, env, side) + } + FoldingExp::Pow(e, _) => check_evaluable(e, env, side), + } +} diff --git a/folding/src/standard_config.rs b/folding/src/standard_config.rs new file mode 100644 index 0000000000..d8d9f92490 --- /dev/null +++ b/folding/src/standard_config.rs @@ -0,0 +1,450 @@ +//! This module offers a standard implementation of [FoldingConfig] supporting +//! many use cases +use crate::{ + expressions::FoldingColumnTrait, instance_witness::Witness, FoldingConfig, FoldingEnv, + Instance, Side, +}; +use derivative::Derivative; +use kimchi::{circuits::gate::CurrOrNext, curve::KimchiCurve}; +use memoization::ColumnMemoizer; +use poly_commitment::{self, commitment::CommitmentCurve}; +use std::{fmt::Debug, hash::Hash, marker::PhantomData, ops::Index}; + +#[derive(Clone, Default)] +/// Default type for when you don't need structure +pub struct EmptyStructure(PhantomData); + +impl Index for EmptyStructure { + type Output = Vec; + + fn index(&self, _index: Col) -> &Self::Output { + panic!("shouldn't reach this point, as this type only works with witness-only constraint systems"); + } +} + +/// A standard folding config that supports: +/// `G`: any curve +/// `Col`: any column implementing [FoldingColumnTrait] +/// `Chall`: any challenge +/// `Sel`: any dynamic selector +/// `Str`: structures that can be indexed by `Col`, thus implementing `Index` +/// `I`: instances (implementing [Instance]) that can be indexed by `Chall` +/// `W`: witnesses (implementing [Witness]) that can be indexed by `Col` and `Sel` +/// ```ignore +/// use ark_poly::{EvaluationDomain, Radix2EvaluationDomain}; +/// use mina_poseidon::FqSponge; +/// use folding::{examples::{BaseSponge, Curve, Fp}, FoldingScheme}; +/// +/// // instanciating the config with our types and the defaults for selectors and structure +/// type MyConfig = StandardConfig, MyWitness>; +/// let constraints = vec![constraint()]; +/// let domain = Radix2EvaluationDomain::::new(2).unwrap(); +/// let srs = poly_commitment::srs::SRS::::create(2); +/// srs.get_lagrange_basis(domain); +/// // this is the default structure, which does nothing or panics if +/// // indexed (as it shouldn't be indexed) +/// let structure = EmptyStructure::default(); +/// +/// // here we can use the config +/// let (scheme, _) = +/// FoldingScheme::::new(constraints, &srs, domain, &structure); +/// +/// let [left, right] = pairs; +/// let left = (left.0, left.1); +/// let right = (right.0, right.1); +/// +/// let mut fq_sponge = BaseSponge::new(Curve::other_curve_sponge_params()); +/// let _output = scheme.fold_instance_witness_pair(left, right, &mut fq_sponge); +/// ``` +#[derive(Derivative)] +#[derivative(Hash, PartialEq, Eq, Debug)] +#[allow(clippy::type_complexity)] +pub struct StandardConfig>( + PhantomData<(G, Col, Chall, Sel, Str, I, W, Srs)>, +); + +//implementing FoldingConfig +impl FoldingConfig + for StandardConfig +where + Self: 'static, + G: CommitmentCurve, + I: Instance + Index + Clone, + W: Witness + Clone, + W: Index + Index, + Srs: poly_commitment::SRS, + Col: Hash + Eq + Debug + Clone + FoldingColumnTrait, + Sel: Ord + Copy + Hash + Debug, + Chall: Hash + Eq + Debug + Copy, + Str: Clone + Index, +{ + type Column = Col; + + type Selector = Sel; + + type Challenge = Chall; + + type Curve = G; + + type Srs = Srs; + + type Instance = I; + + type Witness = W; + + type Structure = Str; + + type Env = Env; +} +///A generic Index based environment +pub struct Env +where + G: CommitmentCurve, + I: Instance + Index + Clone, + W: Witness + Clone, + W: Index + Index, + Col: Hash + Eq, +{ + instances: [I; 2], + witnesses: [W; 2], + next_evals: ColumnMemoizer, + structure: Str, + //not used but needed as generics for the bounds + _phantom: PhantomData<(G, Col, Chall, Sel, Str)>, +} + +//implementing FoldingEnv +impl FoldingEnv + for Env +where + G: CommitmentCurve, + I: Instance + Index + Clone, + W: Witness + Clone, + W: Index + Index, + Col: FoldingColumnTrait + Eq + Hash, + Sel: Copy, + Str: Clone + Index, +{ + type Structure = Str; + + fn new(structure: &Self::Structure, instances: [&I; 2], witnesses: [&W; 2]) -> Self { + // cloning for now, ideally should work with references, but that requires deeper + // refactorings of folding + let instances = instances.map(Clone::clone); + let witnesses = witnesses.map(Clone::clone); + let structure = structure.clone(); + Self { + instances, + witnesses, + structure, + next_evals: ColumnMemoizer::new(), + _phantom: PhantomData, + } + } + + fn challenge(&self, challenge: Chall, side: Side) -> G::ScalarField { + let instance = match side { + Side::Left => &self.instances[0], + Side::Right => &self.instances[1], + }; + // handled through Index in I + instance[challenge] + } + + fn col(&self, col: Col, curr_or_next: CurrOrNext, side: Side) -> &[G::ScalarField] { + let witness = match side { + Side::Left => &self.witnesses[0], + Side::Right => &self.witnesses[1], + }; + // this should hold as long the Index implementations are consistent with the + // FoldingColumnTrait implementation. + // either search in witness for witness columns, or in the structure otherwise + if col.is_witness() { + match curr_or_next { + CurrOrNext::Curr => &witness[col], + CurrOrNext::Next => { + let f = || { + // simple but not the best, ideally there would be a single vector, + // where you push its first element and offer either evals[0..] or + // evals[1..]. + // that would relatively easy to implement in a custom implementation + // with just a small change to this trait, but in this generic implementation + // it is harder to implement. + // The cost is mostly the cost of a clone + let evals = &witness[col]; + let mut next = Vec::with_capacity(evals.len()); + next.extend(evals[1..].iter()); + next.push(evals[0]); + next + }; + self.next_evals.get_or_insert(col, f) + } + } + } else { + &self.structure[col] + } + } + + fn selector(&self, s: &Sel, side: Side) -> &[G::ScalarField] { + //similar to the witness case of col, as expected + let witness = match side { + Side::Left => &self.witnesses[0], + Side::Right => &self.witnesses[1], + }; + &witness[*s] + } +} + +/// contains a data structure useful to support the [CurrOrNext::Next] case +/// in [FoldingEnv::col] +mod memoization { + use ark_ff::Field; + use std::{ + cell::{OnceCell, RefCell}, + collections::HashMap, + hash::Hash, + sync::atomic::{AtomicUsize, Ordering}, + }; + + /// a segment with up to N stored columns, and the potential + /// next segment, similar to a linked list N-length arrays + pub struct ColumnMemoizerSegment { + cols: [OnceCell>; N], + next: OnceCell>, + } + + impl ColumnMemoizerSegment { + pub fn new() -> Self { + let cols = [(); N].map(|_| OnceCell::new()); + let next = OnceCell::new(); + Self { cols, next } + } + // This will find the column if i < N, and get a reference to it, + // initializing it with `f` if needed. + // If i >= N it will continue recursing to the next segment, initializing + // it if needed + pub fn get_or_insert(&self, i: usize, f: I) -> &Vec + where + I: FnOnce() -> Vec, + { + match i { + i if i < N => { + let col = &self.cols[i]; + col.get_or_init(f) + } + i => { + let i = i - N; + let new = || Box::new(Self::new()); + let next = self.next.get_or_init(new); + next.get_or_insert(i, f) + } + } + } + } + /// a hashmap like data structure supporting get-or-insert with + /// an immutable reference and returning an inmutable reference + /// without guard + pub struct ColumnMemoizer { + first_segment: ColumnMemoizerSegment, + next: AtomicUsize, + ids: RefCell>, + } + + impl ColumnMemoizer { + pub fn new() -> Self { + let first_segment = ColumnMemoizerSegment::new(); + let next = AtomicUsize::from(0); + let ids = RefCell::new(HashMap::new()); + Self { + first_segment, + next, + ids, + } + } + pub fn get_or_insert(&self, col: C, f: I) -> &Vec + where + I: FnOnce() -> Vec, + { + // this will find or asign an id for the column and then + // search the segments using the id + let mut ids = self.ids.borrow_mut(); + let new_id = || self.next.fetch_add(1, Ordering::Relaxed); + let id = ids.entry(col).or_insert_with(new_id); + self.first_segment.get_or_insert(*id, f) + } + } +} + +#[cfg(feature = "bn254")] +mod example { + use crate::{ + examples::{BaseSponge, Curve, Fp}, + expressions::{FoldingColumnTrait, FoldingCompatibleExprInner}, + instance_witness::Foldable, + standard_config::{EmptyStructure, StandardConfig}, + FoldingCompatibleExpr, Instance, Witness, + }; + use ark_ec::ProjectiveCurve; + use kimchi::{ + circuits::{expr::Variable, gate::CurrOrNext}, + curve::KimchiCurve, + }; + use std::ops::Index; + + // we create some example types + + // an instance + #[derive(Clone, Debug)] + struct MyInstance { + commitments: [G; 3], + beta: G::ScalarField, + gamma: G::ScalarField, + } + + // implementing foldable + impl Foldable for MyInstance { + fn combine(mut a: Self, b: Self, challenge: G::ScalarField) -> Self { + for (a, b) in a.commitments.iter_mut().zip(b.commitments) { + *a = *a + b.mul(challenge).into_affine(); + } + a + } + } + // and instance + impl Instance for MyInstance { + fn to_absorb(&self) -> (Vec<::ScalarField>, Vec) { + (vec![], self.commitments.to_vec()) + } + + fn get_alphas(&self) -> &crate::Alphas<::ScalarField> { + todo!() + } + } + // a witness + #[derive(Clone, Debug)] + struct MyWitness { + columns: [Vec; 3], + } + + // implementing foldable + impl Foldable for MyWitness { + fn combine(mut a: Self, b: Self, challenge: G::ScalarField) -> Self { + for (a, b) in a.columns.iter_mut().zip(b.columns) { + for (a, b) in a.iter_mut().zip(b) { + *a += b * challenge; + } + } + a + } + } + // and Witness + impl Witness for MyWitness {} + + // a type for the columns + #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] + enum MyCol { + A, + B, + C, + } + + // implementing FoldingColumnTrait, trivial in this witness-only case + impl FoldingColumnTrait for MyCol { + fn is_witness(&self) -> bool { + true + } + } + // a challenge + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] + enum MyChallenge { + Beta, + Gamma, + } + + // now, to use this config with our types, we need to implement Index + // if not already implemented, to resolve access to values in instance, + // witness, and structure if present + + // for witness columns + impl Index for MyWitness { + type Output = Vec; + + fn index(&self, index: MyCol) -> &Self::Output { + let index = match index { + MyCol::A => 0, + MyCol::B => 1, + MyCol::C => 2, + }; + &self.columns[index] + } + } + // for selectors, () in this case as we have none + impl Index<()> for MyWitness { + type Output = [G::ScalarField]; + + fn index(&self, _index: ()) -> &Self::Output { + unreachable!() + } + } + // for challenges, which should live in the instance + impl Index for MyInstance { + type Output = G::ScalarField; + + fn index(&self, index: MyChallenge) -> &Self::Output { + match index { + MyChallenge::Beta => &self.beta, + MyChallenge::Gamma => &self.gamma, + } + } + } + + // now we can get an instance of StandardConfig, where selectors and structures have + // default for cases like this where we don't need them + type MyConfig = StandardConfig, MyWitness>; + + // creating some example constraint + fn constraint() -> FoldingCompatibleExpr> { + let column = |col| { + FoldingCompatibleExpr::Atom(FoldingCompatibleExprInner::Cell(Variable { + col, + row: CurrOrNext::Curr, + })) + }; + let chall = + |chall| FoldingCompatibleExpr::Atom(FoldingCompatibleExprInner::Challenge(chall)); + let a = column(MyCol::A); + let b = column(MyCol::B); + let c = column(MyCol::C); + let beta = chall(MyChallenge::Beta); + let gamma = chall(MyChallenge::Gamma); + (a + b - c) * (beta + gamma) + } + + // here an example of how the config would be used, which is similar to any + // other custom config. + #[allow(dead_code)] + fn fold(pairs: [(MyInstance, MyWitness); 2]) { + use crate::FoldingScheme; + use ark_poly::{EvaluationDomain, Radix2EvaluationDomain}; + use mina_poseidon::FqSponge; + + let constraints = vec![constraint()]; + let domain = Radix2EvaluationDomain::::new(2).unwrap(); + let srs = poly_commitment::ipa::SRS::::create(2); + srs.get_lagrange_basis(domain); + // this is the default structure, which does nothing or panics if + // indexed (as it shouldn't be indexed) + let structure = EmptyStructure::default(); + + // here we can use the config + let (scheme, _) = + FoldingScheme::>::new(constraints, &srs, domain, &structure); + + let [left, right] = pairs; + let left = (left.0, left.1); + let right = (right.0, right.1); + + let mut fq_sponge = BaseSponge::new(Curve::other_curve_sponge_params()); + let _output = scheme.fold_instance_witness_pair(left, right, &mut fq_sponge); + } +} diff --git a/folding/tests/test_decomposable_folding.rs b/folding/tests/test_decomposable_folding.rs new file mode 100644 index 0000000000..d140cc6481 --- /dev/null +++ b/folding/tests/test_decomposable_folding.rs @@ -0,0 +1,476 @@ +use ark_ec::AffineRepr; +use ark_ff::{One, UniformRand}; +use ark_poly::{Evaluations, Radix2EvaluationDomain}; +use folding::{ + checker::{Checker, ExtendedProvider}, + expressions::{FoldingColumnTrait, FoldingCompatibleExprInner}, + instance_witness::Foldable, + Alphas, FoldingCompatibleExpr, FoldingConfig, FoldingEnv, Instance, Side, Witness, +}; +use itertools::Itertools; +use kimchi::circuits::{expr::Variable, gate::CurrOrNext}; +use poly_commitment::{ipa::SRS, SRS as _}; +use rand::thread_rng; +use std::{collections::BTreeMap, ops::Index}; + +use mina_poseidon::{constants::PlonkSpongeConstantsKimchi, sponge::DefaultFqSponge}; + +// Trick to print debug message while testing, as we in the test config env +use ark_poly::{EvaluationDomain, Radix2EvaluationDomain as D}; +use folding::{decomposable_folding::DecomposableFoldingScheme, FoldingOutput}; +use kimchi::curve::KimchiCurve; +use mina_poseidon::FqSponge; +use std::println as debug; + +type Fp = ark_bn254::Fr; +type Curve = ark_bn254::G1Affine; +type SpongeParams = PlonkSpongeConstantsKimchi; +type BaseSponge = DefaultFqSponge; + +// the type representing our columns, in this case we have 3 witness columns +#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)] +pub enum TestColumn { + A, + B, + C, +} + +// the type for the dynamic selectors, which are esentially witness columns, but +// get special treatment to enable optimizations +#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, PartialOrd, Ord)] +pub enum DynamicSelector { + SelecAdd, + SelecSub, +} + +impl FoldingColumnTrait for TestColumn { + //in this case we have only witness, the other example shows non-witness columns + fn is_witness(&self) -> bool { + match self { + TestColumn::A | TestColumn::B | TestColumn::C => true, + } + } +} + +/// The instance is the commitments to the polynomials and the challenges +#[derive(Debug, Clone)] +pub struct TestInstance { + // 3 from the normal witness + 2 from the dynamic selectors + commitments: [Curve; 5], + // for ilustration only, no constraint in this example uses challenges + challenges: [Fp; 3], + // also challenges, but segregated as folding gives them special treatment + alphas: Alphas, + // Used for the blinding factor in the commitment + blinder: Fp, +} + +impl Foldable for TestInstance { + fn combine(a: Self, b: Self, challenge: Fp) -> Self { + TestInstance { + commitments: std::array::from_fn(|i| { + (a.commitments[i] + b.commitments[i] * challenge).into() + }), + challenges: std::array::from_fn(|i| a.challenges[i] + challenge * b.challenges[i]), + alphas: Alphas::combine(a.alphas, b.alphas, challenge), + blinder: a.blinder + challenge * b.blinder, + } + } +} + +impl Instance for TestInstance { + fn to_absorb(&self) -> (Vec, Vec) { + // FIXME? + (vec![], vec![]) + } + + fn get_alphas(&self) -> &Alphas { + &self.alphas + } + + fn get_blinder(&self) -> Fp { + self.blinder + } +} + +/// Our witness is going to be the polynomials that we will commit too. +/// Vec will be the evaluations of each x_1, x_2 and x_3 over the domain. +/// This witness includes not only the 3 normal witness columns, but also the +/// 2 dynamic selector columns that are esentially witness +#[derive(Clone)] +pub struct TestWitness([Evaluations>; 5]); + +impl Foldable for TestWitness { + fn combine(mut a: Self, b: Self, challenge: Fp) -> Self { + for (a, b) in a.0.iter_mut().zip(b.0) { + for (a, b) in a.evals.iter_mut().zip(b.evals) { + *a += challenge * b; + } + } + a + } +} + +impl Witness for TestWitness {} + +// our environment, the way in which we provide access to the actual values in the +// witness and instances, when folding evaluates expressions and reaches leaves (Atom) +// it will call methods from here to resolve the types we have in the config like the +// columns into the actual values. +pub struct TestFoldingEnv { + instances: [TestInstance; 2], + // Corresponds to the omega evaluations, for both sides + curr_witnesses: [TestWitness; 2], + // Corresponds to the zeta*omega evaluations, for both sides + // This is curr_witness but left shifted by 1 + next_witnesses: [TestWitness; 2], +} + +// implementing the an environment trait compatible with our config +impl FoldingEnv + for TestFoldingEnv +{ + type Structure = (); + + fn new( + _structure: &Self::Structure, + instances: [&TestInstance; 2], + witnesses: [&TestWitness; 2], + ) -> Self { + // here it is mostly storing the pairs into self, and also computing other things we may need + // later like the shifted versions, note there are more efficient ways of handling the rotated + // witnesses, which are just for example as no contraint uses them anyway + let curr_witnesses = [witnesses[0].clone(), witnesses[1].clone()]; + let mut next_witnesses = curr_witnesses.clone(); + for side in next_witnesses.iter_mut() { + for col in side.0.iter_mut() { + //TODO: check this, while not relevant for this example I think it should be right rotation + col.evals.rotate_left(1); + } + } + TestFoldingEnv { + instances: [instances[0].clone(), instances[1].clone()], + curr_witnesses, + next_witnesses, + } + } + + // provide access to columns, here side refers to one of the two pairs you + // got in new() + fn col(&self, col: TestColumn, curr_or_next: CurrOrNext, side: Side) -> &[Fp] { + let wit = match curr_or_next { + CurrOrNext::Curr => &self.curr_witnesses[side as usize], + CurrOrNext::Next => &self.next_witnesses[side as usize], + }; + match col { + TestColumn::A => &wit.0[0].evals, + TestColumn::B => &wit.0[1].evals, + TestColumn::C => &wit.0[2].evals, + } + } + + // same as column but for challenges, challenges are not constants + fn challenge(&self, challenge: TestChallenge, side: Side) -> Fp { + match challenge { + TestChallenge::Beta => self.instances[side as usize].challenges[0], + TestChallenge::Gamma => self.instances[side as usize].challenges[1], + TestChallenge::JointCombiner => self.instances[side as usize].challenges[2], + } + } + + // This is exclusively for dynamic selectors aiming to make use of optimization + // as classic static selectors will be handled as normal structure columns in col(). + // The implementation of this if the same as col(), it is just separated as they + // have different types to resolve + fn selector(&self, s: &DynamicSelector, side: Side) -> &[Fp] { + let wit = &self.curr_witnesses[side as usize]; + match s { + DynamicSelector::SelecAdd => &wit.0[3].evals, + DynamicSelector::SelecSub => &wit.0[4].evals, + } + } +} + +// this creates 2 single-constraint gates, each with a selector, +// an addition gate, and a subtraction gate +fn constraints() -> BTreeMap>> { + let get_col = |col| { + FoldingCompatibleExpr::Atom(FoldingCompatibleExprInner::Cell(Variable { + col, + row: CurrOrNext::Curr, + })) + }; + let a = Box::new(get_col(TestColumn::A)); + let b = Box::new(get_col(TestColumn::B)); + let c = Box::new(get_col(TestColumn::C)); + + let add = FoldingCompatibleExpr::Add(a.clone(), b.clone()); + let add = FoldingCompatibleExpr::Sub(add.into(), c.clone()); + let add = FoldingCompatibleExpr::Sub( + add.into(), + Box::new(FoldingCompatibleExpr::Atom( + FoldingCompatibleExprInner::Constant(Fp::one()), + )), + ); + // a + b - c - 1 = 0 + + let sub = FoldingCompatibleExpr::Sub(a.clone(), b.clone()); + let sub = FoldingCompatibleExpr::Sub(sub.into(), c.clone()); + + [ + (DynamicSelector::SelecAdd, vec![add]), + (DynamicSelector::SelecSub, vec![sub]), + ] + .into_iter() + .collect() +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct TestFoldingConfig; + +// Does not contain alpha because it should be added to the expressions by folding +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +pub enum TestChallenge { + Beta, + Gamma, + JointCombiner, +} + +impl FoldingConfig for TestFoldingConfig { + type Structure = (); + type Column = TestColumn; + type Selector = DynamicSelector; + type Challenge = TestChallenge; + type Curve = Curve; + type Srs = SRS; + type Instance = TestInstance; + type Witness = TestWitness; + type Env = TestFoldingEnv; +} + +//creates an instance from its witness +fn instance_from_witness( + witness: &TestWitness, + srs: &::Srs, + domain: Radix2EvaluationDomain, +) -> TestInstance { + let commitments = witness + .0 + .iter() + .map(|w| srs.commit_evaluations_non_hiding(domain, w)) + .map(|c| c.get_first_chunk()) + .collect_vec(); + let commitments: [_; 5] = commitments.try_into().unwrap(); + + // here we should absorb the commitments and similar things to later compute challenges + // but for this example I just use random values + let mut rng = thread_rng(); + let mut challenge = || Fp::rand(&mut rng); + let challenges = [(); 3].map(|_| challenge()); + let alpha = challenge(); + let alphas = Alphas::new(alpha); + let blinder = Fp::one(); + TestInstance { + commitments, + challenges, + alphas, + blinder, + } +} + +impl Checker for ExtendedProvider {} + +impl Index for TestInstance { + type Output = Fp; + + fn index(&self, index: TestChallenge) -> &Self::Output { + match index { + TestChallenge::Beta => &self.challenges[0], + TestChallenge::Gamma => &self.challenges[1], + TestChallenge::JointCombiner => &self.challenges[2], + } + } +} + +impl Index for TestWitness { + type Output = Evaluations>; + + fn index(&self, index: TestColumn) -> &Self::Output { + match index { + TestColumn::A => &self.0[0], + TestColumn::B => &self.0[1], + TestColumn::C => &self.0[2], + } + } +} + +impl Index for TestWitness { + type Output = Evaluations>; + + fn index(&self, index: DynamicSelector) -> &Self::Output { + match index { + DynamicSelector::SelecAdd => &self.0[3], + DynamicSelector::SelecSub => &self.0[4], + } + } +} + +// two functions to create the entire witness from just the a and b columns +fn add_witness(a: [u32; 2], b: [u32; 2]) -> [[u32; 2]; 5] { + let [a1, a2] = a; + let [b1, b2] = b; + let c = [a1 + b1 - 1, a2 + b2 - 1]; + [a, b, c, [1, 1], [0, 0]] +} + +fn sub_witness(a: [u32; 2], b: [u32; 2]) -> [[u32; 2]; 5] { + let [a1, a2] = a; + let [b1, b2] = b; + let c = [a1 - b1, a2 - b2]; + [a, b, c, [0, 0], [1, 1]] +} +fn int_to_witness(x: [[u32; 2]; 5], domain: Radix2EvaluationDomain) -> TestWitness { + TestWitness(x.map(|row| Evaluations::from_vec_and_domain(row.map(Fp::from).to_vec(), domain))) +} + +// in this test we will create 2 add witnesses, fold them together, create 2 +// sub witnesses, +// fold them together, and then further fold the 2 resulting pairs into one +// mixed add-sub witnes +// instances are also folded, but not that relevant in the examples as we +// don't make a proof for them +// and instead directly check the witness +#[test] +fn test_decomposable_folding() { + let constraints = constraints(); + let domain = D::::new(2).unwrap(); + let srs = SRS::::create(2); + srs.get_lagrange_basis(domain); + + let mut fq_sponge = BaseSponge::new(Curve::other_curve_sponge_params()); + + // initiallize the scheme, also getting the final single expression for + // the entire constraint system + let (scheme, final_constraint) = DecomposableFoldingScheme::::new( + constraints.clone(), + vec![], + &srs, + domain, + &(), + ); + + // some inputs to be used by both add and sub + let inputs1 = [[4u32, 2u32], [2u32, 1u32]]; + let inputs2 = [[5u32, 6u32], [4u32, 3u32]]; + + // creates an instance witness pair + let make_pair = |wit: TestWitness| { + let ins = instance_from_witness(&wit, &srs, domain); + (wit, ins) + }; + + // fold adds + debug!("fold add"); + let left = { + let [a, b] = inputs1; + let wit1 = add_witness(a, b); + let (witness1, instance1) = make_pair(int_to_witness(wit1, domain)); + + let [a, b] = inputs2; + let wit2 = add_witness(a, b); + let (witness2, instance2) = make_pair(int_to_witness(wit2, domain)); + + let left = (instance1, witness1); + let right = (instance2, witness2); + // here we provide normal instance-witness pairs, which will be + // automatically relaxed + let folded = scheme.fold_instance_witness_pair( + left, + right, + Some(DynamicSelector::SelecAdd), + &mut fq_sponge, + ); + let FoldingOutput { + folded_instance, + folded_witness, + t_0: _, + t_1: _, + relaxed_extended_left_instance: _, + relaxed_extended_right_instance: _, + to_absorb: _, + } = folded; + let checker = ExtendedProvider::new(folded_instance, folded_witness); + debug!("exp: \n {:#?}", final_constraint.to_string()); + checker.check(&final_constraint, domain); + let ExtendedProvider { + instance, witness, .. + } = checker; + (instance, witness) + }; + //fold subs + debug!("fold subs"); + let right = { + let [a, b] = inputs1; + let wit1 = sub_witness(a, b); + let (witness1, instance1) = make_pair(int_to_witness(wit1, domain)); + + let [a, b] = inputs2; + let wit2 = sub_witness(a, b); + let (witness2, instance2) = make_pair(int_to_witness(wit2, domain)); + + let left = (instance1, witness1); + let right = (instance2, witness2); + let folded = scheme.fold_instance_witness_pair( + left, + right, + Some(DynamicSelector::SelecSub), + &mut fq_sponge, + ); + let FoldingOutput { + folded_instance, + folded_witness, + t_0: _, + t_1: _, + relaxed_extended_left_instance: _, + relaxed_extended_right_instance: _, + to_absorb: _, + } = folded; + + let checker = ExtendedProvider::new(folded_instance, folded_witness); + debug!("exp: \n {:#?}", final_constraint.to_string()); + + checker.check(&final_constraint, domain); + let ExtendedProvider { + instance, witness, .. + } = checker; + (instance, witness) + }; + //fold mixed + debug!("fold mixed"); + { + // here we use already relaxed pairs, which have a trival x -> x implementation + let folded = scheme.fold_instance_witness_pair(left, right, None, &mut fq_sponge); + let FoldingOutput { + folded_instance, + folded_witness, + t_0, + t_1, + relaxed_extended_left_instance: _, + relaxed_extended_right_instance: _, + to_absorb: _, + } = folded; + + // Verifying that error terms are not points at infinity + // It doesn't test that the computation happens correctly, but at least + // show that there is some non trivial computation. + assert_eq!(t_0.len(), 1); + assert_eq!(t_1.len(), 1); + assert!(!t_0.get_first_chunk().is_zero()); + assert!(!t_1.get_first_chunk().is_zero()); + + let checker = ExtendedProvider::new(folded_instance, folded_witness); + debug!("exp: \n {:#?}", final_constraint.to_string()); + + checker.check(&final_constraint, domain); + }; +} diff --git a/folding/tests/test_folding_with_quadriticization.rs b/folding/tests/test_folding_with_quadriticization.rs new file mode 100644 index 0000000000..2d23f76510 --- /dev/null +++ b/folding/tests/test_folding_with_quadriticization.rs @@ -0,0 +1,513 @@ +// this example is a copy of the decomposable folding one, but with a degree 3 gate +// that triggers quadriticization +use ark_ec::AffineRepr; +use ark_ff::{One, UniformRand}; +use ark_poly::{Evaluations, Radix2EvaluationDomain}; +use folding::{ + checker::{Checker, ExtendedProvider}, + expressions::{FoldingColumnTrait, FoldingCompatibleExprInner}, + instance_witness::Foldable, + Alphas, FoldingCompatibleExpr, FoldingConfig, FoldingEnv, Instance, Side, Witness, +}; +use itertools::Itertools; +use kimchi::circuits::{expr::Variable, gate::CurrOrNext}; +use mina_poseidon::{constants::PlonkSpongeConstantsKimchi, sponge::DefaultFqSponge}; +use poly_commitment::{ipa::SRS, SRS as _}; +use rand::thread_rng; +use std::{collections::BTreeMap, ops::Index}; + +use ark_poly::{EvaluationDomain, Radix2EvaluationDomain as D}; +use folding::{decomposable_folding::DecomposableFoldingScheme, FoldingOutput}; +use kimchi::curve::KimchiCurve; +use mina_poseidon::FqSponge; +use std::println as debug; + +type Fp = ark_bn254::Fr; +type Curve = ark_bn254::G1Affine; +type SpongeParams = PlonkSpongeConstantsKimchi; +type BaseSponge = DefaultFqSponge; + +// the type representing our columns, in this case we have 3 witness columns +#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)] +pub enum TestColumn { + A, + B, + C, +} + +// the type for the dynamic selectors, which are esentially witness columns, but +// get special treatment to enable optimizations +#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, PartialOrd, Ord)] +pub enum DynamicSelector { + SelecAdd, + SelecMul, +} + +impl FoldingColumnTrait for TestColumn { + //in this case we have only witness, the other example shows non-witness columns + fn is_witness(&self) -> bool { + match self { + TestColumn::A | TestColumn::B | TestColumn::C => true, + } + } +} + +/// The instance is the commitments to the polynomials and the challenges +#[derive(PartialEq, Eq, Debug, Clone)] +pub struct TestInstance { + // 3 from the normal witness + 2 from the dynamic selectors + commitments: [Curve; 5], + // for ilustration only, no constraint in this example uses challenges + challenges: [Fp; 3], + // also challenges, but segregated as folding gives them special treatment + alphas: Alphas, + // To blind the commitments, + blinder: Fp, +} + +impl Foldable for TestInstance { + fn combine(a: Self, b: Self, challenge: Fp) -> Self { + TestInstance { + commitments: std::array::from_fn(|i| { + (a.commitments[i] + b.commitments[i] * challenge).into() + }), + challenges: std::array::from_fn(|i| a.challenges[i] + challenge * b.challenges[i]), + alphas: Alphas::combine(a.alphas, b.alphas, challenge), + blinder: a.blinder + challenge * b.blinder, + } + } +} + +impl Instance for TestInstance { + fn to_absorb(&self) -> (Vec, Vec) { + // FIXME? + (vec![], vec![]) + } + + fn get_alphas(&self) -> &Alphas { + &self.alphas + } + + fn get_blinder(&self) -> Fp { + self.blinder + } +} + +// our environment, the way in which we provide access to the actual values in the +// witness and instances, when folding evaluates expressions and reaches leaves (Atom) +// it will call methods from here to resolve the types we have in the config like the +// columns into the actual values. +pub struct TestFoldingEnv { + instances: [TestInstance; 2], + // Corresponds to the omega evaluations, for both sides + curr_witnesses: [TestWitness; 2], + // Corresponds to the zeta*omega evaluations, for both sides + // This is curr_witness but left shifted by 1 + next_witnesses: [TestWitness; 2], +} + +/// Our witness is going to be the polynomials that we will commit too. +/// Vec will be the evaluations of each x_1, x_2 and x_3 over the domain. +/// This witness includes not only the 3 normal witness columns, but also the +/// 2 dynamic selector columns that are esentially witness +#[derive(Clone)] +pub struct TestWitness([Evaluations>; 5]); + +impl Foldable for TestWitness { + fn combine(mut a: Self, b: Self, challenge: Fp) -> Self { + for (a, b) in a.0.iter_mut().zip(b.0) { + for (a, b) in a.evals.iter_mut().zip(b.evals) { + *a += challenge * b; + } + } + a + } +} + +impl Witness for TestWitness {} + +// implementing the an environment trait compatible with our config +impl FoldingEnv + for TestFoldingEnv +{ + type Structure = (); + + fn new( + _structure: &Self::Structure, + instances: [&TestInstance; 2], + witnesses: [&TestWitness; 2], + ) -> Self { + // here it is mostly storing the pairs into self, and also computing + // other things we may need later like the shifted versions, note there + // are more efficient ways of handling the rotated witnesses, which are + // just for example as no contraint uses them anyway + let curr_witnesses = [witnesses[0].clone(), witnesses[1].clone()]; + let mut next_witnesses = curr_witnesses.clone(); + for side in next_witnesses.iter_mut() { + for col in side.0.iter_mut() { + // TODO: check this, while not relevant for this example I think + // it should be right rotation + col.evals.rotate_left(1); + } + } + TestFoldingEnv { + instances: [instances[0].clone(), instances[1].clone()], + curr_witnesses, + next_witnesses, + } + } + + // provide access to columns, here side refers to one of the two pairs you + // got in new() + fn col(&self, col: TestColumn, curr_or_next: CurrOrNext, side: Side) -> &[Fp] { + let wit = match curr_or_next { + CurrOrNext::Curr => &self.curr_witnesses[side as usize], + CurrOrNext::Next => &self.next_witnesses[side as usize], + }; + match col { + TestColumn::A => &wit.0[0].evals, + TestColumn::B => &wit.0[1].evals, + TestColumn::C => &wit.0[2].evals, + } + } + + // same as column but for challenges, challenges are not constants + fn challenge(&self, challenge: TestChallenge, side: Side) -> Fp { + match challenge { + TestChallenge::Beta => self.instances[side as usize].challenges[0], + TestChallenge::Gamma => self.instances[side as usize].challenges[1], + TestChallenge::JointCombiner => self.instances[side as usize].challenges[2], + } + } + + // this is exclusively for dynamic selectors aiming to make use of optimization + // as clasic static selectors will be handle as normal structure columns in col() + // the implementation of this if the same as col(), it is just separated as they + // have different types to resolve + fn selector(&self, s: &DynamicSelector, side: Side) -> &[Fp] { + let wit = &self.curr_witnesses[side as usize]; + match s { + DynamicSelector::SelecAdd => &wit.0[3].evals, + DynamicSelector::SelecMul => &wit.0[4].evals, + } + } +} + +// this creates 2 single-constraint gates, each with a selector, +// an addition gate, and a multiplication gate +fn constraints() -> BTreeMap>> { + let get_col = |col| { + FoldingCompatibleExpr::Atom(FoldingCompatibleExprInner::Cell(Variable { + col, + row: CurrOrNext::Curr, + })) + }; + let a = Box::new(get_col(TestColumn::A)); + let b = Box::new(get_col(TestColumn::B)); + let c = Box::new(get_col(TestColumn::C)); + + // Compute a + b - c + let add = FoldingCompatibleExpr::Add(a.clone(), b.clone()); + let add = FoldingCompatibleExpr::Sub(Box::new(add), c.clone()); + + // Compute a * b - c + let mul = FoldingCompatibleExpr::Mul(a.clone(), b.clone()); + let mul = FoldingCompatibleExpr::Sub(Box::new(mul), c.clone()); + + // Compute q_add (a + b - c) + q_mul (a * b - c) + [ + (DynamicSelector::SelecAdd, vec![add]), + (DynamicSelector::SelecMul, vec![mul]), + ] + .into_iter() + .collect() +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct TestFoldingConfig; + +// Flag used as the challenges are never built. +// FIXME: should we use unit? +// Does not contain alpha because it should be added to the expressions by folding +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +pub enum TestChallenge { + Beta, + Gamma, + JointCombiner, +} + +impl FoldingConfig for TestFoldingConfig { + type Structure = (); + type Column = TestColumn; + type Selector = DynamicSelector; + type Challenge = TestChallenge; + type Curve = Curve; + type Srs = SRS; + type Instance = TestInstance; + type Witness = TestWitness; + type Env = TestFoldingEnv; +} + +//creates an instance from its witness +fn instance_from_witness( + witness: &TestWitness, + srs: &::Srs, + domain: Radix2EvaluationDomain, +) -> TestInstance { + let commitments = witness + .0 + .iter() + .map(|w| srs.commit_evaluations_non_hiding(domain, w)) + .map(|c| c.get_first_chunk()) + .collect_vec(); + let commitments: [_; 5] = commitments.try_into().unwrap(); + + // here we should absorve the commitments and similar things to later compute challenges + // but for this example I just use random values + let mut rng = thread_rng(); + let mut challenge = || Fp::rand(&mut rng); + let challenges = [(); 3].map(|_| challenge()); + let alpha = challenge(); + let alphas = Alphas::new(alpha); + let blinder = Fp::one(); + TestInstance { + commitments, + challenges, + alphas, + blinder, + } +} + +impl Checker for ExtendedProvider {} + +impl Index for TestInstance { + type Output = Fp; + + fn index(&self, index: TestChallenge) -> &Self::Output { + match index { + TestChallenge::Beta => &self.challenges[0], + TestChallenge::Gamma => &self.challenges[1], + TestChallenge::JointCombiner => &self.challenges[2], + } + } +} + +impl Index for TestWitness { + type Output = Evaluations>; + + fn index(&self, index: TestColumn) -> &Self::Output { + match index { + TestColumn::A => &self.0[0], + TestColumn::B => &self.0[1], + TestColumn::C => &self.0[2], + } + } +} + +impl Index for TestWitness { + type Output = Evaluations>; + + fn index(&self, index: DynamicSelector) -> &Self::Output { + match index { + DynamicSelector::SelecAdd => &self.0[3], + DynamicSelector::SelecMul => &self.0[4], + } + } +} + +// Trick to print debug message while testing, as we in the test config env +// two functions to create the entire witness from just the a and b columns +fn add_witness(a: [u32; 2], b: [u32; 2]) -> [[u32; 2]; 5] { + let [a1, a2] = a; + let [b1, b2] = b; + let c = [a1 + b1, a2 + b2]; + [a, b, c, [1, 1], [0, 0]] +} +fn mul_witness(a: [u32; 2], b: [u32; 2]) -> [[u32; 2]; 5] { + let [a1, a2] = a; + let [b1, b2] = b; + let c = [a1 * b1, a2 * b2]; + [a, b, c, [0, 0], [1, 1]] +} +fn int_to_witness(x: [[u32; 2]; 5], domain: Radix2EvaluationDomain) -> TestWitness { + TestWitness(x.map(|row| Evaluations::from_vec_and_domain(row.map(Fp::from).to_vec(), domain))) +} + +// in this test we will create 2 add witnesses, fold them together, create 2 +// mul witnesses, fold them together, and then further fold the 2 resulting +// pairs into one mixed add-mul witness +// instances are also folded, but not that relevant in the examples as we +// don't make a proof for them and instead directly check the witness +#[test] +fn test_quadriticization() { + let constraints = constraints(); + let domain = D::::new(2).unwrap(); + let srs = SRS::::create(2); + srs.get_lagrange_basis(domain); + + let mut fq_sponge = BaseSponge::new(Curve::other_curve_sponge_params()); + + // initiallize the scheme, also getting the final single expression for + // the entire constraint system + let (scheme, final_constraint) = DecomposableFoldingScheme::::new( + constraints.clone(), + vec![], + &srs, + domain, + &(), + ); + + // some inputs to be used by both add and mul + let inputs1 = [[4u32, 2u32], [2u32, 1u32]]; + let inputs2 = [[5u32, 6u32], [4u32, 3u32]]; + + // creates an instance witness pair + let make_pair = |wit: TestWitness| { + let ins = instance_from_witness(&wit, &srs, domain); + (wit, ins) + }; + + debug!("exp: \n {:#?}", final_constraint.to_string()); + + // fold adds + debug!("fold add"); + + let left = { + let [a, b] = inputs1; + let wit1 = add_witness(a, b); + let (witness1, instance1) = make_pair(int_to_witness(wit1, domain)); + + let [a, b] = inputs2; + let wit2 = add_witness(a, b); + let (witness2, instance2) = make_pair(int_to_witness(wit2, domain)); + + let left = (instance1, witness1); + let right = (instance2, witness2); + // here we provide normal instance-witness pairs, which will be + // automatically relaxed + let folded = scheme.fold_instance_witness_pair( + left, + right, + Some(DynamicSelector::SelecAdd), + &mut fq_sponge, + ); + let FoldingOutput { + folded_instance, + folded_witness, + .. + } = folded; + + { + let folded_instance_explicit = { + let mut fq_sponge_inst = BaseSponge::new(Curve::other_curve_sponge_params()); + scheme.fold_instance_pair( + folded.relaxed_extended_left_instance, + folded.relaxed_extended_right_instance, + [folded.t_0.clone(), folded.t_1.clone()], + &mut fq_sponge_inst, + ) + }; + + assert!(folded_instance == folded_instance_explicit); + } + + let checker = ExtendedProvider::new(folded_instance, folded_witness); + checker.check(&final_constraint, domain); + let ExtendedProvider { + instance, witness, .. + } = checker; + (instance, witness) + }; + + debug!("fold muls"); + + let right = { + let [a, b] = inputs1; + let wit1 = mul_witness(a, b); + let (witness1, instance1) = make_pair(int_to_witness(wit1, domain)); + + let [a, b] = inputs2; + let wit2 = mul_witness(a, b); + let (witness2, instance2) = make_pair(int_to_witness(wit2, domain)); + + let mut fq_sponge_before_fold = fq_sponge.clone(); + + let left = (instance1, witness1); + let right = (instance2, witness2); + let folded = scheme.fold_instance_witness_pair( + left, + right, + Some(DynamicSelector::SelecMul), + &mut fq_sponge, + ); + let FoldingOutput { + folded_instance, + folded_witness, + .. + } = folded; + + { + let folded_instance_explicit = { + scheme.fold_instance_pair( + folded.relaxed_extended_left_instance, + folded.relaxed_extended_right_instance, + [folded.t_0.clone(), folded.t_1.clone()], + &mut fq_sponge_before_fold, + ) + }; + + assert!(folded_instance == folded_instance_explicit); + } + + let checker = ExtendedProvider::new(folded_instance, folded_witness); + + checker.check(&final_constraint, domain); + let ExtendedProvider { + instance, witness, .. + } = checker; + (instance, witness) + }; + + debug!("fold mixed"); + + { + let mut fq_sponge_before_fold = fq_sponge.clone(); + + // here we use already relaxed pairs, which have a trival x -> x implementation + let folded = scheme.fold_instance_witness_pair(left, right, None, &mut fq_sponge); + let FoldingOutput { + folded_instance, + folded_witness, + t_0, + t_1, + relaxed_extended_left_instance, + relaxed_extended_right_instance, + to_absorb: _, + } = folded; + + { + let folded_instance_explicit = { + scheme.fold_instance_pair( + relaxed_extended_left_instance, + relaxed_extended_right_instance, + [t_0.clone(), t_1.clone()], + &mut fq_sponge_before_fold, + ) + }; + + assert!(folded_instance == folded_instance_explicit); + } + + // Verifying that error terms are not points at infinity + // It doesn't test that the computation happens correctly, but at least + // show that there is some non trivial computation. + assert_eq!(t_0.len(), 1); + assert_eq!(t_1.len(), 1); + assert!(!t_0.get_first_chunk().is_zero()); + assert!(!t_1.get_first_chunk().is_zero()); + + let checker = ExtendedProvider::new(folded_instance, folded_witness); + + checker.check(&final_constraint, domain); + }; +} diff --git a/folding/tests/test_quadraticization.rs b/folding/tests/test_quadraticization.rs new file mode 100644 index 0000000000..84f0df360e --- /dev/null +++ b/folding/tests/test_quadraticization.rs @@ -0,0 +1,318 @@ +//! quadraticization test, check that diferent cases result in the expected number of columns +//! being added +use ark_poly::Radix2EvaluationDomain; +use folding::{ + expressions::{FoldingColumnTrait, FoldingCompatibleExprInner}, + instance_witness::Foldable, + Alphas, FoldingCompatibleExpr, FoldingConfig, FoldingEnv, FoldingScheme, Instance, Side, + Witness, +}; +use kimchi::circuits::{expr::Variable, gate::CurrOrNext}; +use poly_commitment::{ipa::SRS, SRS as _}; +use std::{ + collections::hash_map::DefaultHasher, + hash::{Hash, Hasher}, +}; + +pub type Fp = ark_bn254::Fr; +pub type Curve = ark_bn254::G1Affine; + +#[derive(Hash, Clone, Debug, PartialEq, Eq)] +struct TestConfig; + +#[derive(Debug, Clone)] +struct TestInstance; + +impl Foldable for TestInstance { + fn combine(_a: Self, _b: Self, _challenge: Fp) -> Self { + unimplemented!() + } +} + +impl Instance for TestInstance { + fn get_alphas(&self) -> &Alphas { + todo!() + } + + fn to_absorb(&self) -> (Vec, Vec) { + todo!() + } + + fn get_blinder(&self) -> Fp { + todo!() + } +} + +#[derive(Clone)] +struct TestWitness; + +impl Foldable for TestWitness { + fn combine(_a: Self, _b: Self, _challenge: Fp) -> Self { + unimplemented!() + } +} + +impl Witness for TestWitness {} + +#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] +enum Col { + A, + B, +} + +impl FoldingColumnTrait for Col { + fn is_witness(&self) -> bool { + true + } +} + +impl FoldingConfig for TestConfig { + type Column = Col; + + type Selector = (); + + type Challenge = (); + + type Curve = Curve; + + type Srs = SRS; + + type Instance = TestInstance; + + type Witness = TestWitness; + + type Structure = (); + + type Env = Env; +} + +struct Env; + +impl FoldingEnv for Env { + type Structure = (); + + fn new( + _structure: &Self::Structure, + _instances: [&TestInstance; 2], + _witnesses: [&TestWitness; 2], + ) -> Self { + todo!() + } + + fn col(&self, _col: Col, _curr_or_next: CurrOrNext, _side: Side) -> &[Fp] { + todo!() + } + + fn challenge(&self, _challenge: (), _side: Side) -> Fp { + todo!() + } + + fn selector(&self, _s: &(), _side: Side) -> &[Fp] { + todo!() + } +} + +fn degree_1_constraint(col: Col) -> FoldingCompatibleExpr { + let col = Variable { + col, + row: CurrOrNext::Curr, + }; + FoldingCompatibleExpr::Atom(FoldingCompatibleExprInner::Cell(col)) +} + +fn degree_n_constraint(n: usize, col: Col) -> FoldingCompatibleExpr { + match n { + 0 => { + panic!("Degree 0 is not supposed to be used by the test suite.") + } + // base case + 1 => degree_1_constraint(col), + _ => { + let one = degree_1_constraint(col); + let n_minus_one = degree_n_constraint(n - 1, col); + FoldingCompatibleExpr::Mul(Box::new(one), Box::new(n_minus_one)) + } + } +} +// create 2 constraints of degree a and b +fn constraints(a: usize, b: usize) -> Vec> { + vec![ + degree_n_constraint(a, Col::A), + degree_n_constraint(b, Col::B), + ] +} +// creates a scheme with the constraints and returns the number of columns added +fn test_with_constraints(constraints: Vec>) -> usize { + use ark_poly::EvaluationDomain; + + let domain = Radix2EvaluationDomain::::new(2).unwrap(); + let srs = poly_commitment::ipa::SRS::::create(2); + srs.get_lagrange_basis(domain); + + let (scheme, _) = FoldingScheme::::new(constraints, &srs, domain, &()); + // println!("exp:\n {}", exp.to_string()); + scheme.get_number_of_additional_columns() +} + +// 1 constraint of degree 1 +#[test] +fn quadraticization_test_1() { + let mut constraints = constraints(1, 1); + constraints.truncate(1); + assert_eq!(test_with_constraints(constraints), 0); +} + +// 1 constraint of degree 2 +#[test] +fn quadraticization_test_2() { + let mut constraints = constraints(2, 1); + constraints.truncate(1); + assert_eq!(test_with_constraints(constraints), 0); +} + +// 1 constraint of degree 3 +#[test] +fn quadraticization_test_3() { + let mut constraints = constraints(3, 1); + constraints.truncate(1); + assert_eq!(test_with_constraints(constraints), 1); +} + +// 1 constraint of degree 4 to 8 (as we usually support up to 8). +#[test] +fn quadraticization_test_4() { + let cols = [2, 3, 4, 5, 6]; + for i in 4..=8 { + let mut constraints = constraints(i, 1); + constraints.truncate(1); + assert_eq!(test_with_constraints(constraints), cols[i - 4]); + } +} + +// 2 constraints of degree 1 +#[test] +fn quadraticization_test_5() { + let constraints = constraints(1, 1); + assert_eq!(test_with_constraints(constraints), 0); +} + +// 2 constraints, one of degree 1 and one of degree 2 +#[test] +fn quadraticization_test_6() { + let constraints = constraints(1, 2); + assert_eq!(test_with_constraints(constraints), 0); +} + +// 2 constraints: one of degree 1 and one of degree 3 +#[test] +fn quadraticization_test_7() { + let constraints = constraints(1, 3); + assert_eq!(test_with_constraints(constraints), 1); +} + +// 2 constraints, each with degree higher than 2. +#[test] +fn quadraticization_test_8() { + let constraints = constraints(4, 3); + assert_eq!(test_with_constraints(constraints), 3); +} + +// shared subexpression +#[test] +fn quadraticization_test_9() { + // here I duplicate the first constraint + let mut constraints = constraints(3, 1); + constraints.truncate(1); + constraints.push(constraints[0].clone()); + assert_eq!(test_with_constraints(constraints), 1); +} + +#[test] +fn test_equality_folding_compatible_expressions() { + let x: FoldingCompatibleExpr = + FoldingCompatibleExpr::Atom(FoldingCompatibleExprInner::Cell(Variable { + col: Col::A, + row: CurrOrNext::Curr, + })); + let y = FoldingCompatibleExpr::Atom(FoldingCompatibleExprInner::Cell(Variable { + col: Col::A, + row: CurrOrNext::Curr, + })); + + let z = FoldingCompatibleExpr::Atom(FoldingCompatibleExprInner::Cell(Variable { + col: Col::B, + row: CurrOrNext::Curr, + })); + + assert_eq!(x, y); + assert_eq!(x, x); + assert_ne!(x, z); +} + +#[test] +fn test_discriminant_folding_expressions() { + let x: FoldingCompatibleExpr = + FoldingCompatibleExpr::Atom(FoldingCompatibleExprInner::Cell(Variable { + col: Col::A, + row: CurrOrNext::Curr, + })); + let y: FoldingCompatibleExpr = + FoldingCompatibleExpr::Atom(FoldingCompatibleExprInner::Cell(Variable { + col: Col::A, + row: CurrOrNext::Curr, + })); + + let z: FoldingCompatibleExpr = + FoldingCompatibleExpr::Atom(FoldingCompatibleExprInner::Cell(Variable { + col: Col::B, + row: CurrOrNext::Curr, + })); + + let x = x.simplify(); + let y = y.simplify(); + let z = z.simplify(); + + let x_hasher = { + let mut x_hasher = DefaultHasher::new(); + x.hash(&mut x_hasher); + x_hasher.finish() + }; + + // Checking hashing the same element twice give the same result + { + let x_prime_hasher = { + let mut x_hasher = DefaultHasher::new(); + x.hash(&mut x_hasher); + x_hasher.finish() + }; + assert_eq!(x_hasher, x_prime_hasher); + } + + let y_hasher = { + let mut y_hasher = DefaultHasher::new(); + y.hash(&mut y_hasher); + y_hasher.finish() + }; + + let z_hasher = { + let mut z_hasher = DefaultHasher::new(); + z.hash(&mut z_hasher); + z_hasher.finish() + }; + + { + let (y_hasher, z_hasher) = { + let mut hasher = DefaultHasher::new(); + y.hash(&mut hasher); + let y_hasher = hasher.finish(); + + z.hash(&mut hasher); + (y_hasher, hasher.finish()) + }; + assert_ne!(y_hasher, z_hasher); + } + + assert_eq!(x_hasher, x_hasher); + assert_eq!(x_hasher, y_hasher); + assert_ne!(x_hasher, z_hasher); +} diff --git a/folding/tests/test_vanilla_folding.rs b/folding/tests/test_vanilla_folding.rs new file mode 100644 index 0000000000..5a137bafb1 --- /dev/null +++ b/folding/tests/test_vanilla_folding.rs @@ -0,0 +1,510 @@ +/// This file shows how to use the folding trait with a simple configuration of +/// 3 columns, and two selectors. See [tests::test_folding_instance] at the end +/// for a test. +/// The test requires the features `bn254`, therefore use the following command +/// to execute it: +/// ```text +/// cargo nextest run test_folding_instance --release --all-features +/// ``` +use ark_ec::AffineRepr; +use ark_ff::{One, UniformRand, Zero}; +use ark_poly::{EvaluationDomain, Evaluations, Radix2EvaluationDomain}; +use checker::{ExtendedProvider, Provider}; +use folding::{ + checker::{Checker, Column, Provide}, + expressions::FoldingCompatibleExprInner, + instance_witness::Foldable, + Alphas, ExpExtension, FoldingCompatibleExpr, FoldingConfig, FoldingEnv, FoldingOutput, + FoldingScheme, Instance, RelaxedInstance, RelaxedWitness, Side, Witness, +}; +use itertools::Itertools; +use kimchi::{ + circuits::{expr::Variable, gate::CurrOrNext}, + curve::KimchiCurve, +}; +use mina_poseidon::{constants::PlonkSpongeConstantsKimchi, sponge::DefaultFqSponge, FqSponge}; +use poly_commitment::{ipa::SRS, SRS as _}; +use rand::thread_rng; +use std::println as debug; + +type Fp = ark_bn254::Fr; +type Curve = ark_bn254::G1Affine; +type SpongeParams = PlonkSpongeConstantsKimchi; +type BaseSponge = DefaultFqSponge; + +/// The instance is the commitments to the polynomials and the challenges +/// There are 3 commitments and challanges because there are 3 columns, A, B and +/// C. +#[derive(PartialEq, Eq, Debug, Clone)] +struct TestInstance { + commitments: [Curve; 3], + challenges: [Fp; 3], + alphas: Alphas, + blinder: Fp, +} + +impl Foldable for TestInstance { + fn combine(a: Self, b: Self, challenge: Fp) -> Self { + TestInstance { + commitments: std::array::from_fn(|i| { + (a.commitments[i] + b.commitments[i] * challenge).into() + }), + challenges: std::array::from_fn(|i| a.challenges[i] + challenge * b.challenges[i]), + alphas: Alphas::combine(a.alphas, b.alphas, challenge), + blinder: a.blinder + challenge * b.blinder, + } + } +} + +impl Instance for TestInstance { + fn to_absorb(&self) -> (Vec, Vec) { + let mut fields = Vec::with_capacity(3 + 2); + fields.extend(self.challenges); + fields.extend(self.alphas.clone().powers()); + assert_eq!(fields.len(), 5); + let points = self.commitments.to_vec(); + (fields, points) + } + + fn get_alphas(&self) -> &Alphas { + &self.alphas + } + + fn get_blinder(&self) -> Fp { + self.blinder + } +} + +/// Our witness is going to be the polynomials that we will commit too. +/// Vec will be the evaluations of each x_1, x_2 and x_3 over the domain. +#[derive(Clone)] +struct TestWitness([Evaluations>; 3]); + +impl Foldable for TestWitness { + fn combine(mut a: Self, b: Self, challenge: Fp) -> Self { + for (a, b) in a.0.iter_mut().zip(b.0) { + for (a, b) in a.evals.iter_mut().zip(b.evals) { + *a += challenge * b; + } + } + a + } +} + +impl Witness for TestWitness {} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +struct TestStructure { + s_add: Vec, + s_mul: Vec, + constants: Vec, +} + +struct TestFoldingEnv { + structure: TestStructure, + instances: [TestInstance; 2], + // Corresponds to the omega evaluations, for both sides + curr_witnesses: [TestWitness; 2], + // Corresponds to the zeta*omega evaluations, for both sides + // This is curr_witness but left shifted by 1 + next_witnesses: [TestWitness; 2], +} + +impl FoldingEnv for TestFoldingEnv { + type Structure = TestStructure; + + fn new( + structure: &Self::Structure, + instances: [&TestInstance; 2], + witnesses: [&TestWitness; 2], + ) -> Self { + let curr_witnesses = [witnesses[0].clone(), witnesses[1].clone()]; + let mut next_witnesses = curr_witnesses.clone(); + for side in next_witnesses.iter_mut() { + for col in side.0.iter_mut() { + // TODO: check this, while not relevant for this example I think + // it should be right rotation + col.evals.rotate_left(1); + } + } + TestFoldingEnv { + structure: structure.clone(), + instances: [instances[0].clone(), instances[1].clone()], + curr_witnesses, + next_witnesses, + } + } + + fn col(&self, col: Column, curr_or_next: CurrOrNext, side: Side) -> &[Fp] { + let wit = match curr_or_next { + CurrOrNext::Curr => &self.curr_witnesses[side as usize], + CurrOrNext::Next => &self.next_witnesses[side as usize], + }; + match col { + Column::X(0) => &wit.0[0].evals, + Column::X(1) => &wit.0[1].evals, + Column::X(2) => &wit.0[2].evals, + Column::Selector(0) => &self.structure.s_add, + Column::Selector(1) => &self.structure.s_mul, + // Only 3 columns and 2 selectors + Column::X(_) => unreachable!(), + Column::Selector(_) => unreachable!(), + } + } + + fn challenge(&self, challenge: TestChallenge, side: Side) -> Fp { + match challenge { + TestChallenge::Beta => self.instances[side as usize].challenges[0], + TestChallenge::Gamma => self.instances[side as usize].challenges[1], + TestChallenge::JointCombiner => self.instances[side as usize].challenges[2], + } + } + + fn selector(&self, _s: &(), _side: Side) -> &[Fp] { + unreachable!() + } +} + +fn constraints() -> Vec> { + let get_col = |col| { + FoldingCompatibleExpr::Atom(FoldingCompatibleExprInner::Cell(Variable { + col, + row: CurrOrNext::Curr, + })) + }; + let a = Box::new(get_col(Column::X(0))); + let b = Box::new(get_col(Column::X(1))); + let c = Box::new(get_col(Column::X(2))); + let s_add = Box::new(get_col(Column::Selector(0))); + let s_mul = Box::new(get_col(Column::Selector(1))); + + let add = FoldingCompatibleExpr::Add(a.clone(), b.clone()); + let add = FoldingCompatibleExpr::Sub(add.into(), c.clone()); + let add = FoldingCompatibleExpr::Mul(add.into(), s_add.clone()); + + let mul = FoldingCompatibleExpr::Mul(a.clone(), b.clone()); + let mul = FoldingCompatibleExpr::Sub(mul.into(), c.clone()); + let mul = FoldingCompatibleExpr::Mul(mul.into(), s_mul.clone()); + + vec![add, mul] +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +struct TestFoldingConfig; + +// Does not contain alpha because this one should be provided by folding itself +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +// Flag used as the challenges are never built. +// FIXME: should we use unit? +#[allow(dead_code)] +enum TestChallenge { + Beta, + Gamma, + JointCombiner, +} + +impl FoldingConfig for TestFoldingConfig { + type Structure = TestStructure; + type Column = Column; + type Selector = (); + type Challenge = TestChallenge; + type Curve = Curve; + type Srs = SRS; + type Instance = TestInstance; + type Witness = TestWitness; + type Env = TestFoldingEnv; +} + +fn instance_from_witness( + witness: &TestWitness, + srs: &::Srs, + domain: Radix2EvaluationDomain, +) -> TestInstance { + let commitments = witness + .0 + .iter() + .map(|w| srs.commit_evaluations_non_hiding(domain, w)) + .map(|c| c.get_first_chunk()) + .collect_vec(); + let commitments: [_; 3] = commitments.try_into().unwrap(); + + // here we should absorve the commitments and similar things to later + // compute challenges but for this example I just use random values + let mut rng = thread_rng(); + let mut challenge = || Fp::rand(&mut rng); + let challenges = [(); 3].map(|_| challenge()); + let alpha = challenge(); + let alphas = Alphas::new(alpha); + // We suppose we always have a blinder to one. + let blinder = Fp::one(); + TestInstance { + commitments, + challenges, + alphas, + blinder, + } +} + +fn circuit() -> [Vec; 2] { + [vec![Fp::one(), Fp::zero()], vec![Fp::zero(), Fp::one()]] +} + +/// A kind of pseudo-prover, will compute the expressions over the witness a +/// check row by row for a zero result. +mod checker { + use super::*; + + pub struct Provider { + structure: TestStructure, + instance: TestInstance, + witness: TestWitness, + } + + impl Provider { + pub(super) fn new( + structure: TestStructure, + instance: TestInstance, + witness: TestWitness, + ) -> Self { + Self { + structure, + instance, + witness, + } + } + } + + impl Provide for Provider { + fn resolve( + &self, + inner: FoldingCompatibleExprInner, + domain: Radix2EvaluationDomain, + ) -> Vec { + let domain_size = domain.size as usize; + match inner { + FoldingCompatibleExprInner::Constant(c) => { + vec![c; domain_size] + } + FoldingCompatibleExprInner::Challenge(chall) => { + let chals = self.instance.challenges; + let v = match chall { + TestChallenge::Beta => chals[0], + TestChallenge::Gamma => chals[1], + TestChallenge::JointCombiner => chals[2], + }; + vec![v; domain_size] + } + FoldingCompatibleExprInner::Cell(var) => { + let Variable { col, row } = var; + let col = match col { + Column::X(0) => &self.witness.0[0].evals, + Column::X(1) => &self.witness.0[1].evals, + Column::X(2) => &self.witness.0[2].evals, + Column::Selector(0) => &self.structure.s_add, + Column::Selector(1) => &self.structure.s_mul, + // Only 3 columns and 2 selectors + Column::X(_) => unreachable!(), + Column::Selector(_) => unreachable!(), + }; + + let mut col = col.clone(); + // check this, while not relevant in this case I think it + // should be right rotation + if let CurrOrNext::Next = row { + col.rotate_left(1); + } + col + } + FoldingCompatibleExprInner::Extensions(_) => { + panic!("not handled") + } + } + } + } + + pub struct ExtendedProvider { + inner_provider: Provider, + instance: RelaxedInstance<::Curve, TestInstance>, + witness: RelaxedWitness<::Curve, TestWitness>, + } + + impl ExtendedProvider { + pub(super) fn new( + structure: TestStructure, + instance: RelaxedInstance<::Curve, TestInstance>, + witness: RelaxedWitness<::Curve, TestWitness>, + ) -> Self { + let inner_provider = { + let instance = instance.extended_instance.instance.clone(); + let witness = witness.extended_witness.witness.clone(); + Provider::new(structure, instance, witness) + }; + Self { + inner_provider, + instance, + witness, + } + } + } + + impl Provide for ExtendedProvider { + fn resolve( + &self, + inner: FoldingCompatibleExprInner, + domain: Radix2EvaluationDomain, + ) -> Vec { + let domain_size = domain.size as usize; + match inner { + FoldingCompatibleExprInner::Extensions(ext) => match ext { + ExpExtension::U => { + let u = self.instance.u; + vec![u; domain_size] + } + ExpExtension::Error => self.witness.error_vec.evals.clone(), + ExpExtension::ExtendedWitness(i) => self + .witness + .extended_witness + .extended + .get(&i) + .unwrap() + .evals + .clone(), + ExpExtension::Alpha(i) => { + let alpha = self + .instance + .extended_instance + .instance + .alphas + .get(i) + .unwrap(); + vec![alpha; domain_size] + } + ExpExtension::Selector(_) => panic!("unused"), + }, + e => self.inner_provider.resolve(e, domain), + } + } + } + + impl Checker for Provider {} + impl Checker for ExtendedProvider {} +} + +// this checks a single folding, it would be good to expand it in the future +// to do several foldings, as a few thigs are trivial in the first fold +#[test] +fn test_folding_instance() { + let constraints = constraints(); + let domain = Radix2EvaluationDomain::::new(2).unwrap(); + let srs = poly_commitment::ipa::SRS::::create(2); + srs.get_lagrange_basis(domain); + + let mut fq_sponge = BaseSponge::new(Curve::other_curve_sponge_params()); + + let [s_add, s_mul] = circuit(); + let structure = TestStructure { + s_add, + s_mul, + constants: vec![], + }; + + let (scheme, final_constraint) = + FoldingScheme::::new(constraints.clone(), &srs, domain, &structure); + + // We have a 2 row circuit with and addition gate in the first row, and a multiplication gate in the second + + // Left: 1 + 2 - 3 = 0 + let left_witness = [ + vec![Fp::from(1u32), Fp::from(2u32)], + vec![Fp::from(2u32), Fp::from(3u32)], + vec![Fp::from(3u32), Fp::from(6u32)], + ]; + let left_witness: TestWitness = + TestWitness(left_witness.map(|evals| Evaluations::from_vec_and_domain(evals, domain))); + // Right: 4 + 5 - 9 = 0 + let right_witness = [ + vec![Fp::from(4u32), Fp::from(3u32)], + vec![Fp::from(5u32), Fp::from(6u32)], + vec![Fp::from(9u32), Fp::from(18u32)], + ]; + let right_witness: TestWitness = + TestWitness(right_witness.map(|evals| Evaluations::from_vec_and_domain(evals, domain))); + + // instances + let left_instance = instance_from_witness(&left_witness, &srs, domain); + let right_instance = instance_from_witness(&left_witness, &srs, domain); + + // check left + { + debug!("check left"); + let checker = Provider::new( + structure.clone(), + left_instance.clone(), + left_witness.clone(), + ); + constraints + .iter() + .for_each(|constraint| checker.check(constraint, domain)); + } + // check right + { + debug!("check right"); + let checker = Provider::new( + structure.clone(), + right_instance.clone(), + right_witness.clone(), + ); + constraints + .iter() + .for_each(|constraint| checker.check(constraint, domain)); + } + + // pairs + let left = (left_instance.clone(), left_witness); + let right = (right_instance.clone(), right_witness); + + let folded = scheme.fold_instance_witness_pair(left, right, &mut fq_sponge); + let FoldingOutput { + folded_instance, + folded_witness, + t_0, + t_1, + relaxed_extended_left_instance, + relaxed_extended_right_instance, + to_absorb, + } = folded; + + { + let folded_instance_explicit = { + let mut fq_sponge_inst = BaseSponge::new(Curve::other_curve_sponge_params()); + scheme.fold_instance_pair( + relaxed_extended_left_instance, + relaxed_extended_right_instance, + [t_0.clone(), t_1.clone()], + &mut fq_sponge_inst, + ) + }; + + assert!(folded_instance == folded_instance_explicit); + } + + // Verifying that error terms are not points at infinity + // It doesn't test that the computation happens correctly, but at least + // show that there is some non trivial computation. + assert_eq!(t_0.len(), 1); + assert_eq!(t_1.len(), 1); + assert!(!t_0.get_first_chunk().is_zero()); + assert!(!t_1.get_first_chunk().is_zero()); + + // checking that we have the expected number of elements to absorb + // 3+2 from each instance + 1 from u, times 2 instances + assert_eq!(to_absorb.0.len(), (3 + 2 + 1) * 2); + // 3 from each instance + 1 from E, times 2 instances + t_0 + t_1 + assert_eq!(to_absorb.1.len(), (3 + 1) * 2 + 2); + { + let checker = ExtendedProvider::new(structure, folded_instance, folded_witness); + debug!("exp: \n {:#?}", final_constraint); + debug!("check folded"); + checker.check(&final_constraint, domain); + } +} diff --git a/hasher/src/lib.rs b/hasher/src/lib.rs index cc2b0998f8..0f3a817394 100644 --- a/hasher/src/lib.rs +++ b/hasher/src/lib.rs @@ -7,9 +7,6 @@ pub use mina_curves::pasta::Fp; pub use poseidon::{PoseidonHasherKimchi, PoseidonHasherLegacy}; pub use roinput::ROInput; -#[cfg(test)] -mod tests; - use ark_ff::PrimeField; use o1_utils::FieldHelpers; diff --git a/hasher/src/tests/hasher.rs b/hasher/src/tests/hasher.rs deleted file mode 100644 index b470d30567..0000000000 --- a/hasher/src/tests/hasher.rs +++ /dev/null @@ -1,74 +0,0 @@ -use crate::{create_kimchi, create_legacy, Fp, Hashable, Hasher, ROInput}; -use o1_utils::FieldHelpers; -use serde::Deserialize; -use std::{fs::File, path::PathBuf}; - -// -// Helpers for test vectors -// - -#[derive(Debug, Deserialize)] -struct TestVectors { - test_vectors: Vec, -} - -#[derive(Clone, Debug, Deserialize)] -struct TestVector { - input: Vec, - output: String, -} - -impl Hashable for TestVector { - type D = (); - - fn to_roinput(&self) -> ROInput { - let mut roi = ROInput::new(); - // For hashing we only care about the input part - for input in &self.input { - roi = - roi.append_field(Fp::from_hex(input).expect("failed to deserialize field element")); - } - roi - } - - fn domain_string(_: Self::D) -> Option { - None - } -} - -fn test_vectors(test_vector_file: &str, hasher: &mut dyn Hasher) { - // read test vectors from given file - let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); - path.push("../poseidon/src/tests/test_vectors"); - path.push(test_vector_file); - - let file = File::open(&path).expect("couldn't open test vector file"); - let test_vectors: TestVectors = - serde_json::from_reader(file).expect("couldn't deserialize test vector file"); - - // execute test vectors - for test_vector in test_vectors.test_vectors { - let expected_output = - Fp::from_hex(&test_vector.output).expect("failed to deserialize field element"); - - // hash & check against expect output - let output = hasher.hash(&test_vector); - assert_eq!(output, expected_output); - } -} - -// -// Tests -// - -#[test] -fn hasher_test_vectors_legacy() { - let mut hasher = create_legacy::(()); - test_vectors("legacy.json", &mut hasher); -} - -#[test] -fn hasher_test_vectors_kimchi() { - let mut hasher = create_kimchi::(()); - test_vectors("kimchi.json", &mut hasher); -} diff --git a/hasher/src/tests/mod.rs b/hasher/src/tests/mod.rs deleted file mode 100644 index 99a8232ba6..0000000000 --- a/hasher/src/tests/mod.rs +++ /dev/null @@ -1,60 +0,0 @@ -use crate::{create_legacy, Hashable, Hasher, ROInput}; - -mod hasher; - -#[test] -fn interfaces() { - #[derive(Clone)] - struct Foo { - x: u32, - y: u64, - } - - impl Hashable for Foo { - type D = u64; - - fn to_roinput(&self) -> ROInput { - ROInput::new().append_u32(self.x).append_u64(self.y) - } - - fn domain_string(id: u64) -> Option { - format!("Foo {id}").into() - } - } - - // Usage 1: incremental interface - let mut hasher = create_legacy::(0); - hasher.update(&Foo { x: 3, y: 1 }); - let x1 = hasher.digest(); // Resets to previous init state (0) - hasher.update(&Foo { x: 82, y: 834 }); - hasher.update(&Foo { x: 1235, y: 93 }); - hasher.digest(); // Resets to previous init state (0) - hasher.init(1); - hasher.update(&Foo { x: 82, y: 834 }); - let x2 = hasher.digest(); // Resets to previous init state (1) - - // Usage 2: builder interface with one-shot pattern - let mut hasher = create_legacy::(0); - let y1 = hasher.update(&Foo { x: 3, y: 1 }).digest(); // Resets to previous init state (0) - hasher.update(&Foo { x: 31, y: 21 }).digest(); - - // Usage 3: builder interface with one-shot pattern also setting init state - let mut hasher = create_legacy::(0); - let y2 = hasher.init(0).update(&Foo { x: 3, y: 1 }).digest(); // Resets to previous init state (1) - let y3 = hasher.init(1).update(&Foo { x: 82, y: 834 }).digest(); // Resets to previous init state (2) - - // Usage 4: one-shot interfaces - let mut hasher = create_legacy::(0); - let y4 = hasher.hash(&Foo { x: 3, y: 1 }); - let y5 = hasher.init_and_hash(1, &Foo { x: 82, y: 834 }); - - assert_eq!(x1, y1); - assert_eq!(x1, y2); - assert_eq!(x2, y3); - assert_eq!(x1, y4); - assert_eq!(x2, y5); - assert_ne!(x1, y5); - assert_ne!(x2, y4); - assert_ne!(x1, y3); - assert_ne!(x2, y2); -} diff --git a/hasher/tests/hasher.rs b/hasher/tests/hasher.rs new file mode 100644 index 0000000000..d90809bf85 --- /dev/null +++ b/hasher/tests/hasher.rs @@ -0,0 +1,131 @@ +use mina_hasher::{create_kimchi, create_legacy, Fp, Hashable, Hasher, ROInput}; +use o1_utils::FieldHelpers; +use serde::Deserialize; +use std::{fs::File, path::PathBuf}; + +// +// Helpers for test vectors +// + +#[derive(Debug, Deserialize)] +struct TestVectors { + test_vectors: Vec, +} + +#[derive(Clone, Debug, Deserialize)] +struct TestVector { + input: Vec, + output: String, +} + +impl Hashable for TestVector { + type D = (); + + fn to_roinput(&self) -> ROInput { + let mut roi = ROInput::new(); + // For hashing we only care about the input part + for input in &self.input { + roi = + roi.append_field(Fp::from_hex(input).expect("failed to deserialize field element")); + } + roi + } + + fn domain_string(_: Self::D) -> Option { + None + } +} + +fn test_vectors(test_vector_file: &str, hasher: &mut dyn Hasher) { + // read test vectors from given file + let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + path.push("../poseidon/tests/test_vectors"); + path.push(test_vector_file); + + let file = File::open(&path).expect("couldn't open test vector file"); + let test_vectors: TestVectors = + serde_json::from_reader(file).expect("couldn't deserialize test vector file"); + + // execute test vectors + for test_vector in test_vectors.test_vectors { + let expected_output = + Fp::from_hex(&test_vector.output).expect("failed to deserialize field element"); + + // hash & check against expect output + let output = hasher.hash(&test_vector); + assert_eq!(output, expected_output); + } +} + +// +// Tests +// + +#[test] +fn hasher_test_vectors_legacy() { + let mut hasher = create_legacy::(()); + test_vectors("legacy.json", &mut hasher); +} + +#[test] +fn hasher_test_vectors_kimchi() { + let mut hasher = create_kimchi::(()); + test_vectors("kimchi.json", &mut hasher); +} + +#[test] +fn interfaces() { + #[derive(Clone)] + struct Foo { + x: u32, + y: u64, + } + + impl Hashable for Foo { + type D = u64; + + fn to_roinput(&self) -> ROInput { + ROInput::new().append_u32(self.x).append_u64(self.y) + } + + fn domain_string(id: u64) -> Option { + format!("Foo {id}").into() + } + } + + // Usage 1: incremental interface + let mut hasher = create_legacy::(0); + hasher.update(&Foo { x: 3, y: 1 }); + let x1 = hasher.digest(); // Resets to previous init state (0) + hasher.update(&Foo { x: 82, y: 834 }); + hasher.update(&Foo { x: 1235, y: 93 }); + hasher.digest(); // Resets to previous init state (0) + hasher.init(1); + hasher.update(&Foo { x: 82, y: 834 }); + let x2 = hasher.digest(); // Resets to previous init state (1) + + // Usage 2: builder interface with one-shot pattern + let mut hasher = create_legacy::(0); + let y1 = hasher.update(&Foo { x: 3, y: 1 }).digest(); // Resets to previous init state (0) + hasher.update(&Foo { x: 31, y: 21 }).digest(); + + // Usage 3: builder interface with one-shot pattern also setting init state + let mut hasher = create_legacy::(0); + let y2 = hasher.init(0).update(&Foo { x: 3, y: 1 }).digest(); // Resets to previous init state (1) + let y3 = hasher.init(1).update(&Foo { x: 82, y: 834 }).digest(); // Resets to previous init state (2) + + // Usage 4: one-shot interfaces + let mut hasher = create_legacy::(0); + let y4 = hasher.hash(&Foo { x: 3, y: 1 }); + let y5 = hasher.init_and_hash(1, &Foo { x: 82, y: 834 }); + + assert_eq!(x1, y1); + assert_eq!(x1, y2); + assert_eq!(x2, y3); + assert_eq!(x1, y4); + assert_eq!(x2, y5); + assert_ne!(x1, y5); + assert_ne!(x2, y4); + assert_ne!(x1, y3); + assert_ne!(x2, y2); +} diff --git a/internal-tracing/src/lib.rs b/internal-tracing/src/lib.rs index 7cf3b8a438..f274a50237 100644 --- a/internal-tracing/src/lib.rs +++ b/internal-tracing/src/lib.rs @@ -1,9 +1,5 @@ use std::time::SystemTime; -mod internal_tracing { - pub use crate::*; -} - #[cfg(feature = "enabled")] pub use serde_json::{json, to_writer as json_to_writer, Value as JsonValue}; diff --git a/ivc/Cargo.toml b/ivc/Cargo.toml new file mode 100644 index 0000000000..f6310ab9bd --- /dev/null +++ b/ivc/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "ivc" +version = "0.1.0" +edition = "2021" + +[dependencies] +# We should make this optional in the future. +# We want to make the crate generic for any curve that has nearly the same +# properties than bn254 +ark-bn254.workspace = true +ark-ec.workspace = true +ark-ff.workspace = true +ark-poly.workspace = true +folding.workspace = true +kimchi.workspace = true +kimchi-msm.workspace = true +itertools.workspace = true +mina-poseidon.workspace = true +num-bigint.workspace = true +num-integer.workspace = true +o1-utils.workspace = true +once_cell.workspace = true +poly-commitment.workspace = true +rand.workspace = true +rayon.workspace = true +strum.workspace = true +strum_macros.workspace = true +thiserror.workspace = true \ No newline at end of file diff --git a/ivc/README.md b/ivc/README.md new file mode 100644 index 0000000000..a3659eaecf --- /dev/null +++ b/ivc/README.md @@ -0,0 +1,13 @@ +## IVC circuit for folding over the same curve + +This crate implements the circuit that can be used with the folding crate to +achieve incrementally verifiable computation. +The particularity is that it doesn't require a cycle of curve. +It is aimed to be used with the MSM circuit. + +### Structure + +Integration tests with different circuits are available in the subdirectory +`tests`. +They can be executed by running `cargo nextest run --release` at the level of +this file. diff --git a/ivc/src/expr_eval.rs b/ivc/src/expr_eval.rs new file mode 100644 index 0000000000..7fffa1af36 --- /dev/null +++ b/ivc/src/expr_eval.rs @@ -0,0 +1,193 @@ +use std::ops::Index; + +use crate::plonkish_lang::{CombinableEvals, PlonkishChallenge, PlonkishWitnessGeneric}; +use ark_ec::AffineRepr; +use ark_ff::Field; +use ark_poly::{Evaluations, Radix2EvaluationDomain as R2D}; +use folding::{ + columns::ExtendedFoldingColumn, + eval_leaf::EvalLeaf, + expressions::{ExpExtension, FoldingCompatibleExprInner, FoldingExp}, + instance_witness::ExtendedWitness, + Alphas, FoldingCompatibleExpr, FoldingConfig, +}; +use kimchi::{ + self, + circuits::{expr::Variable, gate::CurrOrNext}, + curve::KimchiCurve, +}; +use kimchi_msm::columns::Column as GenericColumn; +use strum::EnumCount; + +#[derive(Clone)] +/// Generic structure containing column vectors. +pub struct GenericVecStructure(pub Vec>); + +impl Index for GenericVecStructure { + type Output = [G::ScalarField]; + + fn index(&self, index: GenericColumn) -> &Self::Output { + match index { + GenericColumn::FixedSelector(i) => &self.0[i], + _ => panic!("should not happen"), + } + } +} + +/// Minimal environment needed for evaluating constraints. +pub struct GenericEvalEnv< + Curve: KimchiCurve, + const N_COL: usize, + const N_FSEL: usize, + Eval: CombinableEvals, +> { + pub ext_witness: + ExtendedWitness>, + pub alphas: Alphas, + pub challenges: [Curve::ScalarField; PlonkishChallenge::COUNT], + pub error_vec: Eval, + /// The scalar `u` that is used to homogenize the polynomials + pub u: Curve::ScalarField, +} + +pub type SimpleEvalEnv = GenericEvalEnv< + Curve, + N_COL, + N_FSEL, + Evaluations<::ScalarField, R2D<::ScalarField>>, +>; + +impl< + Curve: KimchiCurve, + const N_COL: usize, + const N_FSEL: usize, + Evals: CombinableEvals, + > GenericEvalEnv +{ + fn challenge(&self, challenge: PlonkishChallenge) -> Curve::ScalarField { + match challenge { + PlonkishChallenge::Beta => self.challenges[0], + PlonkishChallenge::Gamma => self.challenges[1], + PlonkishChallenge::JointCombiner => self.challenges[2], + } + } + + pub fn process_extended_folding_column< + FC: FoldingConfig, + >( + &self, + col: &ExtendedFoldingColumn, + ) -> EvalLeaf { + use EvalLeaf::Col; + use ExtendedFoldingColumn::*; + match col { + Inner(Variable { col, row }) => { + let wit = match row { + CurrOrNext::Curr => &self.ext_witness.witness, + CurrOrNext::Next => panic!("not implemented"), + }; + // The following is possible because Index is implemented for our + // circuit witnesses + Col(&wit[*col]) + }, + WitnessExtended(i) => Col(&self.ext_witness.extended.get(i).unwrap().evals), + Error => panic!("shouldn't happen"), + Constant(c) => EvalLeaf::Const(*c), + Challenge(chall) => EvalLeaf::Const(self.challenge(*chall)), + Alpha(i) => { + let alpha = self.alphas.get(*i).expect("alpha not present"); + EvalLeaf::Const(alpha) + } + Selector(_s) => unimplemented!("Selector not implemented for FoldingEnvironment. No selectors are supposed to be used when it is Plonkish relations."), + } + } + + /// Evaluates the expression in the provided side + pub fn eval_naive_fexpr< + 'a, + FC: FoldingConfig, + >( + &'a self, + exp: &FoldingExp, + ) -> EvalLeaf<'a, Curve::ScalarField> { + use FoldingExp::*; + + match exp { + Atom(column) => self.process_extended_folding_column(column), + Double(e) => { + let col = self.eval_naive_fexpr(e); + col.map(Field::double, |f| { + Field::double_in_place(f); + }) + } + Square(e) => { + let col = self.eval_naive_fexpr(e); + col.map(Field::square, |f| { + Field::square_in_place(f); + }) + } + Add(e1, e2) => self.eval_naive_fexpr(e1) + self.eval_naive_fexpr(e2), + Sub(e1, e2) => self.eval_naive_fexpr(e1) - self.eval_naive_fexpr(e2), + Mul(e1, e2) => self.eval_naive_fexpr(e1) * self.eval_naive_fexpr(e2), + Pow(_e, _i) => panic!("We're not supposed to use this"), + } + } + + /// For FoldingCompatibleExp + pub fn eval_naive_fcompat< + 'a, + FC: FoldingConfig, + >( + &'a self, + exp: &FoldingCompatibleExpr, + ) -> EvalLeaf<'a, Curve::ScalarField> where { + use FoldingCompatibleExpr::*; + + match exp { + Atom(column) => { + use FoldingCompatibleExprInner::*; + match column { + Cell(Variable { col, row }) => { + let wit = match row { + CurrOrNext::Curr => &self.ext_witness.witness, + CurrOrNext::Next => panic!("not implemented"), + }; + // The following is possible because Index is implemented for our + // circuit witnesses + EvalLeaf::Col(&wit[*col]) + } + Challenge(chal) => EvalLeaf::Const(self.challenge(*chal)), + Constant(c) => EvalLeaf::Const(*c), + Extensions(ext) => { + use ExpExtension::*; + match ext { + U => EvalLeaf::Const(self.u), + Error => EvalLeaf::Col(self.error_vec.e_as_slice()), + ExtendedWitness(i) => { + EvalLeaf::Col(&self.ext_witness.extended.get(i).unwrap().evals) + } + Alpha(i) => EvalLeaf::Const(self.alphas.get(*i).unwrap()), + Selector(_sel) => panic!("No selectors supported yet"), + } + } + } + } + Double(e) => { + let col = self.eval_naive_fcompat(e); + col.map(Field::double, |f| { + Field::double_in_place(f); + }) + } + Square(e) => { + let col = self.eval_naive_fcompat(e); + col.map(Field::square, |f| { + Field::square_in_place(f); + }) + } + Add(e1, e2) => self.eval_naive_fcompat(e1) + self.eval_naive_fcompat(e2), + Sub(e1, e2) => self.eval_naive_fcompat(e1) - self.eval_naive_fcompat(e2), + Mul(e1, e2) => self.eval_naive_fcompat(e1) * self.eval_naive_fcompat(e2), + Pow(_e, _i) => panic!("We're not supposed to use this"), + } + } +} diff --git a/ivc/src/ivc/columns.rs b/ivc/src/ivc/columns.rs new file mode 100644 index 0000000000..4a3311a2ea --- /dev/null +++ b/ivc/src/ivc/columns.rs @@ -0,0 +1,572 @@ +//! IVC circuit layout - Top level documentation outdated +//! +//! The IVC circuit is tiled vertically. We assume we have as many +//! rows as we need: if we don't, we wrap around and continue. +//! +//! The biggest blocks are hashes and ECAdds, so other blocks may be wider. +//! +//! `N := N_IVC + N_APP` is the total number of columns in the circuit. +//! +//! Vertically stacked blocks are as follows: +//! +//!```text +//! +//! Inputs: +//! Each point is 2 base field coordinates in 17 15-bit limbs +//! recomposed as 8 75-bit limbs +//! recomposed as 4 150-bit limbs. +//! +//! 34 8 4 +//! Input1 R75 R150 +//! 1 |-----------------|-------|----| +//! | C_L1 | | | +//! | C_L2 | | | +//! | | | | +//! | ... | | | +//! N |-----------------| | | +//! | C_R1 | | | +//! | | | | +//! | ... | | | +//! 2N |-----------------| | | +//! | C_O1 | | | +//! | | | | +//! | ... | | | +//! 3N |-----------------|-------|----| +//! 0 ... 34*2 76 80 +//! +//! +//! Hashes +//! (one hash at a row, passing data to the next one) +//! (for i∈N, the input row #i containing 4 150-bit elements +//! is processed by hash rows 2*i and 2*i+1) +//! +//! 1 |------------------------------------------| +//! | | +//! | .| . here is h_l +//! 2N |------------------------------------------| must be equal to public input! +//! | | (H_i in nova) +//! | .| . here is h_r +//! 4N |------------------------------------------| +//! | | +//! | .| . here is h_o, equal to H_{i+1} in Nova +//! 6N |------------------------------------------| +//! | .| r = h_lr = h(h_l,h_r) +//! | .| ϕ = h_lro = h(r,h_o) +//! 6N+2 |------------------------------------------| +//! +//! TODO: we also need to squeeze challenges for +//! the right (strict) instance: β, γ, j (joint_combiner) +//! +//! TODO: we can hash (x0,x1+b*2^150) instead of (x0,x1,y0,y1). +//! +//! Scalars block. +//! +//! Most of the r^2 and r^3 cells are /unused/, and this design can be +//! much more optimal if r^2|r^3 elements come in a separate block. +//! But the overhead is not big and it's a very easy layout, so +//! keeping it for now. +//! +//! constϕ +//! constr +//! ϕ^i r^3·ϕ^i ϕ^i r*ϕ^i r^2·ϕ^i_k r^3·ϕ^i_k +//! r*ϕ^i in 17 limbs in 17 limbs ... ... +//! r^2·ϕ^i each each +//! 1 |---|---|-----|-------|----|----|------------|------------|------------|------------| +//! | ϕ r ϕ rϕ r^2ϕ r^3ϕ| | | | | +//! | ϕ r ϕ^2 rϕ^2 | | | | | +//! | ϕ r ϕ^3 rϕ^3 | | | | | +//! | | | | | | +//! | | | | | | +//! | | | | | | +//! i | | | | | | +//! | | | | | | +//! | | | | | | +//! | | | | | | +//! | ϕ^{N+1} | | | | | +//! N+1 |-------------------------------|------------|------------|------------|------------| +//! 1 2 3 4 5 6 ... 6+17 6+2*17 6+4*17 +//! +//! +//! We compute the following equations, where equations in "quotes" are +//! what we /want/ to prove, and non-quoted is an equavilant version +//! that we actually prove instead: +//! - "C_{O,i} = C_{L,i} + r·C_{R,i}": +//! - C_{O,i} = C_{L,i} + C_{R',i} +//! - "C_{R',i} = r·C_{R,i}" +//! - bucket[(ϕ^i)_k] -= C_{R',i} +//! - bucket[(r·ϕ^i)_k] += C_{R,i} +//! - "E_O = E_L + r·T_0 + r^2·T_1 + r^3·E_R": +//! - E_O = E_L + E_R' +//! - "E_R' = r·T_0 + r^2·T_1 + r^3·E_R" +//! - bucket[(ϕ^{n+1})_k] += E_R' +//! - bucket[(r·ϕ^{n+1})_k] += T_0 +//! - bucket[(r^2·ϕ^{n+1})_k] += T_1 +//! - bucket[(r^3·ϕ^{n+1})_k] += E_R +//! +//! Runtime access time is represented by ? because it's not known in advance. +//! +//! Output and input RAM invocations in the same row use the same coeff/memory index. +//! +//! TODO FIXME we need to write into /different/ buckets. +//! +//! FEC Additions, one per row, each one is ~230 columns: +//! +//! . input #2 (looked up) . Output +//! . Access input#2 . FEC ADD access output +//! input#1 . Coeff/mem ix Time Value . computation time value +//! 1 |-------------------------------------------------------------------------------|-----|------------| +//! | C_{R',1} | ϕ^1_0 | ? | bucket[ϕ^1_0] | | ? | newbucket | +//! | C_{R',2} | ϕ^2_0 | ? | bucket[ϕ^2_0] | | ? | newbucket | +//! | | ... | ... | | | ... | | +//! | C_{R',N} | ϕ^N_0 | ? | bucket[ϕ^N_0] | | ? | newbucket | +//! | C_{R',1} | ϕ^1_1 | ? | bucket[ϕ^0_1] | | ? | newbucket | +//! | | ... | ... | | | ... | | +//! | C_{R',i} | ϕ^i_k | ? | bucket[ϕ^i_k] | | ? | newbucket | +//! | | | | | | | | +//! | | ... | ... | | | ... | | +//! 17*N |--------------------------------------------------------------------------------------------------| +//! | C_{R,i} | r·(ϕ^i_k) | ? | bucket[r·(ϕ^i_k)] | | ? | newbucket | +//! | | | | | | | | +//! | | ... | ... | | | ... | | +//! 34*N |--------------------------------------------------------------------------------------------------| +//! |-C_{R',i} | - | - | C_{L,i} | | - | C_{O,i} | +//! | | ... | ... | | | ... | | +//! 35*N |--------------------------------------------------------------------------------------------------| +//! | -E_R' | ϕ^{n+1}_k | ? | bucket[ϕ^{n+1}_k] | | | newbucket | +//! | | ... | ... | | | ... | | +//! | T_0 | r·(ϕ^{n+1}) | ? | bucket[r·(ϕ^{n+1})] | | | newbucket | +//! | | ... | ... | | | ... | | +//! | T_1 | r^2·(ϕ^{n+1}_k) | ? | bucket[r^2·(ϕ^{n+1}_k)] | | | newbucket | +//! | | ... | ... | | | ... | | +//! | E_R | r^3·(ϕ^{n+1}_k) | ? | bucket[r^3·(ϕ^{n+1}_k)] | | | newbucket | +//! | | ... | ... | | | ... | | +//! 35*N+ | E_L | - | - | E_R' | | - | E_O | +//! 4*17+ |--------------------------------------------------------------------------------------------------| +//! 1 +//! +//! TODO: add different challenges: β, γ, joint_combiner +//! +//! Challenges block. +//! +//! relaxed +//! strict +//! (relaxed in-place) +//! r α_{L,i} α_{R}^i α_{O,i} +//! 1 |--|--------|-----------|-----------------------| +//! | | | α_R = h_R | | +//! | | | | | +//! | | | | | +//! | | | α_R^i | α_{L,i} + r·α_{R,i}^i | +//! | | | | | +//! | | | | | +//! | | | | | +//! | | | | | +//! | | | | | +//! | | | | | +//! #chal |--|--------|-----------|-----------------------| +//! +//! #chal is the number of constraints. Our optimistic expectation is +//! that it is around const*N for const < 3. +//! +//! +//! "u" block. In the general form we want to prove +//! u_O = u_L + r·u_R, but u_R = 0, so we prove +//! u_O = u_L + r. +//! +//! r u_L u_O = u_L + r +//! |--|--------|--------------------| +//! |--|--------|--------------------| +//! +//! +//! 2^15 |---- --------------------------------------| +//!``` +//! +//! +//! Assume that IVC circuit takess CELL cells, e.g. CELL = 10000*N. +//! Then we can calculate N_IVC as dependency of N_APP in this way: +//! ```text +//! N = N_APP + (CELL/2^15)*N +//! (1 - CELL/2^15)*N = N_APP +//! N = (1/(1 - CELL/2^15)) * N_APP = (2^15 / (2^15 - CELL)) * N_APP +//! N_IVC = (1/(1 - CELL/2^15) - 1) * N_APP = (2^15 / (2^15 - CELL) - 1) * N_APP +//! ``` +//! +//! --- (slightly) OUTDATED BELOW --- +//! +//! Counting cells: +//! - Inputs: 2 * 17 * 3N = 102N +//! - Inputs repacked 75: 2 * 4 * 3N = 24N +//! - Inputs repacked 150: 2 * 2 * 3N = 12N +//! - Hashes: 2 * 165 * 3N = 990N (max 4 * 165 * 3N if we add 165 constants to every call) +//! - scalars: 4 N + 17 * 3 * N = 55 N +//! - ECADDs: 230 * 35 * N = 8050N +//! Total (CELL): ~9233*N +//! +//! ...which is less than 32k*N +//! +//! In our particular case, CELL = 9233, so +//! N_IVC = 0.39 N_APP +//! +//! Which means for e.g. 50 columns we'll need extra 20 of IVC. + +use super::N_LIMBS_XLARGE; +use crate::poseidon_8_56_5_3_2::{ + bn254::{ + Column as IVCPoseidonColumn, NB_FULL_ROUND as IVC_POSEIDON_NB_FULL_ROUND, + NB_PARTIAL_ROUND as IVC_POSEIDON_NB_PARTIAL_ROUND, + NB_ROUND_CONSTANTS as IVC_POSEIDON_NB_ROUND_CONSTANTS, + STATE_SIZE as IVC_POSEIDON_STATE_SIZE, + }, + columns::PoseidonColumn, +}; +use kimchi_msm::{ + circuit_design::composition::MPrism, + columns::{Column, ColumnIndexer}, + fec::columns::{FECColumn, FECColumnInput, FECColumnInter, FECColumnOutput}, + serialization::interpreter::{N_LIMBS_LARGE, N_LIMBS_SMALL}, +}; + +/// Number of blocks in the circuit. +pub const N_BLOCKS: usize = 6; + +/// Defines the height of each block in the IVC circuit. +/// As described in the top-level of this file, the circuit is tiled +/// vertically, and the height of each block is defined by the actual +/// application that needs to be run. +pub fn block_height(block_num: usize) -> usize { + match block_num { + // The first block is used for the decomposition of the inputs. + // As we do have 3 inputs (left, right and output), we need 3 times the + // total number of columns. + 0 => 3 * N_COL_TOTAL, + // The second block is used for hashing the different values. + // We do have 2 foreign field elements to hash per commitment. + // At the moment, we do split each foreign field element in 150 bits + // chunk, therefore we need to hash (2 * 2 * 3) (= 12) scalar field + // elements. As our Poseidon state can handle 2 elements per absorb, we + // need to compute hashes. + // FIXME: why an additional 2? + // FIXME: use the encoding of the bit of y instead. It would reduce to 3 + // hashes instead of 6. + // For a EC point (x, y), we can only rely on x and the bit sign of y. + // We decompose x into `x_{0..149}` and `x_{150..255}` (only 255 as + // the foreign field is 255 bits), and we add into the value + // `x_{150..255}` the bit sign of `y`. From there, we can hash the two + // values, which fits in our state of size 3. + 1 => 6 * N_COL_TOTAL + 2, + // The third block is used for the randomisation of the MSM. + 2 => N_COL_TOTAL + 1, + // The fourth block is used for the foreign field ECC addition. + 3 => 35 * N_COL_TOTAL + 5, + // The fifth block is used for the challenges. + // The total number of challenges is given as a type parameter. + 4 => N_CHALS, + // The sixth block is used for the homogeneised value `u`. + // We do only need one row for this one as there is a single value to + // randomize + 5 => 1, + _ => panic!("block_size: no block number {block_num:?}"), + } +} + +/// Returns total height of the IVC circuit. +pub fn total_height() -> usize { + (0..N_BLOCKS) + .map(block_height::) + .sum() +} + +/// Number of fixed selectors for the IVC circuit. +pub const N_FSEL_IVC: usize = IVC_POSEIDON_NB_ROUND_CONSTANTS + N_BLOCKS; + +// NB: We can reuse hash constants. +// TODO: Can we pass just one coordinate and sign (x, sign) instead of (x,y) for hashing? +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub enum IVCColumn { + /// A single column containing the folding iteration number, + /// starting with 0 for the base case, and non-zero positive for + /// inductive case. + FoldIteration, + + /// Selector for blocks. Inner usize is ∈ [0,#blocks). + /// The selectors are fixed values for each instance of the relation. + /// The different blocks are described below or at the top-level of this + /// file. + BlockSel(usize), + + /// 2*17 15-bit limbs (two base field points) + Block1Input(usize), + /// 2*4 75-bit limbs + Block1InputRepacked75(usize), + /// 2*2 150-bit limbs + Block1InputRepacked150(usize), + + /// 1 hash per row + // FIXME/TODO: we might want to have 2 additional columns for the actual + // input we would like to hash, and must constrain the input state of the + // hash to be the sum of the previous state. + // i.e. for each hash, we have 2 + 3 + 2 columns + // - x, y the inputs we want to absorb + // - s0, s1, s2 the initial state of the hash + // - s0 + x, s1 + y the final state of the hash, with s2 being unused. + // - s0 + x, s1 + y and s2 are considered as the column + // PoseidonColumn::Input of the Poseidon gadget, and a constrain must be + // added between s0 and s0 + x (resp. s1 and s1 + y). + Block2Hash( + PoseidonColumn< + IVC_POSEIDON_STATE_SIZE, + IVC_POSEIDON_NB_FULL_ROUND, + IVC_POSEIDON_NB_PARTIAL_ROUND, + >, + ), + + /// Constant phi + Block3ConstPhi, + /// Constant r + Block3ConstR, + /// Scalar coeff #1, powers of Phi, phi^i + Block3PhiPow, + /// Scalar coeff #2, r * phi^i + Block3PhiPowR, + /// Scalar coeff #2, r^2 * phi^i + Block3PhiPowR2, + /// Scalar coeff #2, r^3 * phi^i + Block3PhiPowR3, + /// 17 15-bit limbs + Block3PhiPowLimbs(usize), + /// 17 15-bit limbs + Block3PhiPowRLimbs(usize), + /// 17 15-bit limbs + Block3PhiPowR2Limbs(usize), + /// 17 15-bit limbs + Block3PhiPowR3Limbs(usize), + + /// 2*4 75-bit limbs + Block4Input1(usize), + /// Coeffifient which is indexing a bucket. Used for both lookups in this row. + Block4Coeff, + /// RAM lookup access time for input 1. + Block4Input2AccessTime, + /// 2*4 75-bit limbs + Block4Input2(usize), + /// EC ADD intermediate wires + Block4ECAddInter(FECColumnInter), + /// 2*17 15-bit limbs + Block4OutputRaw(FECColumnOutput), + // TODO: this might be just Input2AccessTime + 1? + /// RAM lookup access time for output. + Block4OutputAccessTime, + /// 2*4 75-bit limbs + Block4OutputRepacked(usize), + + /// Constant h_r + Block5ConstHr, + /// Constant r + Block5ConstR, + /// α_{L,i} + Block5ChalLeft, + /// α_R^i, where α_R = h_R + Block5ChalRight, + /// α_{O,i} = α_{L,i} + r·α_R^i + Block5ChalOutput, + + /// Constant r + Block6ConstR, + /// u_L + Block6ULeft, + /// u_O = u_L + r + Block6UOutput, +} + +impl ColumnIndexer for IVCColumn { + /// Number of columns used by the IVC circuit + /// It contains at least the columns used for Poseidon. + /// It does not include the additional columns that might be required + /// to reduce to degree 2. + /// FIXME: This can be improved by changing a bit the layer. + /// The reduction to degree 2 should happen in the gadgets to avoid adding + /// extra columns and leave sparse rows. + // We consider IVCPoseidonColumn::N_COL but it should be the maximum of + // the different gadgets/blocks. + // We also add 1 for the FoldIteration column. + const N_COL: usize = IVCPoseidonColumn::N_COL + 1 + N_BLOCKS; + + fn to_column(self) -> Column { + match self { + // We keep a column that will be used for the folding iteration. + // Question: do we need it for all the rows or does it appear only + // for one hash? + IVCColumn::FoldIteration => Column::Relation(0), + IVCColumn::BlockSel(i) => { + assert!(i < N_BLOCKS); + Column::FixedSelector(i) + } + + IVCColumn::Block1Input(i) => { + assert!(i < 2 * N_LIMBS_SMALL); + Column::Relation(i).add_rel_offset(1) + } + IVCColumn::Block1InputRepacked75(i) => { + assert!(i < 2 * N_LIMBS_LARGE); + Column::Relation(2 * N_LIMBS_SMALL + i).add_rel_offset(1) + } + IVCColumn::Block1InputRepacked150(i) => { + assert!(i < 2 * N_LIMBS_XLARGE); + Column::Relation(2 * N_LIMBS_SMALL + 2 * N_LIMBS_LARGE + i).add_rel_offset(1) + } + + // The second block of columns will be used for hashing the + // different values, like the commitments and the challenges. + IVCColumn::Block2Hash(poseidon_col) => match poseidon_col { + // We map the round constants into the appropriate fixed + // selector. + // We suppose that the round constants are added after the + // public values describing the block selection. + PoseidonColumn::RoundConstant(round, state_size) => { + let idx = round * IVC_POSEIDON_STATE_SIZE + state_size; + Column::FixedSelector(idx + N_BLOCKS) + } + // We add a padding for the fold iteration. + // Even though, we might not need it. + other => other.to_column().add_rel_offset(1), + }, + + // The third block is used to compute the randomisation of the MSM + IVCColumn::Block3ConstPhi => Column::Relation(0).add_rel_offset(1), + IVCColumn::Block3ConstR => Column::Relation(1).add_rel_offset(1), + IVCColumn::Block3PhiPow => Column::Relation(2).add_rel_offset(1), + IVCColumn::Block3PhiPowR => Column::Relation(3).add_rel_offset(1), + IVCColumn::Block3PhiPowR2 => Column::Relation(4).add_rel_offset(1), + IVCColumn::Block3PhiPowR3 => Column::Relation(5).add_rel_offset(1), + IVCColumn::Block3PhiPowLimbs(i) => { + assert!(i < N_LIMBS_SMALL); + Column::Relation(6 + i).add_rel_offset(1) + } + IVCColumn::Block3PhiPowRLimbs(i) => { + assert!(i < N_LIMBS_SMALL); + Column::Relation(6 + N_LIMBS_SMALL + i).add_rel_offset(1) + } + IVCColumn::Block3PhiPowR2Limbs(i) => { + assert!(i < N_LIMBS_SMALL); + Column::Relation(6 + 2 * N_LIMBS_SMALL + i).add_rel_offset(1) + } + IVCColumn::Block3PhiPowR3Limbs(i) => { + assert!(i < N_LIMBS_SMALL); + Column::Relation(6 + 3 * N_LIMBS_SMALL + i).add_rel_offset(1) + } + + // The fourth block is used for the foreign field ECC addition + IVCColumn::Block4Input1(i) => { + assert!(i < 2 * N_LIMBS_LARGE); + Column::Relation(i).add_rel_offset(1) + } + IVCColumn::Block4Coeff => Column::Relation(8).add_rel_offset(1), + IVCColumn::Block4Input2AccessTime => Column::Relation(9).add_rel_offset(1), + IVCColumn::Block4Input2(i) => { + assert!(i < 2 * N_LIMBS_LARGE); + Column::Relation(10 + i).add_rel_offset(1) + } + IVCColumn::Block4ECAddInter(fec_inter) => { + fec_inter.to_column().add_rel_offset(18).add_rel_offset(1) + } + IVCColumn::Block4OutputRaw(fec_output) => fec_output + .to_column() + .add_rel_offset(18 + FECColumnInter::N_COL) + .add_rel_offset(1), + IVCColumn::Block4OutputAccessTime => { + Column::Relation(18 + FECColumnInter::N_COL + FECColumnOutput::N_COL) + .add_rel_offset(1) + } + IVCColumn::Block4OutputRepacked(i) => { + assert!(i < 2 * N_LIMBS_LARGE); + Column::Relation(18 + FECColumnInter::N_COL + FECColumnOutput::N_COL + 1 + i) + .add_rel_offset(1) + } + + // The fifth block is used for the challenges. Note that the + // challenges contains all the alphas (used to bundle the + // constraints) and also the challenges used by the PIOP. + IVCColumn::Block5ConstHr => Column::Relation(0).add_rel_offset(1), + IVCColumn::Block5ConstR => Column::Relation(1).add_rel_offset(1), + IVCColumn::Block5ChalLeft => Column::Relation(2).add_rel_offset(1), + IVCColumn::Block5ChalRight => Column::Relation(3).add_rel_offset(1), + IVCColumn::Block5ChalOutput => Column::Relation(4).add_rel_offset(1), + + // The sixth block is used for the homogeneised value `u`. + IVCColumn::Block6ConstR => Column::Relation(0).add_rel_offset(1), + IVCColumn::Block6ULeft => Column::Relation(1).add_rel_offset(1), + IVCColumn::Block6UOutput => Column::Relation(2).add_rel_offset(1), + } + } +} + +/// Lens used to convert between IVC columns and Poseidon columns. +/// We use one row for each hash, whatever the type of inputs the hash has. +pub struct IVCHashLens {} + +impl MPrism for IVCHashLens { + type Source = IVCColumn; + type Target = IVCPoseidonColumn; + + // [traverse] must map the columns IVCColumn to the corresponding columns of + // the Poseidon circuit. + // We do suppose that the columns are ordered in the same way in the IVC and + // the Poseidon circuit, and that a whole row is used by the Poseidon gadget + fn traverse(&self, source: Self::Source) -> Option { + match source { + IVCColumn::Block2Hash(col) => Some(col), + _ => None, + } + } + + // [re_get] must map back the columns of the Poseidon gadget to the columns + // of the IVC circuit. + // As said for [traverse], a whole row, Block2Hash, is used by the Poseidon + // gadget. + // We map the Poseidon columns in the same order than described in + // [crate::poseidon_8_56_5_3_2::columns::PoseidonColumn], i.e. Input, + // Partialround and FullRound. + fn re_get(&self, target: Self::Target) -> Self::Source { + IVCColumn::Block2Hash(target) + } +} + +pub struct IVCFECLens {} + +impl MPrism for IVCFECLens { + type Source = IVCColumn; + type Target = FECColumn; + + fn traverse(&self, source: Self::Source) -> Option { + match source { + IVCColumn::Block4Input1(i) => { + if i < 4 { + Some(FECColumn::Input(FECColumnInput::XP(i))) + } else { + Some(FECColumn::Input(FECColumnInput::YP(i))) + } + } + IVCColumn::Block4Input2(i) => { + if i < 4 { + Some(FECColumn::Input(FECColumnInput::XQ(i))) + } else { + Some(FECColumn::Input(FECColumnInput::YQ(i))) + } + } + IVCColumn::Block4OutputRaw(output) => Some(FECColumn::Output(output)), + IVCColumn::Block4ECAddInter(inter) => Some(FECColumn::Inter(inter)), + _ => None, + } + } + + fn re_get(&self, target: Self::Target) -> Self::Source { + match target { + FECColumn::Input(FECColumnInput::XP(i)) => IVCColumn::Block4Input1(i), + FECColumn::Input(FECColumnInput::YP(i)) => IVCColumn::Block4Input1(4 + i), + FECColumn::Input(FECColumnInput::XQ(i)) => IVCColumn::Block4Input2(i), + FECColumn::Input(FECColumnInput::YQ(i)) => IVCColumn::Block4Input2(4 + i), + FECColumn::Output(output) => IVCColumn::Block4OutputRaw(output), + FECColumn::Inter(inter) => IVCColumn::Block4ECAddInter(inter), + } + } +} diff --git a/ivc/src/ivc/constraints.rs b/ivc/src/ivc/constraints.rs new file mode 100644 index 0000000000..99ce5a2527 --- /dev/null +++ b/ivc/src/ivc/constraints.rs @@ -0,0 +1,305 @@ +use super::{ + columns::{IVCColumn, IVCHashLens, N_BLOCKS}, + helpers::{combine_large_to_xlarge, combine_small_to_full}, + lookups::{IVCFECLookupLens, IVCLookupTable}, + N_LIMBS_XLARGE, +}; + +use crate::poseidon_8_56_5_3_2::bn254::PoseidonBN254Parameters; + +use crate::{ivc::columns::IVCFECLens, poseidon_8_56_5_3_2}; +use ark_ff::PrimeField; +use kimchi_msm::{ + circuit_design::{ + capabilities::read_column_array, + composition::{SubEnvColumn, SubEnvLookup}, + ColAccessCap, HybridCopyCap, LookupCap, + }, + fec::{columns::FECColumnOutput, interpreter::constrain_ec_addition}, + serialization::{ + interpreter::{combine_small_to_large, N_LIMBS_LARGE, N_LIMBS_SMALL}, + lookups as serlookup, + }, + Fp, +}; +use std::marker::PhantomData; + +fn range_check_scalar_limbs( + env: &mut Env, + input_limbs_small: &[Env::Variable; N_LIMBS_SMALL], +) where + F: PrimeField, + Ff: PrimeField, + Env: ColAccessCap + LookupCap>, +{ + for (i, x) in input_limbs_small.iter().enumerate() { + if i % N_LIMBS_SMALL == N_LIMBS_SMALL - 1 { + // If it's the highest limb, we need to check that it's representing a field element. + env.lookup( + IVCLookupTable::SerLookupTable(serlookup::LookupTable::RangeCheckFfHighest( + PhantomData, + )), + vec![x.clone()], + ); + } else { + // TODO Add this lookup. + // env.lookup(IVCLookupTable::RangeCheckFHighest, x); + } + } +} + +fn range_check_small_limbs( + env: &mut Env, + input_limbs_small: &[Env::Variable; N_LIMBS_SMALL], +) where + F: PrimeField, + Ff: PrimeField, + Env: ColAccessCap + LookupCap>, +{ + for (i, x) in input_limbs_small.iter().enumerate() { + if i % N_LIMBS_SMALL == N_LIMBS_SMALL - 1 { + // If it's the highest limb, we need to check that it's representing a field element. + env.lookup( + IVCLookupTable::SerLookupTable(serlookup::LookupTable::RangeCheckFfHighest( + PhantomData, + )), + vec![x.clone()], + ); + } else { + env.lookup( + IVCLookupTable::SerLookupTable(serlookup::LookupTable::RangeCheck15), + vec![x.clone()], + ); + } + } +} + +/// Constraints for the inputs block. +pub fn constrain_inputs(env: &mut Env) +where + F: PrimeField, + Ff: PrimeField, + Env: ColAccessCap + LookupCap>, +{ + let input_limbs_small_x: [_; N_LIMBS_SMALL] = read_column_array(env, IVCColumn::Block1Input); + let input_limbs_small_y: [_; N_LIMBS_SMALL] = + read_column_array(env, |x| IVCColumn::Block1Input(N_LIMBS_SMALL + x)); + // Range checks on 15 bits + { + range_check_small_limbs::(env, &input_limbs_small_x); + range_check_small_limbs::(env, &input_limbs_small_y); + } + + let input_limbs_large_x: [_; N_LIMBS_LARGE] = + read_column_array(env, IVCColumn::Block1InputRepacked75); + let input_limbs_large_y: [_; N_LIMBS_LARGE] = + read_column_array(env, |x| IVCColumn::Block1InputRepacked75(N_LIMBS_LARGE + x)); + + // Repacking to 75 bits + { + let input_limbs_large_x_expected = + combine_small_to_large::<_, _, Env>(input_limbs_small_x.clone()); + let input_limbs_large_y_expected = + combine_small_to_large::<_, _, Env>(input_limbs_small_y.clone()); + input_limbs_large_x_expected + .into_iter() + .zip(input_limbs_large_x.clone()) + .for_each(|(e1, e2)| env.assert_zero(e1 - e2)); + input_limbs_large_y_expected + .into_iter() + .zip(input_limbs_large_y.clone()) + .for_each(|(e1, e2)| env.assert_zero(e1 - e2)); + } + + let input_limbs_xlarge_x: [_; N_LIMBS_XLARGE] = + read_column_array(env, IVCColumn::Block1InputRepacked150); + let input_limbs_xlarge_y: [_; N_LIMBS_XLARGE] = read_column_array(env, |x| { + IVCColumn::Block1InputRepacked150(N_LIMBS_XLARGE + x) + }); + + // Repacking to 150 bits + { + let input_limbs_xlarge_x_expected = + combine_large_to_xlarge::<_, _, Env>(input_limbs_large_x.clone()); + let input_limbs_xlarge_y_expected = + combine_large_to_xlarge::<_, _, Env>(input_limbs_large_y.clone()); + input_limbs_xlarge_x_expected + .into_iter() + .zip(input_limbs_xlarge_x.clone()) + .for_each(|(e1, e2)| env.assert_zero(e1 - e2)); + input_limbs_xlarge_y_expected + .into_iter() + .zip(input_limbs_xlarge_y.clone()) + .for_each(|(e1, e2)| env.assert_zero(e1 - e2)); + } +} + +pub fn constrain_u(env: &mut Env) +where + F: PrimeField, + Env: ColAccessCap, +{ + // TODO constrain that r is read from the "hashes" block. + // TODO constrain that the inputs are corresponding to public input (?). + + let r = env.read_column(IVCColumn::Block6ConstR); + let u_l = env.read_column(IVCColumn::Block6ULeft); + let u_o = env.read_column(IVCColumn::Block6UOutput); + env.assert_zero(u_o - u_l - r); +} + +pub fn constrain_challenges(env: &mut Env) +where + F: PrimeField, + Env: ColAccessCap, +{ + let _h_r = env.read_column(IVCColumn::Block5ConstHr); + + let r = env.read_column(IVCColumn::Block5ConstR); + let alpha_l = env.read_column(IVCColumn::Block5ChalLeft); + let alpha_r = env.read_column(IVCColumn::Block5ChalRight); + let alpha_o = env.read_column(IVCColumn::Block5ChalOutput); + env.assert_zero(alpha_o - alpha_l - r * alpha_r); + + // TODO constrain that α_l are public inputs + // TODO constrain that α_{r,i} = α_{r,i-1} * h_R + // TODO constrain that α_{r,1} = h_R (from hash table) +} + +pub fn constrain_scalars(env: &mut Env) +where + F: PrimeField, + Ff: PrimeField, + Env: ColAccessCap + LookupCap>, +{ + let _phi = env.read_column(IVCColumn::Block3ConstPhi); + let r = env.read_column(IVCColumn::Block3ConstR); + let phi_i = env.read_column(IVCColumn::Block3PhiPow); + let phi_i_r = env.read_column(IVCColumn::Block3PhiPowR); + let phi_pow_limbs: [_; N_LIMBS_SMALL] = read_column_array(env, IVCColumn::Block3PhiPowLimbs); + let phi_pow_r_limbs: [_; N_LIMBS_SMALL] = read_column_array(env, IVCColumn::Block3PhiPowRLimbs); + + let phi_pow_expected = combine_small_to_full::<_, _, Env>(phi_pow_limbs.clone()); + let phi_pow_r_expected = combine_small_to_full::<_, _, Env>(phi_pow_r_limbs.clone()); + + { + range_check_scalar_limbs::(env, &phi_pow_limbs); + range_check_scalar_limbs::(env, &phi_pow_r_limbs); + } + + // TODO Add expression asserting data with the next row. E.g. + // let phi_i_next = env.read_column_(IVCColumn::Block3ConstR) + // env.assert_zero(phi_i_next - phi_i * phi) + env.assert_zero(phi_i_r.clone() - phi_i.clone() * r.clone()); + env.assert_zero(phi_pow_expected - phi_i); + env.assert_zero(phi_pow_r_expected - phi_i_r); +} + +pub fn constrain_ecadds(env: &mut Env) +where + F: PrimeField, + Ff: PrimeField, + Env: ColAccessCap + LookupCap>, +{ + constrain_ec_addition::(&mut SubEnvLookup::new( + &mut SubEnvColumn::new(env, IVCFECLens {}), + IVCFECLookupLens(PhantomData), + )); + + // Repacking to 75 bits + + let output_limbs_small_x: [_; N_LIMBS_SMALL] = + read_column_array(env, |i| IVCColumn::Block4OutputRaw(FECColumnOutput::XR(i))); + let output_limbs_small_y: [_; N_LIMBS_SMALL] = + read_column_array(env, |i| IVCColumn::Block4OutputRaw(FECColumnOutput::YR(i))); + + let output_limbs_large_x: [_; N_LIMBS_LARGE] = + read_column_array(env, IVCColumn::Block4OutputRepacked); + let output_limbs_large_y: [_; N_LIMBS_LARGE] = + read_column_array(env, |i| IVCColumn::Block4OutputRepacked(N_LIMBS_LARGE + i)); + + { + let output_limbs_large_x_expected = + combine_small_to_large::<_, _, Env>(output_limbs_small_x); + let output_limbs_large_y_expected = + combine_small_to_large::<_, _, Env>(output_limbs_small_y); + output_limbs_large_x_expected + .into_iter() + .zip(output_limbs_large_x.clone()) + .for_each(|(e1, e2)| env.assert_zero(e1 - e2)); + output_limbs_large_y_expected + .into_iter() + .zip(output_limbs_large_y.clone()) + .for_each(|(e1, e2)| env.assert_zero(e1 - e2)); + } +} + +// We might not need to constrain selectors to be 0 or 1 if selectors +// are public values, and can be verified directly by the verifier. +// However we might need these constraints in folding, where public +// input needs to be checked. +/// This function generates constraints for the whole IVC circuit. +pub fn constrain_selectors(env: &mut Env) +where + F: PrimeField, + Env: ColAccessCap, +{ + for i in 0..N_BLOCKS { + // Each selector must have value either 0 or 1. + let sel = env.read_column(IVCColumn::BlockSel(i)); + env.assert_zero(sel.clone() * (sel.clone() - Env::constant(F::one()))); + } +} + +/// This function generates constraints for the whole IVC circuit. +pub fn constrain_ivc(env: &mut Env) +where + Ff: PrimeField, + Env: ColAccessCap + + LookupCap> + + HybridCopyCap, +{ + constrain_selectors(env); + + // The code below calls constraint method, and internally records + // constraints for the corresponding blocks. Before the each call + // we prefix the constraint with `selector(block_num)*` so that + // the constraints that are created in the block block_num will have + // the form selector(block_num)*C(X) and not just C(X). + + let fold_iteration = env.read_column(IVCColumn::FoldIteration); + let s0 = env.read_column(IVCColumn::BlockSel(0)); + env.set_assert_mapper(Box::new(move |x| fold_iteration.clone() * s0.clone() * x)); + constrain_inputs(env); + + let fold_iteration = env.read_column(IVCColumn::FoldIteration); + + let s1 = env.read_column(IVCColumn::BlockSel(1)); + env.set_assert_mapper(Box::new(move |x| fold_iteration.clone() * s1.clone() * x)); + { + let mut env = SubEnvColumn::new(env, IVCHashLens {}); + poseidon_8_56_5_3_2::interpreter::apply_permutation(&mut env, &PoseidonBN254Parameters); + } + + let fold_iteration = env.read_column(IVCColumn::FoldIteration); + let s2 = env.read_column(IVCColumn::BlockSel(2)); + env.set_assert_mapper(Box::new(move |x| fold_iteration.clone() * s2.clone() * x)); + constrain_scalars(env); + + let fold_iteration = env.read_column(IVCColumn::FoldIteration); + let s3 = env.read_column(IVCColumn::BlockSel(3)); + env.set_assert_mapper(Box::new(move |x| fold_iteration.clone() * s3.clone() * x)); + constrain_ecadds(env); + + let fold_iteration = env.read_column(IVCColumn::FoldIteration); + let s4 = env.read_column(IVCColumn::BlockSel(4)); + env.set_assert_mapper(Box::new(move |x| fold_iteration.clone() * s4.clone() * x)); + constrain_challenges(env); + + let fold_iteration = env.read_column(IVCColumn::FoldIteration); + let s5 = env.read_column(IVCColumn::BlockSel(5)); + env.set_assert_mapper(Box::new(move |x| fold_iteration.clone() * s5.clone() * x)); + constrain_u(env); + + env.set_assert_mapper(Box::new(move |x| x)); +} diff --git a/ivc/src/ivc/helpers.rs b/ivc/src/ivc/helpers.rs new file mode 100644 index 0000000000..4507833bb7 --- /dev/null +++ b/ivc/src/ivc/helpers.rs @@ -0,0 +1,45 @@ +use ark_ff::PrimeField; +use kimchi_msm::{ + circuit_design::ColAccessCap, + columns::ColumnIndexer, + serialization::interpreter::{ + combine_limbs_m_to_n, LIMB_BITSIZE_LARGE, LIMB_BITSIZE_SMALL, N_LIMBS_LARGE, N_LIMBS_SMALL, + }, +}; + +use super::{LIMB_BITSIZE_XLARGE, N_LIMBS_XLARGE}; + +/// Helper. Combines small limbs into big limbs. +pub fn combine_large_to_xlarge>( + x: [Env::Variable; N_LIMBS_LARGE], +) -> [Env::Variable; N_LIMBS_XLARGE] { + combine_limbs_m_to_n::< + N_LIMBS_LARGE, + { N_LIMBS_XLARGE }, + LIMB_BITSIZE_LARGE, + { LIMB_BITSIZE_XLARGE }, + F, + Env::Variable, + _, + >(|f| Env::constant(f), x) +} + +/// Helper. Combines 17x15bit limbs into 1 native field element. +pub fn combine_small_to_full>( + x: [Env::Variable; N_LIMBS_SMALL], +) -> Env::Variable { + let [res] = + combine_limbs_m_to_n::( + |f| Env::constant(f), + x, + ); + res +} + +// TODO double-check it works +/// Helper. Combines large limbs into one element. Computation is over the field. +pub fn combine_large_to_full_field(x: [Ff; N_LIMBS_LARGE]) -> Ff { + let [res] = + combine_limbs_m_to_n::(|f| f, x); + res +} diff --git a/ivc/src/ivc/interpreter.rs b/ivc/src/ivc/interpreter.rs new file mode 100644 index 0000000000..a1d7784d20 --- /dev/null +++ b/ivc/src/ivc/interpreter.rs @@ -0,0 +1,937 @@ +// Interpreter for IVC circuit (for folding). + +// Imports only used for docs +#[cfg(doc)] +use crate::poseidon_8_56_5_3_2::bn254::{ + NB_COLUMNS as IVC_POSEIDON_NB_COLUMNS, NB_CONSTRAINTS as IVC_POSEIDON_NB_CONSTRAINTS, +}; + +use crate::{ + ivc::{ + columns::{block_height, total_height, IVCColumn, IVCFECLens, IVCHashLens, N_BLOCKS}, + constraints::{ + constrain_challenges, constrain_ecadds, constrain_inputs, constrain_scalars, + constrain_u, + }, + lookups::{IVCFECLookupLens, IVCLookupTable}, + }, + poseidon_8_56_5_3_2::{ + bn254::{ + PoseidonBN254Parameters, NB_TOTAL_ROUND as IVC_POSEIDON_NB_TOTAL_ROUND, + STATE_SIZE as IVC_POSEIDON_STATE_SIZE, + }, + interpreter::{poseidon_circuit, PoseidonParams}, + }, +}; +use ark_ff::PrimeField; +use kimchi_msm::{ + circuit_design::{ + capabilities::write_column_array_const, + composition::{SubEnvColumn, SubEnvLookup}, + ColWriteCap, DirectWitnessCap, HybridCopyCap, LookupCap, MultiRowReadCap, + }, + fec::interpreter::ec_add_circuit, + serialization::interpreter::{ + limb_decompose_ff, LIMB_BITSIZE_LARGE, LIMB_BITSIZE_SMALL, N_LIMBS_LARGE, N_LIMBS_SMALL, + }, +}; +use num_bigint::BigUint; +use std::marker::PhantomData; + +use super::{ + columns::N_FSEL_IVC, helpers::combine_large_to_full_field, LIMB_BITSIZE_XLARGE, N_LIMBS_XLARGE, +}; + +pub fn write_inputs_row( + env: &mut Env, + target_comms: &[(Ff, Ff); N_COL_TOTAL], + row_num_local: usize, +) -> (Vec, Vec, Vec) +where + F: PrimeField, + Ff: PrimeField, + Env: ColWriteCap, +{ + let small_limbs_1 = limb_decompose_ff::( + &target_comms[row_num_local].0, + ); + let small_limbs_2 = limb_decompose_ff::( + &target_comms[row_num_local].1, + ); + let small_limbs_1_2: [_; 2 * N_LIMBS_SMALL] = small_limbs_1 + .into_iter() + .chain(small_limbs_2) + .collect::>() + .try_into() + .unwrap(); + let large_limbs_1 = limb_decompose_ff::( + &target_comms[row_num_local].0, + ); + let large_limbs_2 = limb_decompose_ff::( + &target_comms[row_num_local].1, + ); + let large_limbs_1_2: [_; 2 * N_LIMBS_LARGE] = large_limbs_1 + .into_iter() + .chain(large_limbs_2) + .collect::>() + .try_into() + .unwrap(); + let xlarge_limbs_1 = limb_decompose_ff::( + &target_comms[row_num_local].0, + ); + let xlarge_limbs_2 = limb_decompose_ff::( + &target_comms[row_num_local].1, + ); + let xlarge_limbs_1_2: [_; 2 * N_LIMBS_XLARGE] = xlarge_limbs_1 + .into_iter() + .chain(xlarge_limbs_2) + .collect::>() + .try_into() + .unwrap(); + + write_column_array_const(env, &small_limbs_1_2, IVCColumn::Block1Input); + write_column_array_const(env, &large_limbs_1_2, IVCColumn::Block1InputRepacked75); + write_column_array_const(env, &xlarge_limbs_1_2, IVCColumn::Block1InputRepacked150); + + ( + small_limbs_1_2.to_vec(), + large_limbs_1_2.to_vec(), + xlarge_limbs_1_2.to_vec(), + ) +} + +/// `comms` is lefts, rights, and outs. Returns the packed commitments +/// in three different representations. +#[allow(clippy::type_complexity)] +pub fn process_inputs( + env: &mut Env, + fold_iteration: usize, + comms: [Box<[(Ff, Ff); N_COL_TOTAL]>; 3], +) -> ( + Box<[[[F; 2 * N_LIMBS_SMALL]; N_COL_TOTAL]; 3]>, + Box<[[[F; 2 * N_LIMBS_LARGE]; N_COL_TOTAL]; 3]>, + Box<[[[F; 2 * N_LIMBS_XLARGE]; N_COL_TOTAL]; 3]>, +) +where + F: PrimeField, + Ff: PrimeField, + Env: MultiRowReadCap + LookupCap>, +{ + let mut comms_limbs_s: [Vec>; 3] = std::array::from_fn(|_| vec![]); + let mut comms_limbs_l: [Vec>; 3] = std::array::from_fn(|_| vec![]); + let mut comms_limbs_xl: [Vec>; 3] = std::array::from_fn(|_| vec![]); + + for _block_row_i in 0..block_height::(0) { + let row_num = env.curr_row(); + + env.write_column( + IVCColumn::FoldIteration, + &Env::constant(F::from(fold_iteration as u64)), + ); + + let (target_comms, row_num_local, comtype) = if row_num < N_COL_TOTAL { + (&comms[0], row_num, 0) + } else if row_num < 2 * N_COL_TOTAL { + (&comms[1], row_num - N_COL_TOTAL, 1) + } else { + (&comms[2], row_num - 2 * N_COL_TOTAL, 2) + }; + + let (limbs_small, limbs_large, limbs_xlarge) = + write_inputs_row(env, target_comms, row_num_local); + + comms_limbs_s[comtype].push(limbs_small); + comms_limbs_l[comtype].push(limbs_large); + comms_limbs_xl[comtype].push(limbs_xlarge); + + constrain_inputs(env); + + env.next_row(); + } + + ( + o1_utils::array::vec_to_boxed_array3(comms_limbs_s.to_vec()), + o1_utils::array::vec_to_boxed_array3(comms_limbs_l.to_vec()), + o1_utils::array::vec_to_boxed_array3(comms_limbs_xl.to_vec()), + ) +} + +// TODO We need to have alpha +// TODO We need to hash i (or i+1)? +// TODO We need to hash T_0 and T_1? +// FIXME: we do not initialize correctly the sponge state +// FIXME: when starting a new row, we do only use the output of the previous +// hash state. We might want to use the whole state and compute (s0 + i0, s1 + +// i1, s2). See comments in [crate::ivc::ivc::columns] +/// Instantiates the IVC circuit for folding. +/// `N_COL_TOTAL` is the total number of columns required by the IVC circuit + +/// the application. +/// +/// The input `comms_xlarge` contains the commitments to the left, right and +/// output instances (hence the outer size of 3). Each nested `N_COL_TOTAL` +/// contains the 4 limbs of 150 bits encoding the two coordinates of a +/// commitment. +/// +/// For instance, if there are 2 columns, the parameter `comms_xlarge` will be 6 +/// elliptic curve points `C_L_1, C_R_1, C_O_1, C_L_2, C_R_2, C_O_2`. +/// Each elliptic curve point is represented in affine coordinates by two values +/// `(x, y)` in the base field of the curve. We split each coordinate into 2 +/// chunks of 150 bits. +/// +/// We have therefore 12 scalar field elements, as follow: +/// ```text +/// C_L_1 C_L_2 +/// (x_l_1_0_150, x_l_1_150_255) | (x_l_2_0_150, x_l_2_150_255) | +/// (y_l_1_0_150, y_l_1_150_255) | (y_l_2_0_150, y_l_2_150_255) | +/// C_R_1 C_R_2 +/// (x_r_1_0_150, x_r_1_150_255) | (x_r_1_0_150, x_r_2_150_255) | +/// (y_r_1_0_150, y_r_1_150_255) | (y_r_1_0_150, y_r_2_150_255) | +/// C_O_1 C_O_2 +/// (x_o_1_0_150, x_o_1_150_255) | (x_o_2_0_150, x_o_2_150_255) | +/// (y_o_1_0_150, y_o_1_150_255) | (y_o_2_0_150, y_o_2_150_255) | +/// ``` +/// +/// We will absorb in the following order: +/// ```text +/// ---- Left commitments --- +/// - `(x_l_1_0_150, x_l_1_150_255)` +/// - `(y_l_1_0_150, y_l_1_150_255)` +/// - `(x_l_2_0_150, x_l_2_150_255)` +/// - `(y_l_2_0_150, y_l_2_150_255)` +/// ---- Right commitments --- +/// - `(x_r_1_0_150, x_r_1_150_255)` +/// - `(y_r_1_0_150, y_r_1_150_255)` +/// - `(x_r_2_0_150, x_r_2_150_255)` +/// - `(y_r_2_0_150, y_r_2_150_255)` +/// ---- Output commitments --- +/// - `(x_o_1_0_150, x_o_1_150_255)` +/// - `(y_o_1_0_150, y_o_1_150_255)` +/// - `(x_o_2_0_150, x_o_2_150_255)` +/// - `(y_o_2_0_150, y_o_2_150_255)` +/// ``` +/// For this: +/// 1. we get the previous state of the sponge `(s0, s1, s2)`. We copy it +/// in the three first columns of the row. +/// 2. we constrain the two next columns to be equal to +/// (`s0 + i1, s1 + i2`). We have then the columns 3, 4 and 5 equal to +/// `(s2, s0 + i1, s1 + i2)`. These will be considered as the input of the +/// Poseidon gadget. +/// This function therefore requires 2 + [IVC_POSEIDON_NB_COLUMNS] columns, +/// without counting the block selector. +/// It also introduces [IVC_POSEIDON_NB_CONSTRAINTS] + 2 constraints, and +/// therefore requires [IVC_POSEIDON_NB_CONSTRAINTS] + 2 alphas. +pub fn process_hashes( + env: &mut Env, + fold_iteration: usize, + poseidon_params: &PParams, + comms_xlarge: &[[[F; 2 * N_LIMBS_XLARGE]; N_COL_TOTAL]; 3], +) -> (Env::Variable, Env::Variable, Env::Variable) +where + F: PrimeField, + PParams: PoseidonParams, + Env: MultiRowReadCap + HybridCopyCap, +{ + // These should be some proper seeds. Passed from the outside + let sponge_l_init: F = F::zero(); + let sponge_r_init: F = F::zero(); + let sponge_o_init: F = F::zero(); + let sponge_lr_init: F = F::zero(); + let sponge_lro_init: F = F::zero(); + + let mut prev_hash_output = Env::constant(F::zero()); + let mut hash_l = Env::constant(F::zero()); + let mut hash_r = Env::constant(F::zero()); + let mut hash_o = Env::constant(F::zero()); + let mut r = Env::constant(F::zero()); + let mut phi = Env::constant(F::zero()); + + // Relative position in the hashing block + for block_row_i in 0..block_height::(1) { + env.write_column( + IVCColumn::FoldIteration, + &Env::constant(F::from(fold_iteration as u64)), + ); + + // On the first 6 * N_COL_TOTAL, we process the commitments + // Computing h_l, h_r, h_o independently + if block_row_i < 6 * N_COL_TOTAL { + // Left, right, or output + // we process first all left, after that all right, and after that + // all output. Can be 0 (left), 1 (right) or 2 (output). + let comm_type = block_row_i / (2 * N_COL_TOTAL); + // The commitment we target. Commitment i is processed in hash rows + // 2*i and 2*i+1. + // With our example above, we have: + // block_row_i = 0 or 1 -> 0 + // block_row_i = 2 or 3 -> 1 + // block_row_i = 4 or 5 -> 0 + // block_row_i = 6 or 7 -> 1 + // block_row_i = 8 or 9 -> 0 + // block_row_i = 10 or 11 -> 1 + let comm_i = (block_row_i % (2 * N_COL_TOTAL)) / 2; + + // Selecting the coordinate of the elliptic curve point, x or y + let (input1, input2) = if block_row_i % 2 == 0 { + ( + // x_[0..150] + comms_xlarge[comm_type][comm_i][0], + // x_[150..255] + comms_xlarge[comm_type][comm_i][1], + ) + } else { + ( + // y_[0..150] + comms_xlarge[comm_type][comm_i][2], + // y_[150..255] + comms_xlarge[comm_type][comm_i][3], + ) + }; + + // FIXME: we want to do s0 + input1, s1 + input2, s3 + // where s0, s1, s3 is the previous hash state. + let input3 = if block_row_i == 0 { + Env::constant(sponge_l_init) + } else if block_row_i == 2 * N_COL_TOTAL { + Env::constant(sponge_r_init) + } else if block_row_i == 4 * N_COL_TOTAL { + Env::constant(sponge_o_init) + } else { + prev_hash_output.clone() + }; + + // Run the actual computation. We keep the last output. + // FIXME: we want to keep the whole state for the next call. + let [_, _, output] = poseidon_circuit( + &mut SubEnvColumn::new(env, IVCHashLens {}), + poseidon_params, + [Env::constant(input1), Env::constant(input2), input3], + ); + + if block_row_i == 2 * N_COL_TOTAL { + // TODO we must somehow assert this hash_l is part of + // the "right strict instance". This is H_i in Nova. + hash_l = output; + } else if block_row_i == 4 * N_COL_TOTAL { + hash_r = output; + } else if block_row_i == 6 * N_COL_TOTAL { + // TODO we must somehow assert this hash_o is part of + // the "output relaxed instance". This is H_{i+1} in Nova. + // This one should be in the public input? + hash_o = output; + } else { + prev_hash_output = output; + } + } else if block_row_i == 6 * N_COL_TOTAL { + // Computing r + let [_, _, r_res] = poseidon_circuit( + &mut SubEnvColumn::new(env, IVCHashLens {}), + poseidon_params, + [ + hash_l.clone(), + hash_r.clone(), + Env::constant(sponge_lr_init), + ], + ); + r = r_res; + } else if block_row_i == 6 * N_COL_TOTAL + 1 { + // Computing phi + let [_, _, phi_res] = poseidon_circuit( + &mut SubEnvColumn::new(env, IVCHashLens {}), + poseidon_params, + [hash_o.clone(), r.clone(), Env::constant(sponge_lro_init)], + ); + phi = phi_res; + } + + env.next_row(); + } + + (hash_r, r, phi) +} + +pub fn write_scalars_row( + env: &mut Env, + r_f: F, + phi_f: F, + phi_prev_power_f: F, +) -> ( + F, + [F; N_LIMBS_SMALL], + [F; N_LIMBS_SMALL], + [F; N_LIMBS_SMALL], + [F; N_LIMBS_SMALL], +) +where + F: PrimeField, + Env: ColWriteCap, +{ + let phi_curr_power_f = phi_prev_power_f * phi_f; + let phi_curr_power_r_f = phi_prev_power_f * phi_f * r_f; + let phi_curr_power_r2_f = phi_prev_power_f * phi_f * r_f * r_f; + let phi_curr_power_r3_f = phi_prev_power_f * phi_f * r_f * r_f * r_f; + + env.write_column(IVCColumn::Block3ConstPhi, &Env::constant(phi_f)); + env.write_column(IVCColumn::Block3ConstR, &Env::constant(r_f)); + env.write_column(IVCColumn::Block3PhiPow, &Env::constant(phi_curr_power_f)); + env.write_column(IVCColumn::Block3PhiPowR, &Env::constant(phi_curr_power_r_f)); + env.write_column( + IVCColumn::Block3PhiPowR2, + &Env::constant(phi_curr_power_r2_f), + ); + env.write_column( + IVCColumn::Block3PhiPowR3, + &Env::constant(phi_curr_power_r3_f), + ); + + // TODO check that limb_decompose_ff works with + let phi_curr_power_f_limbs = + limb_decompose_ff::(&phi_curr_power_f); + write_column_array_const(env, &phi_curr_power_f_limbs, IVCColumn::Block3PhiPowLimbs); + + let phi_curr_power_r_f_limbs = + limb_decompose_ff::(&phi_curr_power_r_f); + write_column_array_const( + env, + &phi_curr_power_r_f_limbs, + IVCColumn::Block3PhiPowRLimbs, + ); + + let phi_curr_power_r2_f_limbs = + limb_decompose_ff::(&phi_curr_power_r2_f); + write_column_array_const( + env, + &phi_curr_power_r2_f_limbs, + IVCColumn::Block3PhiPowR2Limbs, + ); + + let phi_curr_power_r3_f_limbs = + limb_decompose_ff::(&phi_curr_power_r3_f); + write_column_array_const( + env, + &phi_curr_power_r3_f_limbs, + IVCColumn::Block3PhiPowR3Limbs, + ); + + ( + phi_curr_power_f, + phi_curr_power_f_limbs, + phi_curr_power_r_f_limbs, + phi_curr_power_r2_f_limbs, + phi_curr_power_r3_f_limbs, + ) +} + +/// Contains vectors of scalars in small limb representations. +/// Generic consts don't allow +1, so vectors not arrays. `N` is +/// `N_COL_TOTAL`. +pub struct ScalarLimbs { + /// ϕ^i, i ∈ [N+1] + pub phi_limbs: Vec<[F; N_LIMBS_SMALL]>, + /// r·ϕ^i, i ∈ [N+1] + pub phi_r_limbs: Vec<[F; N_LIMBS_SMALL]>, + /// r^2·ϕ^{N+1} + pub phi_np1_r2_limbs: [F; N_LIMBS_SMALL], + /// r^3·ϕ^{N+1} + pub phi_np1_r3_limbs: [F; N_LIMBS_SMALL], +} + +/// Processes scalars. Returns a vector of limbs of (powers of) scalars produced. +pub fn process_scalars( + env: &mut Env, + fold_iteration: usize, + r: F, + phi: F, +) -> ScalarLimbs +where + F: PrimeField, + Ff: PrimeField, + Env: DirectWitnessCap + LookupCap>, +{ + let mut phi_prev_power_f = F::one(); + let mut phi_limbs = vec![]; + let mut phi_r_limbs = vec![]; + + // FIXME constrain these two, they are not in circuit yet + let mut phi_np1_r2_limbs = [F::zero(); N_LIMBS_SMALL]; + let mut phi_np1_r3_limbs = [F::zero(); N_LIMBS_SMALL]; + + for block_row_i in 0..block_height::(2) { + env.write_column( + IVCColumn::FoldIteration, + &Env::constant(F::from(fold_iteration as u64)), + ); + + let ( + phi_prev_power_f_new, + phi_curr_power_f_limbs, + phi_curr_power_r_f_limbs, + phi_curr_power_r2_f_limbs, + phi_curr_power_r3_f_limbs, + ) = write_scalars_row(env, r, phi, phi_prev_power_f); + + phi_prev_power_f = phi_prev_power_f_new; + phi_limbs.push(phi_curr_power_f_limbs); + phi_r_limbs.push(phi_curr_power_r_f_limbs); + + if block_row_i == N_COL_TOTAL + 1 { + phi_np1_r2_limbs = phi_curr_power_r2_f_limbs; + phi_np1_r3_limbs = phi_curr_power_r3_f_limbs; + } + + // Checking our constraints + constrain_scalars(env); + + env.next_row(); + } + + ScalarLimbs { + phi_limbs, + phi_r_limbs, + phi_np1_r2_limbs, + phi_np1_r3_limbs, + } +} + +pub fn process_ecadds( + env: &mut Env, + fold_iteration: usize, + scalar_limbs: ScalarLimbs, + comms_large: &[[[F; 2 * N_LIMBS_LARGE]; N_COL_TOTAL]; 3], + error_terms: [(Ff, Ff); 3], // E_L, E_R, E_O + t_terms: [(Ff, Ff); 2], // T_0, T_1 +) where + F: PrimeField, + Ff: PrimeField, + Env: DirectWitnessCap + LookupCap>, +{ + // Compute error and t terms limbs. + let error_terms_large: Box<[[F; 2 * N_LIMBS_LARGE]; 3]> = o1_utils::array::vec_to_boxed_array2( + error_terms + .iter() + .map(|(x, y)| { + limb_decompose_ff::(x) + .into_iter() + .chain(limb_decompose_ff::(y)) + .collect() + }) + .collect(), + ); + let t_terms_large: Box<[[F; 2 * N_LIMBS_LARGE]; 2]> = o1_utils::array::vec_to_boxed_array2( + t_terms + .iter() + .map(|(x, y)| { + limb_decompose_ff::(x) + .into_iter() + .chain(limb_decompose_ff::(y)) + .collect() + }) + .collect(), + ); + + let stub_bucket = { + // FIXME This is a STUB right now it uses randomly generated points (not even on curve) + // Must use bucket input which is looked up. + let mut rng = rand::thread_rng(); + let stub_x = ::rand(&mut rng); + let stub_y = ::rand(&mut rng); + let stub_x_large: [F; N_LIMBS_LARGE] = + limb_decompose_ff::(&stub_x); + let stub_y_large: [F; N_LIMBS_LARGE] = + limb_decompose_ff::(&stub_y); + (stub_x_large, stub_y_large) + }; + + // TODO FIXME STUBBED. multiply by r. For now these are just C_{R,i}, they must be {r * C_{R,i}} + //let r_hat_large: Box<[[F; 2 * N_LIMBS_LARGE]; N_COL_TOTAL]> = Box::new(comms_large[1]); + let r_hat_large: Box<[[F; 2 * N_LIMBS_LARGE]; N_COL_TOTAL]> = { + let mut rng = rand::thread_rng(); + let r_hat_x = ::rand(&mut rng); + let r_hat_y = ::rand(&mut rng); + let r_hat_x_large: [F; N_LIMBS_LARGE] = + limb_decompose_ff::(&r_hat_x); + let r_hat_y_large: [F; N_LIMBS_LARGE] = + limb_decompose_ff::(&r_hat_y); + let decomposition: [F; 2 * N_LIMBS_LARGE] = r_hat_x_large + .into_iter() + .chain(r_hat_y_large) + .collect::>() + .try_into() + .unwrap(); + o1_utils::array::vec_to_boxed_array(vec![decomposition; N_COL_TOTAL]) + }; + + // E_R' = r·T_0 + r^2·T_1 + r^3·E_R + // FIXME for now stubbed and just equal to E_L + let error_term_rprime_large: [F; 2 * N_LIMBS_LARGE] = error_terms_large[0]; + + for block_row_i in 0..block_height::(3) { + env.write_column( + IVCColumn::FoldIteration, + &Env::constant(F::from(fold_iteration as u64)), + ); + + // Number of the commitment we're processing, ∈ [N] + let com_i = block_row_i % N_COL_TOTAL; + // Coefficient limb we're processing for C_L/C_R/C_O, ∈ [k = 17] + let coeff_num_1 = (block_row_i / N_COL_TOTAL) % N_LIMBS_SMALL; + // Coefficient limb we're processing for error terms, ∈ [k = 17] + let coeff_num_2 = if block_row_i >= 35 * N_COL_TOTAL { + (block_row_i - 35 * N_COL_TOTAL) % N_LIMBS_SMALL + } else { + 0 + }; + + // First FEC input point, P. + let (xp_limbs, yp_limbs, coeff) = if block_row_i < 17 * N_COL_TOTAL { + // R hat, with ϕ^i + ( + r_hat_large[com_i][..N_LIMBS_LARGE].try_into().unwrap(), + r_hat_large[com_i][N_LIMBS_LARGE..].try_into().unwrap(), + scalar_limbs.phi_limbs[com_i][coeff_num_1], + ) + } else if block_row_i < 34 * N_COL_TOTAL { + // Our main C_R commitment input, with r·ϕ^i + ( + comms_large[1][com_i][..N_LIMBS_LARGE].try_into().unwrap(), + comms_large[1][com_i][N_LIMBS_LARGE..].try_into().unwrap(), + scalar_limbs.phi_r_limbs[com_i][coeff_num_1], + ) + } else if block_row_i < 35 * N_COL_TOTAL { + // FIXME add a minus! + // no bucketing, no coeffient, no RAM. Only -R hat + ( + r_hat_large[com_i][..N_LIMBS_LARGE].try_into().unwrap(), + r_hat_large[com_i][N_LIMBS_LARGE..].try_into().unwrap(), + F::zero(), + ) + } else if block_row_i < 35 * N_COL_TOTAL + 17 { + // FIXME add a minus + // -E_R', with coeff ϕ^{n+1} + ( + error_term_rprime_large[..N_LIMBS_LARGE].try_into().unwrap(), + error_term_rprime_large[N_LIMBS_LARGE..].try_into().unwrap(), + scalar_limbs.phi_limbs[N_COL_TOTAL][coeff_num_2], + ) + } else if block_row_i < 35 * N_COL_TOTAL + 2 * 17 { + // T_0, with coeff r · ϕ^{n+1} + ( + t_terms_large[0][..N_LIMBS_LARGE].try_into().unwrap(), + t_terms_large[0][N_LIMBS_LARGE..].try_into().unwrap(), + scalar_limbs.phi_r_limbs[N_COL_TOTAL][coeff_num_2], + ) + } else if block_row_i < 35 * N_COL_TOTAL + 3 * 17 { + // T_1, with coeff r^2 · ϕ^{n+1} + ( + t_terms_large[1][..N_LIMBS_LARGE].try_into().unwrap(), + t_terms_large[1][N_LIMBS_LARGE..].try_into().unwrap(), + scalar_limbs.phi_np1_r2_limbs[coeff_num_2], + ) + } else if block_row_i < 35 * N_COL_TOTAL + 4 * 17 { + // E_R, with coeff r^3 · ϕ^{n+1} + ( + error_terms_large[1][..N_LIMBS_LARGE].try_into().unwrap(), + error_terms_large[1][N_LIMBS_LARGE..].try_into().unwrap(), + scalar_limbs.phi_np1_r3_limbs[coeff_num_2], + ) + } else if block_row_i == 35 * N_COL_TOTAL + 4 * 17 { + // E_L, no bucketing, no coeff + ( + error_terms_large[0][..N_LIMBS_LARGE].try_into().unwrap(), + error_terms_large[0][N_LIMBS_LARGE..].try_into().unwrap(), + F::zero(), + ) + } else { + panic!("Dead case"); + }; + + // Second FEC input point, Q. + let (xq_limbs, yq_limbs) = if block_row_i < 34 * N_COL_TOTAL { + stub_bucket + } else if block_row_i < 35 * N_COL_TOTAL { + // C_{L,i} commitments + ( + comms_large[0][com_i][..N_LIMBS_LARGE].try_into().unwrap(), + comms_large[0][com_i][N_LIMBS_LARGE..].try_into().unwrap(), + ) + } else if block_row_i < 35 * N_COL_TOTAL + 4 * 17 { + stub_bucket + } else if block_row_i == 35 * N_COL_TOTAL + 4 * 17 { + // E_R' + ( + error_term_rprime_large[..N_LIMBS_LARGE].try_into().unwrap(), + error_term_rprime_large[N_LIMBS_LARGE..].try_into().unwrap(), + ) + } else { + panic!("Dead case"); + }; + + env.write_column(IVCColumn::Block4Coeff, &Env::constant(coeff)); + + // TODO These two should be used when RAMLookups are enabled. + env.write_column(IVCColumn::Block4Input2AccessTime, &Env::constant(F::zero())); + env.write_column(IVCColumn::Block4OutputAccessTime, &Env::constant(F::zero())); + + let limbs_f_to_ff = |limbs: &[F; N_LIMBS_LARGE]| { + limbs + .iter() + .map(|f| Ff::from(BigUint::try_from(*f).unwrap())) + .collect::>() + .try_into() + .unwrap() + }; + let xp_limbs_ff: [Ff; N_LIMBS_LARGE] = limbs_f_to_ff(&xp_limbs); + let yp_limbs_ff: [Ff; N_LIMBS_LARGE] = limbs_f_to_ff(&yp_limbs); + let xq_limbs_ff: [Ff; N_LIMBS_LARGE] = limbs_f_to_ff(&xq_limbs); + let yq_limbs_ff: [Ff; N_LIMBS_LARGE] = limbs_f_to_ff(&yq_limbs); + + let xp = combine_large_to_full_field(xp_limbs_ff); + let xq = combine_large_to_full_field(xq_limbs_ff); + let yp = combine_large_to_full_field(yp_limbs_ff); + let yq = combine_large_to_full_field(yq_limbs_ff); + + let (xr, yr) = ec_add_circuit( + &mut SubEnvLookup::new( + &mut SubEnvColumn::new(env, IVCFECLens {}), + IVCFECLookupLens(PhantomData), + ), + xp, + yp, + xq, + yq, + ); + + // repacking results into 75 bits. + let xr_limbs_large: [F; N_LIMBS_LARGE] = + limb_decompose_ff::(&xr); + let yr_limbs_large: [F; N_LIMBS_LARGE] = + limb_decompose_ff::(&yr); + + write_column_array_const(env, &xr_limbs_large, IVCColumn::Block4OutputRepacked); + write_column_array_const(env, &yr_limbs_large, |i| { + IVCColumn::Block4OutputRepacked(4 + i) + }); + + constrain_ecadds::(env); + + env.next_row(); + } +} + +#[allow(clippy::needless_range_loop)] +pub fn process_challenges( + env: &mut Env, + fold_iteration: usize, + h_r: F, + chal_l: &[F; N_CHALS], + r: F, +) where + F: PrimeField, + Env: MultiRowReadCap, +{ + let mut curr_alpha_r_pow: F = F::one(); + + for block_row_i in 0..block_height::(4) { + env.write_column( + IVCColumn::FoldIteration, + &Env::constant(F::from(fold_iteration as u64)), + ); + + curr_alpha_r_pow *= h_r; + + env.write_column(IVCColumn::Block5ConstHr, &Env::constant(h_r)); + env.write_column(IVCColumn::Block5ConstR, &Env::constant(r)); + env.write_column( + IVCColumn::Block5ChalLeft, + &Env::constant(chal_l[block_row_i]), + ); + env.write_column(IVCColumn::Block5ChalRight, &Env::constant(curr_alpha_r_pow)); + env.write_column( + IVCColumn::Block5ChalOutput, + &Env::constant(curr_alpha_r_pow * r + chal_l[block_row_i]), + ); + + constrain_challenges(env); + + env.next_row() + } +} + +#[allow(clippy::needless_range_loop)] +pub fn process_u( + env: &mut Env, + fold_iteration: usize, + u_l: F, + r: F, +) where + F: PrimeField, + Env: MultiRowReadCap, +{ + env.write_column( + IVCColumn::FoldIteration, + &Env::constant(F::from(fold_iteration as u64)), + ); + + env.write_column(IVCColumn::Block6ConstR, &Env::constant(r)); + env.write_column(IVCColumn::Block6ULeft, &Env::constant(u_l)); + env.write_column(IVCColumn::Block6UOutput, &Env::constant(u_l + r)); + + constrain_u(env); + + env.next_row(); +} + +/// Builds selectors for the IVC circuit. +/// The round constants for Poseidon are not added in this function, and must be +/// done separately. +/// The size of the array is the total number of public values required for the +/// IVC. Therefore, it includes the potential round constants required by +/// the hash function. +#[allow(clippy::needless_range_loop)] +pub fn build_fixed_selectors( + domain_size: usize, +) -> [Vec; N_FSEL_IVC] { + // Selectors can be only generated for BN254G1 for now, because + // that's what Poseidon works with. + use ark_ff::{One, Zero}; + use kimchi_msm::Fp; + + assert!( + total_height::() < domain_size, + "IVC circuit (height {:?}) cannot be fit into domain size ({domain_size})", + total_height::(), + ); + + // 3*N + 6*N+2 + N+1 + 35*N + 5 + N_CHALS + 1 = + // 45N + 9 + N_CHALS + let mut selectors: [Vec; N_FSEL_IVC] = + core::array::from_fn(|_| vec![Fp::zero(); domain_size]); + let mut curr_row = 0; + for block_i in 0..N_BLOCKS { + for _i in 0..block_height::(block_i) { + assert!( + curr_row < domain_size, + "The domain size is too small to handle the IVC circuit" + ); + selectors[block_i][curr_row] = Fp::one(); + curr_row += 1; + } + } + + for i in N_BLOCKS..N_FSEL_IVC - N_BLOCKS { + PoseidonBN254Parameters + .constants() + .iter() + .enumerate() + .for_each(|(_round, rcs)| { + rcs.iter().enumerate().for_each(|(_state_index, rc)| { + selectors[i] = vec![*rc; domain_size]; + }); + }); + } + + selectors +} + +/// Instantiates the IVC circuit for folding. L is relaxed (folded) +/// instance, and R is strict (new) instance that is being relaxed at +/// this step. `N_COL_TOTAL` is the total number of columns for IVC + APP. +/// `N_CHALS` is the number of challenges, which contains also the alphas used +/// to combine constraints, see [top level documentation in +/// folding](folding::expressions). +/// The number of commitments is the total number, and it is expecting the +/// commitments to also the previous IVC columns +// FIXME: we must accept the scaled right commitments and the right instance +// commitments +// FIXME: Env should be implementing like a IVCCapability trait, which contains +// the sponge for instance, and the buckets for the MSM. All the data points +// used here should be saved inside it, and this function should only take as an +// argument the environment. +// FIXME: the fold_iteration variable should be inside the environment +#[allow(clippy::too_many_arguments)] +pub fn ivc_circuit( + env: &mut Env, + fold_iteration: usize, + comms_left: Box<[(Ff, Ff); N_COL_TOTAL]>, + comms_right: Box<[(Ff, Ff); N_COL_TOTAL]>, + comms_out: Box<[(Ff, Ff); N_COL_TOTAL]>, + error_terms: [(Ff, Ff); 3], // E_L, E_R, E_O + t_terms: [(Ff, Ff); 2], // T_0, T_1 + u_l: F, // part of the relaxed instance. + chal_l: Box<[F; N_CHALS]>, // challenges + poseidon_params: &PParams, + domain_size: usize, +) where + F: PrimeField, + Ff: PrimeField, + PParams: PoseidonParams, + Env: DirectWitnessCap + + HybridCopyCap + + LookupCap>, +{ + assert!( + total_height::() < domain_size, + "IVC circuit (height {:?}) cannot be fit into domain size ({domain_size})", + total_height::(), + ); + + assert!(chal_l.len() == N_CHALS); + + let (_comms_small, comms_large, comms_xlarge) = process_inputs::<_, _, _, N_COL_TOTAL, N_CHALS>( + env, + fold_iteration, + [comms_left, comms_right, comms_out], + ); + let (hash_r_var, r_var, phi_var) = process_hashes::<_, _, _, N_COL_TOTAL, N_CHALS>( + env, + fold_iteration, + poseidon_params, + &comms_xlarge, + ); + let r: F = Env::variable_to_field(r_var); + let phi: F = Env::variable_to_field(phi_var); + let hash_r: F = Env::variable_to_field(hash_r_var); + let scalar_limbs = + process_scalars::<_, Ff, _, N_COL_TOTAL, N_CHALS>(env, fold_iteration, r, phi); + process_ecadds::<_, Ff, _, N_COL_TOTAL, N_CHALS>( + env, + fold_iteration, + scalar_limbs, + &comms_large, + error_terms, + t_terms, + ); + process_challenges::<_, _, N_COL_TOTAL, N_CHALS>(env, fold_iteration, hash_r, &chal_l, r); + process_u::<_, _, N_COL_TOTAL>(env, fold_iteration, u_l, r); +} + +/// Base case IVC circuit, completely turned off. +/// For the base case, we do set the fold iteration to 0, and we don't +/// do any computation. +/// As each constraint is multiplied by the fold iteration, this will simulate a +/// "deactivation" of the IVC circuit. +// FIXME: this is not the final version. +pub fn ivc_circuit_base_case( + env: &mut Env, + domain_size: usize, +) where + F: PrimeField, + Env: DirectWitnessCap + HybridCopyCap, +{ + assert!( + total_height::() < domain_size, + "IVC circuit (height {:?}) cannot be fit into domain size ({domain_size})", + total_height::(), + ); + + // Assuming tables are initialized to zero we don't even have to do this. + let fold_iteration = 0; + for block_i in 0..N_BLOCKS { + for _i in 0..block_height::(block_i) { + env.write_column( + IVCColumn::FoldIteration, + &Env::constant(F::from(fold_iteration as u64)), + ); + env.next_row(); + } + } +} diff --git a/ivc/src/ivc/lookups.rs b/ivc/src/ivc/lookups.rs new file mode 100644 index 0000000000..b9475317f0 --- /dev/null +++ b/ivc/src/ivc/lookups.rs @@ -0,0 +1,108 @@ +use ark_ff::PrimeField; +use kimchi_msm::{ + circuit_design::composition::MPrism, fec::lookups as feclookup, logup::LookupTableID, + serialization::lookups as serlookup, +}; +use std::marker::PhantomData; + +/// Enumeration of concrete lookup tables used in serialization circuit. +#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)] +pub enum IVCLookupTable { + SerLookupTable(serlookup::LookupTable), +} + +impl LookupTableID for IVCLookupTable { + fn to_u32(&self) -> u32 { + match self { + Self::SerLookupTable(lt) => lt.to_u32(), + } + } + + fn from_u32(value: u32) -> Self { + if value < 4 { + Self::SerLookupTable(serlookup::LookupTable::from_u32(value)) + } else { + panic!("Invalid lookup table id") + } + } + + /// All tables are fixed tables. + fn is_fixed(&self) -> bool { + true + } + + fn runtime_create_column(&self) -> bool { + panic!("No runtime tables specified"); + } + + fn length(&self) -> usize { + match self { + Self::SerLookupTable(lt) => lt.length(), + } + } + + /// Converts a value to its index in the fixed table. + fn ix_by_value(&self, value: &[F]) -> Option { + match self { + Self::SerLookupTable(lt) => lt.ix_by_value(value), + } + } + + fn all_variants() -> Vec { + serlookup::LookupTable::::all_variants() + .into_iter() + .map(IVCLookupTable::SerLookupTable) + .collect() + } +} + +impl IVCLookupTable { + /// Provides a full list of entries for the given table. + pub fn entries(&self, domain_d1_size: u64) -> Option> { + match self { + Self::SerLookupTable(lt) => lt.entries(domain_d1_size), + } + } +} + +pub struct IVCFECLookupLens(pub PhantomData); + +impl MPrism for IVCFECLookupLens { + type Source = IVCLookupTable; + type Target = feclookup::LookupTable; + + fn traverse(&self, source: Self::Source) -> Option { + match source { + IVCLookupTable::SerLookupTable(serlookup::LookupTable::RangeCheck15) => { + Some(feclookup::LookupTable::RangeCheck15) + } + IVCLookupTable::SerLookupTable(serlookup::LookupTable::RangeCheck14Abs) => { + Some(feclookup::LookupTable::RangeCheck14Abs) + } + IVCLookupTable::SerLookupTable(serlookup::LookupTable::RangeCheck9Abs) => { + Some(feclookup::LookupTable::RangeCheck9Abs) + } + IVCLookupTable::SerLookupTable(serlookup::LookupTable::RangeCheckFfHighest(p)) => { + Some(feclookup::LookupTable::RangeCheckFfHighest(p)) + } + _ => None, + } + } + + fn re_get(&self, target: Self::Target) -> Self::Source { + match target { + feclookup::LookupTable::RangeCheck15 => { + IVCLookupTable::SerLookupTable(serlookup::LookupTable::RangeCheck15) + } + feclookup::LookupTable::RangeCheck14Abs => { + IVCLookupTable::SerLookupTable(serlookup::LookupTable::RangeCheck14Abs) + } + feclookup::LookupTable::RangeCheck9Abs => { + IVCLookupTable::SerLookupTable(serlookup::LookupTable::RangeCheck9Abs) + } + feclookup::LookupTable::RangeCheckFfHighest(p) => { + IVCLookupTable::SerLookupTable(serlookup::LookupTable::RangeCheckFfHighest(p)) + } + } + } +} diff --git a/ivc/src/ivc/mod.rs b/ivc/src/ivc/mod.rs new file mode 100644 index 0000000000..95e0f805d3 --- /dev/null +++ b/ivc/src/ivc/mod.rs @@ -0,0 +1,272 @@ +pub mod columns; +pub mod constraints; +pub mod helpers; +pub mod interpreter; +pub mod lookups; + +use self::columns::N_BLOCKS; +use crate::poseidon_8_56_5_3_2::bn254::NB_CONSTRAINTS as N_CONSTRAINTS_POSEIDON; + +/// The biggest packing variant for foreign field. Used for hashing. 150-bit limbs. +pub const LIMB_BITSIZE_XLARGE: usize = 150; + +/// The biggest packing format, 2 limbs. +pub const N_LIMBS_XLARGE: usize = 2; + +/// Number of additional columns that a reduction to degree 2 will +/// require. +/// A regression test is available in the tests directory, under the name +/// `test_regression_additional_columns_reduction_to_degree_2` +pub const N_ADDITIONAL_WIT_COL_QUAD: usize = 335; + +/// Number of constraints used by the IVC circuit. +pub const N_CONSTRAINTS: usize = N_CONSTRAINTS_POSEIDON + N_BLOCKS + 61; + +/// Number of alphas needed for the IVC circuit, equal is the number +/// of all the constraints per row. +/// +/// TODO: We can do with less challenges, just N_BLOCKS + +/// max(constraints_in_block) = N_BLOCKS + N_CONSTRAINTS_POSEIDOn. +/// +/// Supposing the Poseidon circuit has the highest number of +/// constraints, for now. We also add the number of "blocks" in the +/// IVC circuit as there will be [N_BLOCKS] alphas required to +/// aggregate all blocks on each row. +pub const N_ALPHAS: usize = N_CONSTRAINTS; + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use crate::{ + ivc::{ + columns::{IVCColumn, N_BLOCKS, N_FSEL_IVC}, + constraints::constrain_ivc, + interpreter::{build_fixed_selectors, ivc_circuit}, + lookups::IVCLookupTable, + }, + poseidon_8_56_5_3_2::bn254::PoseidonBN254Parameters, + }; + use ark_ff::{UniformRand, Zero}; + use kimchi_msm::{ + circuit_design::{ + composition::{IdMPrism, MPrism}, + ConstraintBuilderEnv, SubEnvLookup, WitnessBuilderEnv, + }, + columns::ColumnIndexer, + logup::LookupTableID, + Ff1, Fp, + }; + use o1_utils::box_array; + use rand::{CryptoRng, RngCore}; + + use super::N_ALPHAS; + + // Total number of columns in IVC and Application circuits. + pub const TEST_N_COL_TOTAL: usize = IVCColumn::N_COL + 50; + + // IVC can process more challenges than just alphas, generally. + // However we do not have any in this test. + pub const TEST_N_CHALS: usize = N_ALPHAS; + + pub const TEST_DOMAIN_SIZE: usize = 1 << 15; + + type IVCWitnessBuilderEnvRaw = WitnessBuilderEnv< + Fp, + IVCColumn, + { ::N_COL - N_BLOCKS }, + { ::N_COL - N_BLOCKS }, + 0, + N_FSEL_IVC, + LT, + >; + + /// Generic IVC circuit builder. + fn build_ivc_circuit< + RNG: RngCore + CryptoRng, + LT: LookupTableID, + L: MPrism>, + >( + rng: &mut RNG, + domain_size: usize, + fold_iteration: usize, + lt_lens: L, + ) -> IVCWitnessBuilderEnvRaw { + let mut witness_env = IVCWitnessBuilderEnvRaw::::create(); + + let mut comms_left: Box<_> = box_array![(Ff1::zero(),Ff1::zero()); TEST_N_COL_TOTAL]; + let mut comms_right: Box<_> = box_array![(Ff1::zero(),Ff1::zero()); TEST_N_COL_TOTAL]; + let mut comms_output: Box<_> = box_array![(Ff1::zero(),Ff1::zero()); TEST_N_COL_TOTAL]; + + for i in 0..TEST_N_COL_TOTAL { + comms_left[i] = ( + ::rand(rng), + ::rand(rng), + ); + comms_right[i] = ( + ::rand(rng), + ::rand(rng), + ); + comms_output[i] = ( + ::rand(rng), + ::rand(rng), + ); + } + + println!("Building fixed selectors"); + + let fixed_selectors: [Vec; N_FSEL_IVC] = + build_fixed_selectors::(domain_size); + + witness_env.set_fixed_selectors(fixed_selectors.to_vec()); + + println!("Calling the IVC circuit"); + // TODO add nonzero E/T values. + ivc_circuit::<_, _, _, _, TEST_N_COL_TOTAL, TEST_N_CHALS>( + &mut SubEnvLookup::new(&mut witness_env, lt_lens), + fold_iteration, + comms_left, + comms_right, + comms_output, + [(Ff1::zero(), Ff1::zero()); 3], + [(Ff1::zero(), Ff1::zero()); 2], + Fp::zero(), + Box::new( + (*vec![Fp::zero(); TEST_N_CHALS].into_boxed_slice()) + .try_into() + .unwrap(), + ), + &PoseidonBN254Parameters, + TEST_DOMAIN_SIZE, + ); + + witness_env + } + + #[test] + /// Tests if building the IVC circuit succeeds when using the general case + /// (i.e. fold_iteration != 0). + pub fn heavy_test_ivc_circuit_general_case() { + let mut rng = o1_utils::tests::make_test_rng(None); + build_ivc_circuit::<_, IVCLookupTable, _>( + &mut rng, + 1 << 15, + 1, + IdMPrism::>::default(), + ); + } + + #[test] + fn test_regression_ivc_constraints() { + let mut constraint_env = ConstraintBuilderEnv::>::create(); + constrain_ivc::(&mut constraint_env); + let constraints = constraint_env.get_relation_constraints(); + + let mut constraints_degrees = HashMap::new(); + + // Regression testing for the number of constraints and their degree + { + // 67 + 432 (Poseidon) + assert_eq!(constraints.len(), 499); + constraints.iter().for_each(|c| { + let degree = c.degree(1, 0); + *constraints_degrees.entry(degree).or_insert(0) += 1; + }); + + assert_eq!(constraints_degrees.get(&1), None); + assert_eq!(constraints_degrees.get(&2), Some(&6)); + assert_eq!(constraints_degrees.get(&3), Some(&215)); + assert_eq!(constraints_degrees.get(&4), Some(&278)); + assert_eq!(constraints_degrees.get(&5), None); + + // Maximum degree is 5 + // - fold_iteration increases by one + // - the public selectors increase by one + assert!(constraints.iter().all(|c| c.degree(1, 0) <= 5)); + } + } + + #[test] + /// Completeness test for the IVC circuit in the general case (i.e. + /// fold_iteration != 0). + fn heavy_test_completeness_ivc_general_case() { + let fold_iteration = 1; + + let mut rng = o1_utils::tests::make_test_rng(None); + + let domain_size = 1 << 15; + + let witness_env = build_ivc_circuit::<_, IVCLookupTable, _>( + &mut rng, + domain_size, + fold_iteration, + IdMPrism::>::default(), + ); + let relation_witness = witness_env.get_relation_witness(domain_size); + + let mut constraint_env = ConstraintBuilderEnv::>::create(); + constrain_ivc::(&mut constraint_env); + let constraints = constraint_env.get_relation_constraints(); + + let fixed_selectors: Box<[Vec; N_FSEL_IVC]> = Box::new(build_fixed_selectors::< + TEST_N_COL_TOTAL, + TEST_N_CHALS, + >(domain_size)); + + kimchi_msm::test::test_completeness_generic_no_lookups::< + { IVCColumn::N_COL - N_BLOCKS }, + { IVCColumn::N_COL - N_BLOCKS }, + 0, + N_FSEL_IVC, + _, + >( + constraints, + fixed_selectors, + relation_witness, + domain_size, + &mut rng, + ); + } + + #[test] + /// Completeness test for the IVC circuit in the base case (i.e. + /// fold_iteration = 0). + fn heavy_test_completeness_ivc_base_case() { + let fold_iteration = 0; + + let mut rng = o1_utils::tests::make_test_rng(None); + + let domain_size = 1 << 15; + + let witness_env = build_ivc_circuit::<_, IVCLookupTable, _>( + &mut rng, + domain_size, + fold_iteration, + IdMPrism::>::default(), + ); + let relation_witness = witness_env.get_relation_witness(domain_size); + + let mut constraint_env = ConstraintBuilderEnv::>::create(); + constrain_ivc::(&mut constraint_env); + let constraints = constraint_env.get_relation_constraints(); + + let fixed_selectors: Box<[Vec; N_FSEL_IVC]> = Box::new(build_fixed_selectors::< + TEST_N_COL_TOTAL, + TEST_N_CHALS, + >(domain_size)); + + kimchi_msm::test::test_completeness_generic_no_lookups::< + { IVCColumn::N_COL - N_BLOCKS }, + { IVCColumn::N_COL - N_BLOCKS }, + 0, + N_FSEL_IVC, + _, + >( + constraints, + fixed_selectors, + relation_witness, + domain_size, + &mut rng, + ); + } +} diff --git a/ivc/src/lib.rs b/ivc/src/lib.rs new file mode 100644 index 0000000000..1037832192 --- /dev/null +++ b/ivc/src/lib.rs @@ -0,0 +1,218 @@ +//! This crate provides a circuit to achieve Incremental Verifiable +//! Computation (IVC) based on a variant of the folding scheme described in the +//! paper [Nova](https://eprint.iacr.org/2021/370.pdf). For the rest of the +//! document, we do suppose that the curve is BN254. +//! +//! Note that in the rest of this document, we will mix the terms +//! "relaxation/relaxed" with "homogeneisation/homogeneous" polynomials. The +//! reason is explained in the [folding] library. The term "running relaxed +//! instance" can be rephrased as "an accumulator of evaluations of the +//! homogeneised polynomials describing the computation and commitments to them". +//! +//! ## A recap of the Nova IVC scheme +//! +//! The [Nova paper](https://eprint.iacr.org/2021/370.pdf) describes a IVC +//! scheme to be used by the folding scheme described in the same paper. +//! +//! The IVC scheme uses different notations for different concepts: +//! - `F` is the function to be computed. In our codebase, we call it the +//! "application circuit". +//! - `F'` is the augmented function with the verifier computation. In our +//! codebase, we call it the "IVC circuit + APP circuit", or also the "joint +//! circuit". +//! - `U_i` is the accumulator of the function `F'` up to step `i`. For clarity, +//! we will sometimes use the notation `acc_i` to denote the accumulator. The +//! accumulator in the [folding] library is called the "left instance". It is a +//! "running relaxed instance". +//! - `u_i` is the current instance of the function `F'` to be folded into the +//! accumulator. In the [folding] library, it is called the "right instance". +//! - `U_(i + 1)` is the accumulator of the function `F'` up to step `i + 1`, +//! i.e. the accumulation of `U_i` and `u_i`. In the [folding] library, it is +//! called the "output instance" or the "folded output". The "folded output" and +//! the "left instance" should be seen as both accumulators, and the right +//! instance should be seen as the difference between both. +//! +//! The IVC scheme described in Fig 4 works as follow, for a given step `i`: +//! - In the circuit (or also called "the augmented function `F'`"), the prover +//! will compute a hash of the left instance `acc_i`. It is a way to "compress" +//! the information of the previous computation. It is described by the notation +//! `u_i.x ?= Hash(i, z0, zi, U_i)` in the Fig 4 on the left. +//! The value `u_i.x` is provided as a public input to the circuit. In the code, +//! we could summarize with the following pseudo code: +//! ```text +//! left_commitments.into_iter().for_each(|comm| { +//! env.process_hash(comm, LeftCommitment) +//! }) +//! env.assert_eq(env.output_left, env.hash_state[LeftCommitment]) +//! ``` +//! - The prover will also compute the expected hash for the next iteration, by +//! computing the hash of the output instance `U_(i + 1)`. It is described by +//! the notation `u_(i + 1).x ?= Hash(i + 1, z0, zi, U_(i + 1))` in the Fig 4 on +//! the right. Note that `U_(i + 1)` will be, at step `i + 1`, the new left +//! input. The value of the hash computed at step `i` will be compared at step +//! `i + 1`. In the code, we could summarize with the following pseudo-code: +//! ```text +//! output_commitments.into_iter().for_each(|comm| { +//! env.process_hash(comm, OutputCommitment) +//! }) +//! env.assert_eq(env.output_left, env.hash_state[OutputCommitment]) +//! ``` +//! +//! The order of the execution is encoded in the fact the the hash contains the +//! step `i` when we check the left input and `i + 1` when we compress the +//! folded output. The fact that the prover encodes the computation on the same +//! initial input is encoded by adding the initial value `z0` into the hash for +//! both hash computation. +//! Therefore, the Nova IVC scheme encodes a **sequential** and **incremental** +//! computation, which can be verified later using a SNARK on the accumulated +//! instance. +//! +//! The job regarding the correct accumulation of the commitments and the +//! scalars are done in the middle square, i.e. `acc_(i + 1) = NIFS.V(acc_i, +//! u_i)`. It performs foreign field elliptic curve additions and scalars +//! additions. +//! +// The documentation below is outdated. Please, update it after bifolding is +// implemented. +//! ## Our design +//! +//! The circuit is implemented using the generic interpreter provided by the +//! crate [kimchi_msm]. +//! The particularity of the implementation is that it doesn't use a cycle of +//! curves, aims to be as generic as possible, defer the scalar +//! multiplications to compute them in bulk and rely on the Poseidon hash. +//! Our argument will be mostly based on the state of the hash after we +//! executed the whole computation. +//! +//! The IVC circuit is divided into different sections/sets of constraints +//! described by multivariate polynomials, each +//! activated by (public) selectors that are defined at setup time. +//! The number of columns of the original computation defines the shape of the +//! circuit as it will define the values of the selectors. +//! +//! First, the folding scheme we use is described in the crate +//! [folding](folding::expressions) and is based on degree-3 polynomials. +//! We do suppose that each constraint are reduced to degree 2, and the third +//! degree is used to encode the aggregation of constraints. +//! +//! In the [Nova paper](https://eprint.iacr.org/2021/370.pdf), to provide +//! incremental verifiable computation, the authors propose a folding scheme +//! where the verifier has to compute the followings: +//! ```text +//! // Accumulation of the homogeneous value `u`: +//! u'' = u + r u' +//! // Accumulation of all the PIOP challenges: +//! for each challenge c_i: +//! c_i'' = c_i + r c_i' +//! for each alpha_i (aggregation of constraints): +//! alpha_i'' = alpha_i + r alpha_i' +//! // Accumulation of the blinders for the commitment: +//! blinder'' = blinder + r + r^2 + r^3 blinder' +//! // Accumulation of the error terms (scalar multiplication) +//! E = E1 + r T0 + r^2 T1 + r^3 E2 +//! // Randomized accumulation of the instance commitments (scalar multiplication) +//! for i in 0..N_COLUMNS +//! (C_i)_O = (C_i)_L + r (C_i)_R +//! ``` +//! +//! The accumulation of the challenges, the homogeneous value and the blinders +//! are done trivially as they are scalar field values. +//! After that, the verifier has to perform foreign field ellictic curve +//! additions and scalar multiplications `(r T0)`, `(r^2 T1)`, `(r^3 E2)` and +//! `(r (C_i)_R)` for each column. +//! +//! First, we decide to defer the computations of +//! the scalar multiplications, and we reduce the verifier work to compute only +//! the foreign field elliptic curve additions. Therefore, the verifier has +//! access already to the +//! result of `r T0`, `r^2 T1`, `r^3 E2` and `r (C_i)_R`, and must only perform the +//! addition. We call the commitments `r T0` the "scaled" commitment to `T0`, and +//! the same for the others. The commitments `(C_i)_L`, `(C_i)_R` and `(C_i)_O` are +//! called the "instance commitments". +//! +//! To perform foreign field elliptic curve addition, we split the commitments +//! into 17 chunks of 15 bits and use additive lookups as described in +//! [kimchi_msm::logup]. These 17 chunks will be used later to compute the +//! scalar multiplications using a variant of the scheme described in [the MSM +//! RFC](https://github.com/o1-labs/rfcs/blob/main/0013-efficient-msms-for-non-native-pickles-verification.md) +//! +//! The first component of our circuit is a hash function. +//! We decided to use the Poseidon hash function, and implemented a generic one +//! using the generic interpreter in [crate::poseidon_8_56_5_3_2]. The Poseidon +//! instance we decided to go with is the traditional full/partial rounds. For a +//! security of 128 bits, a substitution box of 5 and a state of 3 elements, we +//! need 8 full rounds and 56 partial rounds for the scalar field BN254. +//! Therefore, we can "absorb" 2 field elements per row (the third element is +//! kept as a buffer in the Sponge construction). +//! +//! The cost of a single Poseidon hash in terms of constraints is 432 +//! constraints, and requires 435 columns, in addition to 192 round constants +//! considered as constants "selectors" in our constraints. +//! Note that having 432 constraints per Poseidon hash means that we must also +//! have 432 constraints to accumulate the challenges used to combine the +//! constraints. It is worth noting that these alphas are the same for all +//! poseidon hashes. +//! Combining constraints is done by using another selector, on a single row, +//! see below (TODO). +//! Note that the columns used by the poseidon hash must also be aggregated +//! while running the IVC. +//! +//! The Poseidon hash is used to absorb all the instance commitments, the +//! challenges and the scaled commitments. +//! In our circuit, we first start by absorbing all the elliptic curve points. A +//! particularity is that we will use different Poseidon instances for each +//! "side" of the addition. For each point, we do assume (*at the moment*) that +//! each point is encoded as 2 field elements in the field of the circuit +//! (FIXME: there is negligeable probability that it wraps over, as the +//! coordinates are in the base field). +// We will change this soon, by absorbing the coordinate x, and the sign of y. +// If we compute on one row the hash and on the next row the ECADD, the sign of +// y can be computed on the ECADD row, and accessed by the poseidon constraints. +//! For a given set of coordinates `(x, y)`, and by supposing an initial state of +//! our permutation `(s0, s1, s2)`, we will compute on a single row the absorbtion +//! of `(x, y)`, which consists of updating the state `(s0, s1, s2)` to +//! `(s0 + x, s1 + y, s2)`. +//! For each side, we will initialize a new Poseidon state, and we will keep +//! absorbing each column of the circuit. +//! +//! We end up with the following shape: +//! ```text +//! | q_poseidon | s0 | s1 | s2 | s0 + x | s1 + y | ... | s0' | s1' | s2' | side | +//! ``` +//! where `(s0', s1', s2')` is the final state of the Poseidon permutation after +//! the execution of all the rounds, and `q_poseidon` will be a (public +//! selector) that will be set to `1` on the row that Poseidon will need to be +//! executed, `0` otherwise. +//! +//! ## IVC circuit: base case +//! +//! The base case of the IVC circuit is supported by multiplying each constraint +//! by a fold_iteration column. As it is `0` for the base case, the whole +//! circuit is "turned off". +//! We do not want to keep this in the final version, and aim to replace it with +//! a method that does not increase the degree of the constraints by one. The +//! column value is also under-constrained for now. +//! +//! The code for the base case is handled in +//! [crate::ivc::interpreter::ivc_circuit_base_case] and the corresponding +//! constraints are in [crate::ivc::constraints::constrain_ivc]. The fold +//! iteration column is set at each row by each process_* function in the +//! interpreter. + +pub mod expr_eval; +pub mod ivc; +pub mod plonkish_lang; +/// Poseidon hash function with 55 full rounds, 0 partial rounds, sbox 7, a +/// state of 3 elements and constraints of degree 2 +pub mod poseidon_55_0_7_3_2; +/// Poseidon hash function with 55 full rounds, 0 partial rounds, sbox 7, +/// a state of 3 elements and constraints of degree 7 +pub mod poseidon_55_0_7_3_7; +/// Poseidon hash function with 8 full rounds, 56 partial rounds, sbox 5, a +/// state of 3 elements and constraints of degree 2 +pub mod poseidon_8_56_5_3_2; +/// Poseidon parameters for 55 full rounds, 0 partial rounds, sbox 7, a state of +/// 3 elements +pub mod poseidon_params_55_0_7_3; +pub mod prover; +pub mod verifier; diff --git a/ivc/src/plonkish_lang.rs b/ivc/src/plonkish_lang.rs new file mode 100644 index 0000000000..4cfcc8cd0a --- /dev/null +++ b/ivc/src/plonkish_lang.rs @@ -0,0 +1,268 @@ +/// Provides definition of plonkish language related instance, +/// witness, and tools to work with them. The IVC is specialized for +/// exactly the plonkish language. +use ark_ff::{FftField, Field, One}; +use ark_poly::{Evaluations, Radix2EvaluationDomain as R2D}; +use folding::{instance_witness::Foldable, Alphas, Instance, Witness}; +use itertools::Itertools; +use kimchi::{self, circuits::berkeley_columns::BerkeleyChallengeTerm}; +use kimchi_msm::{columns::Column, witness::Witness as GenericWitness}; +use mina_poseidon::FqSponge; +use poly_commitment::{ + commitment::{absorb_commitment, CommitmentCurve}, + PolyComm, SRS, +}; +use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _}; +use std::ops::Index; +use strum_macros::{EnumCount as EnumCountMacro, EnumIter}; + +/// Vector field over F. Something like a vector. +pub trait CombinableEvals: PartialEq { + fn e_as_slice(&self) -> &[F]; + fn e_as_mut_slice(&mut self) -> &mut [F]; +} + +impl CombinableEvals for Evaluations> { + fn e_as_slice(&self) -> &[F] { + self.evals.as_slice() + } + fn e_as_mut_slice(&mut self) -> &mut [F] { + self.evals.as_mut_slice() + } +} + +impl CombinableEvals for Vec { + fn e_as_slice(&self) -> &[F] { + self.as_slice() + } + fn e_as_mut_slice(&mut self) -> &mut [F] { + self.as_mut_slice() + } +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct PlonkishWitnessGeneric { + pub witness: GenericWitness, + // This does not have to be part of the witness... can be a static + // precompiled object. + pub fixed_selectors: GenericWitness, + pub phantom: std::marker::PhantomData, +} + +pub type PlonkishWitness = + PlonkishWitnessGeneric>>; + +impl> Foldable + for PlonkishWitnessGeneric +{ + fn combine(mut a: Self, b: Self, challenge: F) -> Self { + for (a, b) in (*a.witness.cols).iter_mut().zip(*(b.witness.cols)) { + for (a, b) in (a.e_as_mut_slice()).iter_mut().zip(b.e_as_slice()) { + *a += *b * challenge; + } + } + assert!(a.fixed_selectors == b.fixed_selectors); + a + } +} + +impl< + const N_COL: usize, + const N_FSEL: usize, + Curve: CommitmentCurve, + Evals: CombinableEvals, + > Witness for PlonkishWitnessGeneric +{ +} + +impl> Index + for PlonkishWitnessGeneric +{ + type Output = [F]; + + /// Map a column alias to the corresponding witness column. + fn index(&self, index: Column) -> &Self::Output { + match index { + Column::Relation(i) => self.witness.cols[i].e_as_slice(), + Column::FixedSelector(i) => self.fixed_selectors[i].e_as_slice(), + other => panic!("Invalid column index: {other:?}"), + } + } +} + +// for selectors, () in this case as we have none +impl Index<()> + for PlonkishWitness +{ + type Output = [F]; + + fn index(&self, _index: ()) -> &Self::Output { + unreachable!() + } +} + +#[derive(PartialEq, Eq, Clone, Debug)] +pub struct PlonkishInstance< + G: CommitmentCurve, + const N_COL: usize, + const N_CHALS: usize, + const N_ALPHAS: usize, +> { + pub commitments: [G; N_COL], + pub challenges: [G::ScalarField; N_CHALS], + pub alphas: Alphas, + pub blinder: G::ScalarField, +} + +impl + Foldable for PlonkishInstance +{ + fn combine(a: Self, b: Self, challenge: G::ScalarField) -> Self { + Self { + commitments: std::array::from_fn(|i| { + (a.commitments[i] + b.commitments[i].mul(challenge)).into() + }), + challenges: std::array::from_fn(|i| a.challenges[i] + challenge * b.challenges[i]), + alphas: Alphas::combine(a.alphas, b.alphas, challenge), + blinder: a.blinder + challenge * b.blinder, + } + } +} + +impl + Instance for PlonkishInstance +{ + fn to_absorb(&self) -> (Vec, Vec) { + // FIXME: check!!!! + let mut scalars = Vec::new(); + let mut points = Vec::new(); + points.extend(self.commitments); + scalars.extend(self.challenges); + scalars.extend(self.alphas.clone().powers()); + (scalars, points) + } + + fn get_alphas(&self) -> &Alphas { + &self.alphas + } + + fn get_blinder(&self) -> G::ScalarField { + self.blinder + } +} + +// Implementation for 3 challenges; only for now. +impl + PlonkishInstance +{ + pub fn from_witness< + EFqSponge: FqSponge, + Srs: SRS + std::marker::Sync, + >( + w: &GenericWitness>>, + fq_sponge: &mut EFqSponge, + srs: &Srs, + domain: R2D, + ) -> Self { + let blinder = G::ScalarField::one(); + + let commitments: GenericWitness> = w + .into_par_iter() + .map(|w| { + let blinder = PolyComm::new(vec![blinder; 1]); + let unblinded = srs.commit_evaluations_non_hiding(domain, w); + srs.mask_custom(unblinded, &blinder).unwrap().commitment + }) + .collect(); + + // Absorbing commitments + (&commitments).into_iter().for_each(|c| { + assert!(c.len() == 1); + absorb_commitment(fq_sponge, c) + }); + + let commitments: [G; N_COL] = commitments + .into_iter() + .map(|c| c.get_first_chunk()) + .collect_vec() + .try_into() + .unwrap(); + + let beta = fq_sponge.challenge(); + let gamma = fq_sponge.challenge(); + let joint_combiner = fq_sponge.challenge(); + let challenges = [beta, gamma, joint_combiner]; + + let alpha = fq_sponge.challenge(); + let alphas = Alphas::new_sized(alpha, N_ALPHAS); + + Self { + commitments, + challenges, + alphas, + blinder, + } + } + + pub fn verify_from_witness>( + &self, + fq_sponge: &mut EFqSponge, + ) -> Result<(), String> { + (self.blinder == G::ScalarField::one()) + .then_some(()) + .ok_or("Blinder must be one")?; + + // Absorbing commitments + self.commitments + .iter() + .for_each(|c| absorb_commitment(fq_sponge, &PolyComm { chunks: vec![*c] })); + + let beta = fq_sponge.challenge(); + let gamma = fq_sponge.challenge(); + let joint_combiner = fq_sponge.challenge(); + + (self.challenges == [beta, gamma, joint_combiner]) + .then_some(()) + .ok_or("Challenges do not match the expected result")?; + + let alpha = fq_sponge.challenge(); + + (self.alphas == Alphas::new_sized(alpha, N_ALPHAS)) + .then_some(()) + .ok_or("Alphas do not match the expected result")?; + + Ok(()) + } +} + +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq, EnumIter, EnumCountMacro)] +pub enum PlonkishChallenge { + Beta, + Gamma, + JointCombiner, +} + +impl From for PlonkishChallenge { + fn from(chal: BerkeleyChallengeTerm) -> Self { + match chal { + BerkeleyChallengeTerm::Beta => PlonkishChallenge::Beta, + BerkeleyChallengeTerm::Gamma => PlonkishChallenge::Gamma, + BerkeleyChallengeTerm::JointCombiner => PlonkishChallenge::JointCombiner, + BerkeleyChallengeTerm::Alpha => panic!("Alpha not allowed in folding expressions"), + } + } +} + +impl Index + for PlonkishInstance +{ + type Output = G::ScalarField; + + fn index(&self, index: PlonkishChallenge) -> &Self::Output { + match index { + PlonkishChallenge::Beta => &self.challenges[0], + PlonkishChallenge::Gamma => &self.challenges[1], + PlonkishChallenge::JointCombiner => &self.challenges[2], + } + } +} diff --git a/ivc/src/poseidon_55_0_7_3_2/columns.rs b/ivc/src/poseidon_55_0_7_3_2/columns.rs new file mode 100644 index 0000000000..c9a6791422 --- /dev/null +++ b/ivc/src/poseidon_55_0_7_3_2/columns.rs @@ -0,0 +1,67 @@ +/// The column layout will be as follow, supposing a state size of 3 elements: +/// | C1 | C2 | C3 | C4 | C5 | C6 | ... | C_(k) | C_(k + 1) | C_(k + 2) | +/// |--- |----|----|-----|-----|-----|-----|-------|-----------|-----------| +/// | x | y | z | x' | y' | z' | ... | x'' | y'' | z'' | +/// | MDS \circ SBOX | | MDS \circ SBOX | +/// |-----------------| |-------------------------------| +/// Divided in 5 +/// blocks of degree 2 +/// constraints +/// where (x', y', z') = MDS(x^7, y^7, z^7), i.e. the result of the linear layer +/// We will have, for N full rounds: +/// - `3` input columns +/// - `5 N * 3` round columns, indexed by the round number and the index in the state, +/// the number of rounds. +use kimchi_msm::columns::{Column, ColumnIndexer}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub enum PoseidonColumn { + Input(usize), + // nb round, state + // we use the constraint: + // y = x * x -> x^2 -> i + // y' = y * y -> x^4 -> i + 1 + // y'' = y * y' -> x^6 -> i + 2 + // z = y'' * x -> i + 3 + // z * MDS -> i + 4 + // --> 5 * state, nb round + Round(usize, usize), + RoundConstant(usize, usize), +} + +impl ColumnIndexer + for PoseidonColumn +{ + // - STATE_SIZE input columns + // - for each round: + // - STATE_SIZE state columns for x^2 -> x * x + // - STATE_SIZE state columns for x^4 -> x^2 * x^2 + // - STATE_SIZE state columns for x^6 -> x^4 * x^2 + // - STATE_SIZE state columns for x^7 -> x^6 * x + // - STATE_SIZE state columns for x^7 * MDS(., L) + // - STATE_SIZE * NB_FULL_ROUND constants + const N_COL: usize = STATE_SIZE + 5 * NB_FULL_ROUND * STATE_SIZE; + + fn to_column(self) -> Column { + match self { + PoseidonColumn::Input(i) => { + assert!(i < STATE_SIZE); + Column::Relation(i) + } + PoseidonColumn::Round(round, state_index) => { + assert!(state_index < 5 * STATE_SIZE); + // We start round 0 + assert!(round < NB_FULL_ROUND); + let idx = STATE_SIZE + (round * 5 * STATE_SIZE + state_index); + Column::Relation(idx) + } + PoseidonColumn::RoundConstant(round, state_index) => { + assert!(state_index < STATE_SIZE); + // We start round 0 + assert!(round < NB_FULL_ROUND); + let idx = round * STATE_SIZE + state_index; + Column::FixedSelector(idx) + } + } + } +} diff --git a/ivc/src/poseidon_55_0_7_3_2/interpreter.rs b/ivc/src/poseidon_55_0_7_3_2/interpreter.rs new file mode 100644 index 0000000000..029e2744b3 --- /dev/null +++ b/ivc/src/poseidon_55_0_7_3_2/interpreter.rs @@ -0,0 +1,194 @@ +//! Implement an interpreter for a specific instance of the Poseidon inner permutation. +//! The Poseidon construction is defined in the paper ["Poseidon: A New Hash +//! Function"](https://eprint.iacr.org/2019/458.pdf). +//! The Poseidon instance works on a state of size `STATE_SIZE` and is designed +//! to work only with full rounds. As a reminder, the Poseidon permutation is a +//! mapping from `F^STATE_SIZE` to `F^STATE_SIZE`. +//! The user is responsible to provide the correct number of full rounds for the +//! given field and the state. +//! Also, it is hard-coded that the substitution is `7`. The user must verify +//! that `7` is coprime with `p - 1` where `p` is the order the field. +//! The constants and matrix can be generated the file +//! `poseidon/src/pasta/params.sage` + +use crate::poseidon_55_0_7_3_2::columns::PoseidonColumn; +use ark_ff::PrimeField; +use kimchi_msm::circuit_design::{ColAccessCap, ColWriteCap, HybridCopyCap}; +use num_bigint::BigUint; +use num_integer::Integer; + +/// Represents the parameters of the instance of the Poseidon permutation. +/// Constants are the round constants for each round, and MDS is the matrix used +/// by the linear layer. +/// The type is parametrized by the field, the state size, and the number of full rounds. +/// Note that the parameters are only for instances using full rounds. +// IMPROVEME merge constants and mds in a flat array, to use the CPU cache +// IMPROVEME generalise init_state for more than 3 elements +pub trait PoseidonParams { + fn constants(&self) -> [[F; STATE_SIZE]; NB_FULL_ROUNDS]; + fn mds(&self) -> [[F; STATE_SIZE]; STATE_SIZE]; +} + +/// Populates and checks one poseidon invocation. +pub fn poseidon_circuit< + F: PrimeField, + const STATE_SIZE: usize, + const NB_FULL_ROUND: usize, + PARAMETERS, + Env, +>( + env: &mut Env, + param: &PARAMETERS, + init_state: [Env::Variable; STATE_SIZE], +) -> [Env::Variable; STATE_SIZE] +where + F: PrimeField, + PARAMETERS: PoseidonParams, + Env: ColWriteCap> + + HybridCopyCap>, +{ + // Write inputs + init_state.iter().enumerate().for_each(|(i, value)| { + env.write_column(PoseidonColumn::Input(i), value); + }); + + // Create, write, and constrain all other columns. + apply_permutation(env, param) +} + +/// Apply the whole permutation of Poseidon to the state. +/// The environment has to be initialized with the input values. +pub fn apply_permutation< + F: PrimeField, + const STATE_SIZE: usize, + const NB_FULL_ROUND: usize, + PARAMETERS, + Env, +>( + env: &mut Env, + param: &PARAMETERS, +) -> [Env::Variable; STATE_SIZE] +where + F: PrimeField, + PARAMETERS: PoseidonParams, + Env: ColAccessCap> + + HybridCopyCap>, +{ + // Checking that p - 1 is coprime with 7 as it has to be the case for the sbox + { + let one = BigUint::from(1u64); + let p: BigUint = TryFrom::try_from(::MODULUS).unwrap(); + let p_minus_one = p - one.clone(); + let seven = BigUint::from(7u64); + assert_eq!(p_minus_one.gcd(&seven), one); + } + + let mut final_state: [Env::Variable; STATE_SIZE] = + std::array::from_fn(|_| Env::constant(F::zero())); + + for i in 0..NB_FULL_ROUND { + let state: [PoseidonColumn; STATE_SIZE] = { + if i == 0 { + std::array::from_fn(PoseidonColumn::Input) + } else { + let prev_round = i - 1; + // Previous outputs are in index 4, 9, and 14 if we have 3 elements + std::array::from_fn(|j| PoseidonColumn::Round(prev_round, j * 5 + 4)) + } + }; + let round_res = compute_one_round::( + env, param, i, &state, + ); + + if i == NB_FULL_ROUND - 1 { + final_state = round_res + } + } + + final_state +} + +/// Compute one round the Poseidon permutation +fn compute_one_round< + F: PrimeField, + const STATE_SIZE: usize, + const NB_FULL_ROUND: usize, + PARAMETERS, + Env, +>( + env: &mut Env, + param: &PARAMETERS, + round: usize, + elements: &[PoseidonColumn; STATE_SIZE], +) -> [Env::Variable; STATE_SIZE] +where + F: PrimeField, + PARAMETERS: PoseidonParams, + Env: ColAccessCap> + + HybridCopyCap>, +{ + // We start at round 0 + // This implementation mimicks the version described in + // poseidon_block_cipher in the mina_poseidon crate. + assert!( + round < NB_FULL_ROUND, + "The round index {:} is higher than the number of full rounds encoded in the type", + round + ); + // Applying sbox + // For a state transition from (x, y, z) to (x', y', z'), we use the + // following columns shape: + // x^2, x^4, x^6, x^7, x', y^2, y^4, y^6, y^7, y', z^2, z^4, z^6, z^7, z') + // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 + let state: Vec = elements + .iter() + .enumerate() + .map(|(i, var_col)| { + let var = env.read_column(*var_col); + // x^2 + let var_square_col = PoseidonColumn::Round(round, 5 * i); + let var_square = env.hcopy(&(var.clone() * var.clone()), var_square_col); + let var_four_col = PoseidonColumn::Round(round, 5 * i + 1); + let var_four = env.hcopy(&(var_square.clone() * var_square.clone()), var_four_col); + let var_six_col = PoseidonColumn::Round(round, 5 * i + 2); + let var_six = env.hcopy(&(var_four.clone() * var_square.clone()), var_six_col); + let var_seven_col = PoseidonColumn::Round(round, 5 * i + 3); + env.hcopy(&(var_six.clone() * var.clone()), var_seven_col) + }) + .collect(); + + // Applying the linear layer + let mds = PoseidonParams::mds(param); + let state: Vec = mds + .into_iter() + .map(|m| { + state + .clone() + .into_iter() + .zip(m) + .fold(Env::constant(F::zero()), |acc, (s_i, mds_i_j)| { + Env::constant(mds_i_j) * s_i.clone() + acc.clone() + }) + }) + .collect(); + + // Adding the round constants + let state: Vec = state + .iter() + .enumerate() + .map(|(i, var)| { + let rc = env.read_column(PoseidonColumn::RoundConstant(round, i)); + var.clone() + rc + }) + .collect(); + + let res_state: Vec = state + .iter() + .enumerate() + .map(|(i, res)| env.hcopy(res, PoseidonColumn::Round(round, 5 * i + 4))) + .collect(); + + res_state + .try_into() + .expect("Resulting state must be of STATE_SIZE length") +} diff --git a/ivc/src/poseidon_55_0_7_3_2/mod.rs b/ivc/src/poseidon_55_0_7_3_2/mod.rs new file mode 100644 index 0000000000..11e8d81b57 --- /dev/null +++ b/ivc/src/poseidon_55_0_7_3_2/mod.rs @@ -0,0 +1,180 @@ +//! Specialised circuit for Poseidon where we have maximum degree 2 constraints. + +pub mod columns; +pub mod interpreter; + +#[cfg(test)] +mod tests { + use crate::{ + poseidon_55_0_7_3_2::{columns::PoseidonColumn, interpreter, interpreter::PoseidonParams}, + poseidon_params_55_0_7_3, + poseidon_params_55_0_7_3::PlonkSpongeConstantsIVC, + }; + use ark_ff::{UniformRand, Zero}; + use kimchi_msm::{ + circuit_design::{ColAccessCap, ConstraintBuilderEnv, WitnessBuilderEnv}, + columns::ColumnIndexer, + lookups::DummyLookupTable, + Fp, + }; + use mina_poseidon::permutation::poseidon_block_cipher; + + pub struct PoseidonBN254Parameters; + + pub const STATE_SIZE: usize = 3; + pub const NB_FULL_ROUND: usize = 55; + type TestPoseidonColumn = PoseidonColumn; + pub const N_COL: usize = TestPoseidonColumn::N_COL; + pub const N_DSEL: usize = 0; + pub const N_FSEL: usize = 165; + + impl PoseidonParams for PoseidonBN254Parameters { + fn constants(&self) -> [[Fp; STATE_SIZE]; NB_FULL_ROUND] { + let rc = &poseidon_params_55_0_7_3::static_params().round_constants; + std::array::from_fn(|i| std::array::from_fn(|j| Fp::from(rc[i][j]))) + } + + fn mds(&self) -> [[Fp; STATE_SIZE]; STATE_SIZE] { + let mds = &poseidon_params_55_0_7_3::static_params().mds; + std::array::from_fn(|i| std::array::from_fn(|j| Fp::from(mds[i][j]))) + } + } + + type PoseidonWitnessBuilderEnv = WitnessBuilderEnv< + Fp, + TestPoseidonColumn, + { ::N_COL }, + { ::N_COL }, + N_DSEL, + N_FSEL, + DummyLookupTable, + >; + + #[test] + /// Tests that poseidon circuit is correctly formed (witness + /// generation + constraints match) and matches the CPU + /// specification of Poseidon. Fast to run, can be used for + /// debugging. + pub fn test_poseidon_circuit() { + let mut rng = o1_utils::tests::make_test_rng(None); + let domain_size = 1 << 4; + + let mut witness_env: PoseidonWitnessBuilderEnv = WitnessBuilderEnv::create(); + + // Write constants + { + let rc = PoseidonBN254Parameters.constants(); + rc.iter().enumerate().for_each(|(round, rcs)| { + rcs.iter().enumerate().for_each(|(state_index, rc)| { + let rc = vec![*rc; domain_size]; + witness_env.set_fixed_selector_cix( + PoseidonColumn::RoundConstant(round, state_index), + rc, + ) + }); + }); + } + + // Generate random inputs at each row + for _row in 0..domain_size { + let x: Fp = Fp::rand(&mut rng); + let y: Fp = Fp::rand(&mut rng); + let z: Fp = Fp::rand(&mut rng); + + interpreter::poseidon_circuit(&mut witness_env, &PoseidonBN254Parameters, [x, y, z]); + + // Check internal consistency of our circuit: that our + // computed values match the CPU-spec implementation of + // Poseidon. + { + let exp_output: Vec = { + let mut state: Vec = vec![x, y, z]; + poseidon_block_cipher::( + poseidon_params_55_0_7_3::static_params(), + &mut state, + ); + state + }; + let x_col: PoseidonColumn = + PoseidonColumn::Round(NB_FULL_ROUND - 1, 4); + let y_col: PoseidonColumn = + PoseidonColumn::Round(NB_FULL_ROUND - 1, 9); + let z_col: PoseidonColumn = + PoseidonColumn::Round(NB_FULL_ROUND - 1, 14); + assert_eq!(witness_env.read_column(x_col), exp_output[0]); + assert_eq!(witness_env.read_column(y_col), exp_output[1]); + assert_eq!(witness_env.read_column(z_col), exp_output[2]); + } + + witness_env.next_row(); + } + } + + #[test] + /// Checks that poseidon circuit can be proven and verified. Big domain. + pub fn heavy_test_completeness() { + let mut rng = o1_utils::tests::make_test_rng(None); + let domain_size: usize = 1 << 15; + + let (relation_witness, fixed_selectors) = { + let mut witness_env: PoseidonWitnessBuilderEnv = WitnessBuilderEnv::create(); + + let mut fixed_selectors: [Vec; N_FSEL] = + std::array::from_fn(|_| vec![Fp::zero(); 1]); + // Write constants + { + let rc = PoseidonBN254Parameters.constants(); + rc.iter().enumerate().for_each(|(round, rcs)| { + rcs.iter().enumerate().for_each(|(state_index, rc)| { + witness_env.set_fixed_selector_cix( + PoseidonColumn::RoundConstant(round, state_index), + vec![*rc; domain_size], + ); + fixed_selectors[round * STATE_SIZE + state_index] = vec![*rc; domain_size]; + }); + }); + } + + // Generate random inputs at each row + for _row in 0..domain_size { + let x: Fp = Fp::rand(&mut rng); + let y: Fp = Fp::rand(&mut rng); + let z: Fp = Fp::rand(&mut rng); + + interpreter::poseidon_circuit( + &mut witness_env, + &PoseidonBN254Parameters, + [x, y, z], + ); + + witness_env.next_row(); + } + + ( + witness_env.get_relation_witness(domain_size), + fixed_selectors, + ) + }; + + let constraints = { + let mut constraint_env = ConstraintBuilderEnv::::create(); + interpreter::apply_permutation(&mut constraint_env, &PoseidonBN254Parameters); + let constraints = constraint_env.get_constraints(); + + // We have 825 constraints in total + assert_eq!(constraints.len(), 5 * STATE_SIZE * NB_FULL_ROUND); + // Maximum degree of the constraints is 2 + assert_eq!(constraints.iter().map(|c| c.degree(1, 0)).max().unwrap(), 2); + + constraints + }; + + kimchi_msm::test::test_completeness_generic_no_lookups::( + constraints, + Box::new(fixed_selectors), + relation_witness, + domain_size, + &mut rng, + ); + } +} diff --git a/ivc/src/poseidon_55_0_7_3_7/columns.rs b/ivc/src/poseidon_55_0_7_3_7/columns.rs new file mode 100644 index 0000000000..cf62625369 --- /dev/null +++ b/ivc/src/poseidon_55_0_7_3_7/columns.rs @@ -0,0 +1,50 @@ +//! The column layout will be as follow, supposing a state size of 3 elements: +//! | C1 | C2 | C3 | C4 | C5 | C6 | ... | C_(k) | C_(k + 1) | C_(k + 2) | +//! |--- |----|----|-----|-----|-----|-----|-------|-----------|-----------| +//! | x | y | z | x' | y' | z' | ... | x'' | y'' | z'' | +//! | MDS \circ SBOX | | MDS \circ SBOX | +//! |-----------------| |-------------------------------| +//! where (x', y', z') = MDS(x^7, y^7, z^7), i.e. the result of the linear layer +//! We will have, for N full rounds: +//! - `3` input columns +//! - `N * 3` round columns, indexed by the round number and the index in the state, +//! the number of rounds. +//! The round constants are added as fixed selectors. + +use kimchi_msm::columns::{Column, ColumnIndexer}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub enum PoseidonColumn { + Input(usize), + Round(usize, usize), + RoundConstant(usize, usize), +} + +impl ColumnIndexer + for PoseidonColumn +{ + const N_COL: usize = STATE_SIZE + NB_FULL_ROUND * STATE_SIZE; + + fn to_column(self) -> Column { + match self { + PoseidonColumn::Input(i) => { + assert!(i < STATE_SIZE); + Column::Relation(i) + } + PoseidonColumn::Round(round, state_index) => { + assert!(state_index < STATE_SIZE); + // We start round 0 + assert!(round < NB_FULL_ROUND); + let idx = STATE_SIZE + (round * STATE_SIZE + state_index); + Column::Relation(idx) + } + PoseidonColumn::RoundConstant(round, state_index) => { + assert!(state_index < STATE_SIZE); + // We start round 0 + assert!(round < NB_FULL_ROUND); + let idx = round * STATE_SIZE + state_index; + Column::FixedSelector(idx) + } + } + } +} diff --git a/ivc/src/poseidon_55_0_7_3_7/interpreter.rs b/ivc/src/poseidon_55_0_7_3_7/interpreter.rs new file mode 100644 index 0000000000..706f36eb4a --- /dev/null +++ b/ivc/src/poseidon_55_0_7_3_7/interpreter.rs @@ -0,0 +1,181 @@ +//! Implement an interpreter for a specific instance of the Poseidon inner permutation. +//! The Poseidon construction is defined in the paper ["Poseidon: A New Hash +//! Function"](https://eprint.iacr.org/2019/458.pdf). +//! The Poseidon instance works on a state of size `STATE_SIZE` and is designed +//! to work only with full rounds. As a reminder, the Poseidon permutation is a +//! mapping from `F^STATE_SIZE` to `F^STATE_SIZE`. +//! The user is responsible to provide the correct number of full rounds for the +//! given field and the state. +//! Also, it is hard-coded that the substitution is `7`. The user must verify +//! that `7` is coprime with `p - 1` where `p` is the order the field. +//! The constants and matrix can be generated the file +//! `poseidon/src/pasta/params.sage` + +use crate::poseidon_55_0_7_3_7::columns::PoseidonColumn; +use ark_ff::PrimeField; +use kimchi_msm::circuit_design::{ColAccessCap, ColWriteCap, HybridCopyCap}; +use num_bigint::BigUint; +use num_integer::Integer; + +/// Represents the parameters of the instance of the Poseidon permutation. +/// Constants are the round constants for each round, and MDS is the matrix used +/// by the linear layer. +/// The type is parametrized by the field, the state size, and the number of full rounds. +/// Note that the parameters are only for instances using full rounds. +// IMPROVEME merge constants and mds in a flat array, to use the CPU cache +// IMPROVEME generalise init_state for more than 3 elements +pub trait PoseidonParams { + fn constants(&self) -> [[F; STATE_SIZE]; NB_FULL_ROUNDS]; + fn mds(&self) -> [[F; STATE_SIZE]; STATE_SIZE]; +} + +/// Populates and checks one poseidon invocation. +pub fn poseidon_circuit< + F: PrimeField, + const STATE_SIZE: usize, + const NB_FULL_ROUND: usize, + PARAMETERS, + Env, +>( + env: &mut Env, + param: &PARAMETERS, + init_state: [Env::Variable; STATE_SIZE], +) -> [Env::Variable; STATE_SIZE] +where + F: PrimeField, + PARAMETERS: PoseidonParams, + Env: ColWriteCap> + + HybridCopyCap>, +{ + // Write inputs + init_state.iter().enumerate().for_each(|(i, value)| { + env.write_column(PoseidonColumn::Input(i), value); + }); + + // Create, write, and constrain all other columns. + apply_permutation(env, param) +} + +/// Apply the whole permutation of Poseidon to the state. +/// The environment has to be initialized with the input values. +pub fn apply_permutation< + F: PrimeField, + const STATE_SIZE: usize, + const NB_FULL_ROUND: usize, + PARAMETERS, + Env, +>( + env: &mut Env, + param: &PARAMETERS, +) -> [Env::Variable; STATE_SIZE] +where + F: PrimeField, + PARAMETERS: PoseidonParams, + Env: ColAccessCap> + + HybridCopyCap>, +{ + // Checking that p - 1 is coprime with 7 as it has to be the case for the sbox + { + let one = BigUint::from(1u64); + let p: BigUint = TryFrom::try_from(::MODULUS).unwrap(); + let p_minus_one = p - one.clone(); + let seven = BigUint::from(7u64); + assert_eq!(p_minus_one.gcd(&seven), one); + } + + let mut final_state: [Env::Variable; STATE_SIZE] = + std::array::from_fn(|_| Env::constant(F::zero())); + + for i in 0..NB_FULL_ROUND { + let state: [PoseidonColumn; STATE_SIZE] = { + if i == 0 { + std::array::from_fn(PoseidonColumn::Input) + } else { + std::array::from_fn(|j| PoseidonColumn::Round(i - 1, j)) + } + }; + let round_res = compute_one_round::( + env, param, i, &state, + ); + + if i == NB_FULL_ROUND - 1 { + final_state = round_res + } + } + + final_state +} + +/// Compute one round the Poseidon permutation +fn compute_one_round< + F: PrimeField, + const STATE_SIZE: usize, + const NB_FULL_ROUND: usize, + PARAMETERS, + Env, +>( + env: &mut Env, + param: &PARAMETERS, + round: usize, + elements: &[PoseidonColumn; STATE_SIZE], +) -> [Env::Variable; STATE_SIZE] +where + F: PrimeField, + PARAMETERS: PoseidonParams, + Env: ColAccessCap> + + HybridCopyCap>, +{ + // We start at round 0 + // This implementation mimicks the version described in + // poseidon_block_cipher in the mina_poseidon crate. + assert!( + round < NB_FULL_ROUND, + "The round index {:} is higher than the number of full rounds encoded in the type", + round + ); + // Applying sbox + let state: Vec = elements + .iter() + .map(|x| { + let x_col = env.read_column(*x); + let x_square = x_col.clone() * x_col.clone(); + let x_four = x_square.clone() * x_square.clone(); + x_four.clone() * x_square.clone() * x_col.clone() + }) + .collect(); + + // Applying the linear layer + let mds = PoseidonParams::mds(param); + let state: Vec = mds + .into_iter() + .map(|m| { + state + .clone() + .into_iter() + .zip(m) + .fold(Env::constant(F::zero()), |acc, (s_i, mds_i_j)| { + Env::constant(mds_i_j) * s_i.clone() + acc.clone() + }) + }) + .collect(); + + // Adding the round constants + let state: Vec = state + .iter() + .enumerate() + .map(|(i, x)| { + let rc = env.read_column(PoseidonColumn::RoundConstant(round, i)); + x.clone() + rc + }) + .collect(); + + let res_state: Vec = state + .iter() + .enumerate() + .map(|(i, res)| env.hcopy(res, PoseidonColumn::Round(round, i))) + .collect(); + + res_state + .try_into() + .expect("Resulting state must be of STATE_SIZE length") +} diff --git a/ivc/src/poseidon_55_0_7_3_7/mod.rs b/ivc/src/poseidon_55_0_7_3_7/mod.rs new file mode 100644 index 0000000000..bfebe16ba6 --- /dev/null +++ b/ivc/src/poseidon_55_0_7_3_7/mod.rs @@ -0,0 +1,183 @@ +pub mod columns; +pub mod interpreter; + +#[cfg(test)] +mod tests { + use crate::{ + poseidon_55_0_7_3_7::{columns::PoseidonColumn, interpreter, interpreter::PoseidonParams}, + poseidon_params_55_0_7_3, + poseidon_params_55_0_7_3::PlonkSpongeConstantsIVC, + }; + use ark_ff::{UniformRand, Zero}; + use kimchi_msm::{ + circuit_design::{ColAccessCap, ConstraintBuilderEnv, WitnessBuilderEnv}, + columns::ColumnIndexer, + lookups::DummyLookupTable, + Fp, + }; + use mina_poseidon::permutation::poseidon_block_cipher; + + pub struct PoseidonBN254Parameters; + + pub const STATE_SIZE: usize = 3; + pub const NB_FULL_ROUND: usize = 55; + type TestPoseidonColumn = PoseidonColumn; + pub const N_COL: usize = TestPoseidonColumn::N_COL; + pub const N_DSEL: usize = 0; + pub const N_FSEL: usize = 165; + + impl PoseidonParams for PoseidonBN254Parameters { + fn constants(&self) -> [[Fp; STATE_SIZE]; NB_FULL_ROUND] { + let rc = &poseidon_params_55_0_7_3::static_params().round_constants; + std::array::from_fn(|i| std::array::from_fn(|j| Fp::from(rc[i][j]))) + } + + fn mds(&self) -> [[Fp; STATE_SIZE]; STATE_SIZE] { + let mds = &poseidon_params_55_0_7_3::static_params().mds; + std::array::from_fn(|i| std::array::from_fn(|j| Fp::from(mds[i][j]))) + } + } + + type PoseidonWitnessBuilderEnv = WitnessBuilderEnv< + Fp, + TestPoseidonColumn, + { ::N_COL }, + { ::N_COL }, + N_DSEL, + N_FSEL, + DummyLookupTable, + >; + + #[test] + /// Tests that poseidon circuit is correctly formed (witness + /// generation + constraints match) and matches the CPU + /// specification of Poseidon. Fast to run, can be used for + /// debugging. + pub fn test_poseidon_circuit() { + let mut rng = o1_utils::tests::make_test_rng(None); + let domain_size = 1 << 4; + + let mut witness_env: PoseidonWitnessBuilderEnv = WitnessBuilderEnv::create(); + + // Write constants + { + let rc = PoseidonBN254Parameters.constants(); + rc.iter().enumerate().for_each(|(round, rcs)| { + rcs.iter().enumerate().for_each(|(state_index, rc)| { + let rc = vec![*rc; domain_size]; + witness_env.set_fixed_selector_cix( + PoseidonColumn::RoundConstant(round, state_index), + rc, + ) + }); + }); + } + + // Generate random inputs at each row + for _row in 0..domain_size { + let x: Fp = Fp::rand(&mut rng); + let y: Fp = Fp::rand(&mut rng); + let z: Fp = Fp::rand(&mut rng); + + interpreter::poseidon_circuit(&mut witness_env, &PoseidonBN254Parameters, [x, y, z]); + + // Check internal consistency of our circuit: that our + // computed values match the CPU-spec implementation of + // Poseidon. + { + let exp_output: Vec = { + let mut state: Vec = vec![x, y, z]; + poseidon_block_cipher::( + poseidon_params_55_0_7_3::static_params(), + &mut state, + ); + state + }; + let x_col: PoseidonColumn = + PoseidonColumn::Round(NB_FULL_ROUND - 1, 0); + let y_col: PoseidonColumn = + PoseidonColumn::Round(NB_FULL_ROUND - 1, 1); + let z_col: PoseidonColumn = + PoseidonColumn::Round(NB_FULL_ROUND - 1, 2); + assert_eq!(witness_env.read_column(x_col), exp_output[0]); + assert_eq!(witness_env.read_column(y_col), exp_output[1]); + assert_eq!(witness_env.read_column(z_col), exp_output[2]); + } + + witness_env.next_row(); + } + } + + #[test] + /// Checks that poseidon circuit can be proven and verified. Big domain. + pub fn heavy_test_completeness() { + let mut rng = o1_utils::tests::make_test_rng(None); + let domain_size: usize = 1 << 4; + + let (relation_witness, fixed_selectors) = { + let mut witness_env: PoseidonWitnessBuilderEnv = WitnessBuilderEnv::create(); + + let mut fixed_selectors: [Vec; N_FSEL] = + std::array::from_fn(|_| vec![Fp::zero(); 1]); + // Write constants + { + let rc = PoseidonBN254Parameters.constants(); + rc.iter().enumerate().for_each(|(round, rcs)| { + rcs.iter().enumerate().for_each(|(state_index, rc)| { + witness_env.set_fixed_selector_cix( + PoseidonColumn::RoundConstant(round, state_index), + vec![*rc; domain_size], + ); + fixed_selectors[round * STATE_SIZE + state_index] = vec![*rc; domain_size]; + }); + }); + } + + // Generate random inputs at each row + for _row in 0..domain_size { + let x: Fp = Fp::rand(&mut rng); + let y: Fp = Fp::rand(&mut rng); + let z: Fp = Fp::rand(&mut rng); + + interpreter::poseidon_circuit( + &mut witness_env, + &PoseidonBN254Parameters, + [x, y, z], + ); + + witness_env.next_row(); + } + + ( + witness_env.get_relation_witness(domain_size), + fixed_selectors, + ) + }; + + let constraints = { + let mut constraint_env = ConstraintBuilderEnv::::create(); + interpreter::apply_permutation(&mut constraint_env, &PoseidonBN254Parameters); + let constraints = constraint_env.get_constraints(); + + // Constraints properties check. For this test, we do have 165 constraints + assert_eq!(constraints.len(), STATE_SIZE * NB_FULL_ROUND); + // Maximum degree of the constraints + assert_eq!(constraints.iter().map(|c| c.degree(1, 0)).max().unwrap(), 7); + // We only have degree 7 constraints + constraints + .iter() + .map(|c| c.degree(1, 0)) + .for_each(|d| assert_eq!(d, 7)); + + constraints + }; + + kimchi_msm::test::test_completeness_generic_no_lookups::( + constraints, + Box::new(fixed_selectors), + relation_witness, + domain_size, + &mut rng, + ); + } +} diff --git a/ivc/src/poseidon_8_56_5_3_2/bn254/mod.rs b/ivc/src/poseidon_8_56_5_3_2/bn254/mod.rs new file mode 100644 index 0000000000..073689ec4f --- /dev/null +++ b/ivc/src/poseidon_8_56_5_3_2/bn254/mod.rs @@ -0,0 +1,1064 @@ +//! Poseidon parameters that can be used by [crate::poseidon_8_56_5_3_2] over +//! the scalar field of BN254, for a security level of 128 bits. + +use crate::poseidon_8_56_5_3_2::columns::PoseidonColumn; +use ark_bn254::Fr; +use kimchi_msm::columns::ColumnIndexer; +use mina_poseidon::{constants::SpongeConstants, poseidon::ArithmeticSpongeParams}; +use once_cell::sync::Lazy; + +/// The number of field elements in the state +pub const STATE_SIZE: usize = 3; + +/// Number of full rounds +pub const NB_FULL_ROUND: usize = 8; + +/// Number of partial rounds +pub const NB_PARTIAL_ROUND: usize = 56; + +/// Total number of rounds, including partial and full. +pub const NB_TOTAL_ROUND: usize = NB_FULL_ROUND + NB_PARTIAL_ROUND; + +/// Number of round constants +pub const NB_ROUND_CONSTANTS: usize = NB_TOTAL_ROUND * STATE_SIZE; + +/// The number of constraints required by this gadget +pub const NB_CONSTRAINTS: usize = 432; + +/// The maximum degree of a constraint +pub const MAX_DEGREE: u64 = 2; + +pub type Column = PoseidonColumn; + +/// The number of columns in the poseidon permutation +// This is an alias to ease the reference to it in the code. +pub const NB_COLUMNS: usize = Column::N_COL; + +// FIXME: move into mina_poseidon when this code is ready to go into production. +/* Generated by the following command: +```shell +sage params.sage --rounds 64 rust 3 bn254 > bn254_poseidon.rs +``` +with the commit 85cbccd4266cdb567fd47ffc54c3fa52543c2c51 +where the curves have been changed in the script params.sage to use bn254 + */ + +use std::str::FromStr; + +use super::interpreter::PoseidonParams; + +fn params() -> ArithmeticSpongeParams { + ArithmeticSpongeParams { + mds: vec![ + vec![ + Fr::from_str( + "1891083243990574305685895570197511851713934835398236923905694056149924753068", + ) + .unwrap(), + Fr::from_str( + "12425575093536208349525223497695086347893624421418924655002789600937388509542", + ) + .unwrap(), + Fr::from_str( + "21738922686690608832323816580123167366266110538717945605098093992702895344376", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "27684418222856506172738658695423428482598065952308803192811766786876275297242", + ) + .unwrap(), + Fr::from_str( + "15829269591837939027267285551701704696044414402373072091502623580282737399890", + ) + .unwrap(), + Fr::from_str( + "19252639924245456560245438416349235444106039029142654756933998759642658686319", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "13569069999110563102763825881967666333546046803842950209911791358856018301650", + ) + .unwrap(), + Fr::from_str( + "27981915092643975197764375502939826253878858038100319462170032597200154977862", + ) + .unwrap(), + Fr::from_str( + "8885190793627186923687683451640133227063861709370969757580340555722307526098", + ) + .unwrap(), + ], + ], + round_constants: vec![ + vec![ + Fr::from_str( + "22979498492811028599388207262497610781398114192113450876459774028198681978580", + ) + .unwrap(), + Fr::from_str( + "3462163333362550903390325785926437039762839525013511939547720766741499409856", + ) + .unwrap(), + Fr::from_str( + "12308193519164615772431903590941145851045953550307364845912877146027882177970", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "26635724949260230089658132165964484272715279452239432758187323659534196378572", + ) + .unwrap(), + Fr::from_str( + "16384206429233192859678459425673999470857159372638200975562270773516251379588", + ) + .unwrap(), + Fr::from_str( + "19416611223573031732302519381697999456593181436167069017836818060955725864200", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "27016929334480358351753199307567711107921448016886807295662950322547508172113", + ) + .unwrap(), + Fr::from_str( + "11238098416885388471337306072943475212729724613485264776842071924751017567771", + ) + .unwrap(), + Fr::from_str( + "26357617544526716389824924588303395055989615117170456588197310726656210667591", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "13522402812268713007260660858020591835416325370361256561462576999381232564525", + ) + .unwrap(), + Fr::from_str( + "17993297334719348473740719243191272244929728431491444848946396448123051191677", + ) + .unwrap(), + Fr::from_str( + "14266894438972358618281326760350513234022317540856366683582613440301108328203", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "15544565783214950187472074309645319815959207946976778588709413987755227716658", + ) + .unwrap(), + Fr::from_str( + "11459281677950411823472431204851572448489387232119046190993717723024935465646", + ) + .unwrap(), + Fr::from_str( + "21931774631910775084872817505478355551260333558147704485801041562136825822689", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "23930462756185355677176913878440535939755511465091906422358167934102141974413", + ) + .unwrap(), + Fr::from_str( + "24777337052169818536181344215824603464067642843117484360747523672092746770770", + ) + .unwrap(), + Fr::from_str( + "9723694385274179134284666123116305587583263720877786033125569799811141709147", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "21334124835555176052046135550571447667154056413650326169040093998967534013839", + ) + .unwrap(), + Fr::from_str( + "28178412293035677936519661468561659327776642473949161402580740126141114716595", + ) + .unwrap(), + Fr::from_str( + "23261949722991482098067992934452303478826592879142057585426055746184730329977", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "7525745512005637899980788850937553026084357875660084470654426987818115307169", + ) + .unwrap(), + Fr::from_str( + "11196989505533608046175719825672890499899138989727925138952469167443760521626", + ) + .unwrap(), + Fr::from_str( + "25069470445604580535270406510013059947440356347927688567873255716571855782556", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "746152031855445071661969786731764979416640168084300832548286844362915267733", + ) + .unwrap(), + Fr::from_str( + "14500783258490511332342759610086566441467851535274198353685874901881586203920", + ) + .unwrap(), + Fr::from_str( + "12586553744079534087671160799108871165444925427796015728451282191808004632029", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "2658978799169405016067965116007625196452718808329137140199842584938385590196", + ) + .unwrap(), + Fr::from_str( + "25980326596677328357808416201926774899770147152776274221261358547719778020156", + ) + .unwrap(), + Fr::from_str( + "24508235842219447789227397734868597077124899664487104942155698316761683262563", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "20191280011693017888798491354094094243127547574655352220297687980090230216", + ) + .unwrap(), + Fr::from_str( + "27372133858594248121174511881254229698687209419322362469155503967311874023911", + ) + .unwrap(), + Fr::from_str( + "24479529308646273673625530554247818183967069980890544961747034465992605211812", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "23978441594464348354765134391283311312307598042652828493772997453014454901206", + ) + .unwrap(), + Fr::from_str( + "11664556276797417846866489492418594702491171278433997565894976597592370230667", + ) + .unwrap(), + Fr::from_str( + "12270771377697756787250169216048909236321853168489305289596185098410215090484", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "8811136672424023763408452444568991590125177469609096631745680177504609711487", + ) + .unwrap(), + Fr::from_str( + "20237676190929489436588870470052869226876887618934149899920104260655948094739", + ) + .unwrap(), + Fr::from_str( + "20062974513054622329783259390735561568588514582980747668306658460292592667248", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "798776603612212418444623010074034446472782213575824096964910597458981693714", + ) + .unwrap(), + Fr::from_str( + "8758094627630360838980606896109464566485363275799929171438278126952202465833", + ) + .unwrap(), + Fr::from_str( + "707487852423957926868336579455035906997963205205141270331975346408432067044", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "202779927146339323474786313977213080711135813758190854284590663556011706836", + ) + .unwrap(), + Fr::from_str( + "20391231297043803259194308176992776361609861855662006565804325535077866583400", + ) + .unwrap(), + Fr::from_str( + "13103971641300442794695006804822894722919592148173662271765398624597649213516", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "11887088560178314335635682366268022017575768082106271832073442174179560940002", + ) + .unwrap(), + Fr::from_str( + "9497730497365106303373876625863922615426121795921218372965815099779016166015", + ) + .unwrap(), + Fr::from_str( + "13255889376415173861282249540978269441170571583954046228817145377643690062643", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "21797601683155849424606801821462621112628210830522816165936272187765788522806", + ) + .unwrap(), + Fr::from_str( + "24581830980057907171669975479232671391615633395687956862981481636043684252818", + ) + .unwrap(), + Fr::from_str( + "3541295787421703076701812027030126532836938797821201403520460081703194798007", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "26648315008823981808527573963042312855617262185318718006368079643838521696342", + ) + .unwrap(), + Fr::from_str( + "25709698375978466998813029191490908445909426399743598850530240905007542181236", + ) + .unwrap(), + Fr::from_str( + "15108470522637839917783948187626705626608074131603801620229663133367773769409", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "28021329994881905447188243064573681540852456674731988825066269508526569461230", + ) + .unwrap(), + Fr::from_str( + "23931699377713654292108411755172835281506863304630736051021853367123546501725", + ) + .unwrap(), + Fr::from_str( + "16723886233609641431662771575929234875591162154288587674425439500724219825212", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "2671311012893846775573899497331547132295274951055892772678260375558714942406", + ) + .unwrap(), + Fr::from_str( + "10473922685305341900412641696196941263307819634043185788067854748950602802383", + ) + .unwrap(), + Fr::from_str( + "5111630741260296655990184310611375231177307348451724203930709456972805973880", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "4992057345914438551815056581049158703461724962129889522457212594798779955393", + ) + .unwrap(), + Fr::from_str( + "21183229986193453383825269196190741932488867108677665668392098260004852050858", + ) + .unwrap(), + Fr::from_str( + "17083812614758825262603118558173710997332640429545889370504202808090339572857", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "26920976402012774776865974353864956131500106517245044898291304474770065076213", + ) + .unwrap(), + Fr::from_str( + "8733966472516018163838600897336665253176449185681528084604422362542186741640", + ) + .unwrap(), + Fr::from_str( + "12872170036700590319474830060138441606481805188061693688276584984859227049202", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "4477722790415428221489180851061526086953856916136558152290587660464310195129", + ) + .unwrap(), + Fr::from_str( + "15013641992838646488575621859957900283128972996570425739867595472354843675113", + ) + .unwrap(), + Fr::from_str( + "11110048915932820585149019832900719558018746589803602931460479672600239301992", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "6386524788949068680243545386229548182250656575626566757865736117588702942149", + ) + .unwrap(), + Fr::from_str( + "10817607938125396761228235991828983933462215897289776505148575850147631400593", + ) + .unwrap(), + Fr::from_str( + "17865086662065068946652014543896618689007438069494510844490187921781045033543", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "26739078526540815991525797409513121042395926527702073499652548281297625028042", + ) + .unwrap(), + Fr::from_str( + "22105063856923216930774038075969680004134092465525615222672386192561909491643", + ) + .unwrap(), + Fr::from_str( + "6547073217674180994961882255834734219166276822047545566909303528752619346965", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "25888791229334359634310626508983935867495449391423881771352427385021616079608", + ) + .unwrap(), + Fr::from_str( + "5728419725627720701717078886669712201564103341166911632619636881997293379662", + ) + .unwrap(), + Fr::from_str( + "9577751993330236726740958969117670757888267786661142736615315540557935073808", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "18485877910552384007717083804696131794323032392196584185820990234315883649039", + ) + .unwrap(), + Fr::from_str( + "1769516528997465058766917692166940530540162775010485894812511264651123420964", + ) + .unwrap(), + Fr::from_str( + "386757200774461941643538807053842280717539800465562760112828539321993775241", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "3316509296104361799525651631938052468498330218596415774527777363167138087617", + ) + .unwrap(), + Fr::from_str( + "13729336006888298565425261097968915916390598196131379830434588594171666569512", + ) + .unwrap(), + Fr::from_str( + "16148239410408669906403652096870758949643310812745307334909798696471189947927", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "21823520381432477314141727574561655219860240640515827154913290917309686356497", + ) + .unwrap(), + Fr::from_str( + "28068200504050762109788374571880725807356066145249496235193375357140395644419", + ) + .unwrap(), + Fr::from_str( + "5091226074924237835610573037013208388130454756337337471122368918009398753447", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "6515120964938381342071855012952056094433032260551298595194434149594056652176", + ) + .unwrap(), + Fr::from_str( + "2340842607486511800425610472421888578518367747461866490788654350550767741883", + ) + .unwrap(), + Fr::from_str( + "28166492224976583257960903700053982900186421019635512376218580362545306276709", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "3712612816523606936094514234541515605991691986951168127960369409943944902903", + ) + .unwrap(), + Fr::from_str( + "1152766033584194531070871869834974085415579739949782153652057784926343256899", + ) + .unwrap(), + Fr::from_str( + "8765049913334393206963752695629450392338936491569237833753514286462581790561", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "1601627912182619663670418257843886466719103377571015176803695116602531025179", + ) + .unwrap(), + Fr::from_str( + "18098024103033142368624034583201165961009040474643655548434700066455514231331", + ) + .unwrap(), + Fr::from_str( + "21928220871250005366949251966887490261658471717243624436236070782233710958193", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "13295211045409383825433911169815486079945004241473963274225158733643991813121", + ) + .unwrap(), + Fr::from_str( + "10102619007330980171750476017504785290001093245315837544464105419145430593293", + ) + .unwrap(), + Fr::from_str( + "209066254001010781273163283043011468251658035504800687498237342325693440716", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "10569794888394925789900696666438098928482919555802848731145904974932210713381", + ) + .unwrap(), + Fr::from_str( + "24519507563490032456526532812360147850362262124850480124499654753387364654494", + ) + .unwrap(), + Fr::from_str( + "8177920417424677597775227574376448009357347539898615178798501628859714007479", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "12067616893727942346831314147294036302857666398663726324584342747296065928710", + ) + .unwrap(), + Fr::from_str( + "16971111464439978238379247781540292780056241741803063434856766170743124461469", + ) + .unwrap(), + Fr::from_str( + "19754437344667989781187357721513315145328407420825515279265102170857620051268", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "3975594217367336645402283024719769072885792075344511606692741601510641156311", + ) + .unwrap(), + Fr::from_str( + "19040197452680989990741508256669934716515775259429067120683774529031109727966", + ) + .unwrap(), + Fr::from_str( + "5634578132978895965594894841292458192813357921721236220653749372316851034659", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "24312649336217150643502647288233882695985124865143983808786503810476675628308", + ) + .unwrap(), + Fr::from_str( + "26190380549799665209932387790089553938565301230293821415792947418738438572475", + ) + .unwrap(), + Fr::from_str( + "28207972947754207739680139011960703172165585550947706768480292386338937193745", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "14451766174292343678890357279454517773013585124369849119851032684103228653072", + ) + .unwrap(), + Fr::from_str( + "12181367175079876176988175575030826371836485595254655739186027469233503901695", + ) + .unwrap(), + Fr::from_str( + "10991645916339084020684503132290495819766466380987797648576877855177656702574", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "21313503925572929158381750262537945058966279485804307864340280517601895617812", + ) + .unwrap(), + Fr::from_str( + "1550527311092834162328201290652099251984068161341884726521396319041776382263", + ) + .unwrap(), + Fr::from_str( + "12003435213367105899404267897476505638234793525266930670911913539817221417924", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "8914436165324002549401721460186086215484314929358430322367050538121529812441", + ) + .unwrap(), + Fr::from_str( + "27832471390260088958712195064264693263013262224866234154815027609517686567070", + ) + .unwrap(), + Fr::from_str( + "21985130074442575837699661478687045882883906445505825293630722207419298642563", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "3121108818032693608292705100552135299104657945137655223341160358361023747970", + ) + .unwrap(), + Fr::from_str( + "27854040945265461368879245943811705519815133202941163906839617464472048050107", + ) + .unwrap(), + Fr::from_str( + "16002511017673785381646235433633055570232704209925362749110958022614294680393", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "9391205603129777266623979897548135033992809302803677537718465809304634939535", + ) + .unwrap(), + Fr::from_str( + "24297079789956858850141557550669917040800093660605463170583536807967359376943", + ) + .unwrap(), + Fr::from_str( + "5541548341946936547043936820389981449139467811019058131623682269330484215783", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "7190170322642292762509602175876149041297305605840922902837779590457938918669", + ) + .unwrap(), + Fr::from_str( + "2027727849652978298961694136962688734430323753419089111303416976388312626148", + ) + .unwrap(), + Fr::from_str( + "402427768909234974448400451691124558802645007494232545425431267044631491274", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "12797987239774923416407369376953974586493255155241527032821517134381740228905", + ) + .unwrap(), + Fr::from_str( + "13333249780205062899642181453873590081512016926206507817605750877019442371089", + ) + .unwrap(), + Fr::from_str( + "22188565325128643960736450273241242961083822226450593824571763485758109842353", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "14776177603689091833670811465610858113310593319381602563122002687195119058247", + ) + .unwrap(), + Fr::from_str( + "8091773219661648644611506313794854881311037176009459530767471191252204309868", + ) + .unwrap(), + Fr::from_str( + "11572415535524327542083010017274275840592842320730873456621979565559067913733", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "20923164127054504021838810951918321616973675807578349217848643497149066166007", + ) + .unwrap(), + Fr::from_str( + "23916764349365855067645867029751406599241457684665189633084153863091374186634", + ) + .unwrap(), + Fr::from_str( + "15824998809651622678290890820327874961632638426240138370022501121781103446259", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "24637783239914661001463701818528140355269066399069119799243934837974141806351", + ) + .unwrap(), + Fr::from_str( + "960857870893264392654072242483170132703826969766439993624797963496069781505", + ) + .unwrap(), + Fr::from_str( + "26437805484456742255124383441148924774337202744224954418128369699282030985816", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "26578738662896906130683946198236194771349279034514315009985381494637755178012", + ) + .unwrap(), + Fr::from_str( + "4824844368626589074800225473359998312724599509010698030845972726192270761413", + ) + .unwrap(), + Fr::from_str( + "1279633418489206280063439936145640252901253476940505080338091886704297503137", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "13710991567078829309950552225511825775559273430088991705888338422645023107567", + ) + .unwrap(), + Fr::from_str( + "24668872485331698368369872321174853677667368883929365377153552889902839876059", + ) + .unwrap(), + Fr::from_str( + "13750456500419553884447154012886160987332428475286924998031161614878499276461", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "7897870706671989949903359523135485593775134275719857967175422051001928288650", + ) + .unwrap(), + Fr::from_str( + "24837463715854409438259925936940667100471584776258950891486907118659586222782", + ) + .unwrap(), + Fr::from_str( + "15102320430753094455411262873908179792479117498565463936483853362951093377404", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "20684012922073862986880613075424856809319565915848039596293835293542621074230", + ) + .unwrap(), + Fr::from_str( + "27853137663239148089212917142958191707337132972791514344244938292818076156215", + ) + .unwrap(), + Fr::from_str( + "21515833961121090836372549249533024521086358011705741464637998484545867651639", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "14927367861301244473471972423309886137778884679987552642033604382283866452496", + ) + .unwrap(), + Fr::from_str( + "2234445563530370737450306710054836240206601416475326119659514396778117223266", + ) + .unwrap(), + Fr::from_str( + "7320243644229576716977260850447281726825012665363843781486677636270396893993", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "3648709991731171089283528046834957511485556517829186374119154464357293936649", + ) + .unwrap(), + Fr::from_str( + "8431771340355225974834111903440269408389608143822442356390982704086190421477", + ) + .unwrap(), + Fr::from_str( + "850707530944871742730572985296454989068982533587676445405227892149682249460", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "27968558232222653455324053739339866927384509292992637407525622495221665748064", + ) + .unwrap(), + Fr::from_str( + "18309438953837613260356082733742847603856465956586832431344593061564231402706", + ) + .unwrap(), + Fr::from_str( + "7558769148417330857693721923631008074203529622851527561622447590521963662006", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "3817411164684956972519434296876841710387488302709712662164166985673595410943", + ) + .unwrap(), + Fr::from_str( + "14529667139169881031943219432572774573785882712681678612418006263967413092263", + ) + .unwrap(), + Fr::from_str( + "23218827447973479235173323744034317336035671404247548580906272331188785389688", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "8651801510810146025412974160945288620309458551124978671355372318440299892446", + ) + .unwrap(), + Fr::from_str( + "22421908295524355984007285617835206887162355421815299667252838830702546015133", + ) + .unwrap(), + Fr::from_str( + "7160018919593574320964537152208736524155714355508700340218764463887638143455", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "723350002648579254243598219836182385851624402483518972524433485600657018308", + ) + .unwrap(), + Fr::from_str( + "1424479715330663945471022437902357236141169107604243794548089122048400883402", + ) + .unwrap(), + Fr::from_str( + "12380316486623809861878214203281027877179803297952235333898818121693577083127", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "13300264762908838749047406324479200233395191915779570411502803795135456010503", + ) + .unwrap(), + Fr::from_str( + "7182708579311809518589613836995895145881239038123095258480613105900991785981", + ) + .unwrap(), + Fr::from_str( + "3399116840288437658533170746195011940616209443265644590911430680876206950690", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "18075677167539342788969842697404710734192919515624038738684304089341638962797", + ) + .unwrap(), + Fr::from_str( + "774916241389256067817357617190556955727339372170470123567430330174958764878", + ) + .unwrap(), + Fr::from_str( + "23911159648093593276293911909659625889333565386584216455353262739940352889672", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "7861494016510010234313534888968189516898452462977986982488816163397319727747", + ) + .unwrap(), + Fr::from_str( + "15072759186032286031690480130876158018121236234681552923585561011002205274233", + ) + .unwrap(), + Fr::from_str( + "13329814760549908430040494615924435617177143177522675259969178808824256694834", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "7468723007810503327553686337411634190912418400882277241712552524430599692675", + ) + .unwrap(), + Fr::from_str( + "11304437011528059451078980881777692033687950145258622992372151025965772625841", + ) + .unwrap(), + Fr::from_str( + "24455414907879354554658170783358918522990892444576704866381766190660412826213", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "4123473255827760339537832283192656514249000347603384170048044945751397547862", + ) + .unwrap(), + Fr::from_str( + "13177230660974899915401210507892129885856976090284825863072745642109312338203", + ) + .unwrap(), + Fr::from_str( + "28859805634412811444238562480326194712347893192568982315953124889109926907412", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "7803019375719168318963557398085369618684087162676370220706203244371827257508", + ) + .unwrap(), + Fr::from_str( + "4029922906086312393665000628450269927106443351493821364980763691551262280792", + ) + .unwrap(), + Fr::from_str( + "11158633081229897821908038161319320155864649742858460176817619070002508020544", + ) + .unwrap(), + ], + vec![ + Fr::from_str( + "18052059944752105394121061539214545921115313064660268801792496020636120990680", + ) + .unwrap(), + Fr::from_str( + "28815470454875236045606443537792333795118166510830397224364317235056781483845", + ) + .unwrap(), + Fr::from_str( + "8678824483470329866907826063438633160791886719100084683778994189652170993635", + ) + .unwrap(), + ], + ], + } +} + +pub fn static_params() -> &'static ArithmeticSpongeParams { + static PARAMS: Lazy> = Lazy::new(params); + &PARAMS +} + +/// Constants used by the IVC circuit used by the folding scheme +/// This is meant to be only used for the IVC circuit, with the BN254 curve +/// It has not been tested with/for other curves. +// This must be moved into mina_poseidon later. +#[derive(Clone)] +pub struct PlonkSpongeConstantsIVC {} + +impl SpongeConstants for PlonkSpongeConstantsIVC { + const SPONGE_CAPACITY: usize = 1; + const SPONGE_WIDTH: usize = 3; + const SPONGE_RATE: usize = 2; + const PERM_ROUNDS_FULL: usize = 0; + const PERM_ROUNDS_PARTIAL: usize = 56; + const PERM_HALF_ROUNDS_FULL: usize = 4; + const PERM_SBOX: u32 = 5; + const PERM_FULL_MDS: bool = true; + const PERM_INITIAL_ARK: bool = false; +} + +pub struct PoseidonBN254Parameters; + +impl PoseidonParams for PoseidonBN254Parameters { + fn constants(&self) -> [[Fr; STATE_SIZE]; NB_TOTAL_ROUND] { + let rc = &static_params().round_constants; + std::array::from_fn(|i| std::array::from_fn(|j| Fr::from(rc[i][j]))) + } + + fn mds(&self) -> [[Fr; STATE_SIZE]; STATE_SIZE] { + let mds = &static_params().mds; + std::array::from_fn(|i| std::array::from_fn(|j| Fr::from(mds[i][j]))) + } +} + +#[cfg(test)] +mod tests { + use std::str::FromStr; + + use super::{static_params, PlonkSpongeConstantsIVC}; + use ark_bn254::Fr; + use kimchi::o1_utils::FieldHelpers; + use mina_poseidon::poseidon::{ArithmeticSponge as Poseidon, Sponge as _}; + + // Regression tests. Test vectors have been generated using the same + // code, commit cdca35851ff27bdaf9277388ea08d83c8a0a75fc + // This does not mean that the implementation is correct. + #[test] + fn test_poseidon() { + let mut hash = Poseidon::::new(static_params()); + let input: [Fr; 3] = [ + Fr::from_str("1").unwrap(), + Fr::from_str("1").unwrap(), + Fr::from_str("1").unwrap(), + ]; + let exp_output_str = [ + "05c09e2e5cf67833f4987c3011ae8c754b70dd8c28cbb619b421e012fad87c16", + "14b598974e69d9cbee4c6b809256152552ec6c008e59f5fe54f60491d8286c23", + "54e82d762fe2994c85fb4d4e1b22cb060b005327bba4265e486e82ae0c2a7125", + ]; + let exp_output = exp_output_str.map(|x| Fr::from_hex(x).unwrap()); + hash.absorb(&input); + assert_eq!(hash.state, exp_output); + } +} diff --git a/ivc/src/poseidon_8_56_5_3_2/columns.rs b/ivc/src/poseidon_8_56_5_3_2/columns.rs new file mode 100644 index 0000000000..c5893aab51 --- /dev/null +++ b/ivc/src/poseidon_8_56_5_3_2/columns.rs @@ -0,0 +1,89 @@ +/// The column layout will be as follow, supposing a state size of 3 elements: +/// | C1 | C2 | C3 | C4 | C5 | C6 | ... | C_(k) | C_(k + 1) | C_(k + 2) | +/// |--- |----|----|-----|-----|-----|-----|-------|-----------|-----------| +/// | x | y | z | x' | y' | z' | ... | x'' | y'' | z'' | +/// | MDS \circ SBOX | | MDS \circ SBOX | +/// |-----------------| |-------------------------------| +/// Divided in 4 +/// blocks of degree 2 +/// constraints +/// where (x', y', z') = MDS(x^5, y^5, z^5), i.e. the result of the linear layer +use kimchi_msm::columns::{Column, ColumnIndexer}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub enum PoseidonColumn< + const STATE_SIZE: usize, + const NB_FULL_ROUND: usize, + const NB_PARTIAL_ROUND: usize, +> { + Input(usize), + // we use the constraint: + // y = x * x -> x^2 -> i + // y' = y * y -> x^4 -> i + 1 + // y'' = y * y' -> x^5 -> i + 2 + // y'' * MDS -> i + 4 + // --> nb round, 4 * state_size + FullRound(usize, usize), + // round, idx (4 + STATE_SIZE - 1) + PartialRound(usize, usize), + RoundConstant(usize, usize), +} + +impl + ColumnIndexer for PoseidonColumn +{ + // - STATE_SIZE input columns + // - for each partial round: + // - 1 column for x^2 -> x * x + // - 1 column for x^4 -> x^2 * x^2 + // - 1 column for x^5 -> x^4 * x + // - 1 column for x^5 * MDS(., L) + // - STATE_SIZE - 1 columns for the unchanged elements multiplied by the + // MDS + rc + // - for each full round: + // - STATE_SIZE state columns for x^2 -> x * x + // - STATE_SIZE state columns for x^4 -> x^2 * x^2 + // - STATE_SIZE state columns for x^5 -> x^4 * x + // - STATE_SIZE state columns for x^5 * MDS(., L) + // For the round constants, we have: + // - STATE_SIZE * (NB_PARTIAL_ROUND + NB_FULL_ROUND) + const N_COL: usize = + // input + STATE_SIZE + + 4 * NB_FULL_ROUND * STATE_SIZE // full round + + (4 + STATE_SIZE - 1) * NB_PARTIAL_ROUND // partial round + + STATE_SIZE * (NB_PARTIAL_ROUND + NB_FULL_ROUND); // fixed selectors + + fn to_column(self) -> Column { + // number of reductions for + // x -> x^2 -> x^4 -> x^5 -> x^5 * MDS + let nb_red = 4; + match self { + PoseidonColumn::Input(i) => { + assert!(i < STATE_SIZE); + Column::Relation(i) + } + PoseidonColumn::PartialRound(round, idx) => { + assert!(round < NB_PARTIAL_ROUND); + assert!(idx < nb_red + STATE_SIZE - 1); + let offset = STATE_SIZE; + let idx = offset + round * (nb_red + STATE_SIZE - 1) + idx; + Column::Relation(idx) + } + PoseidonColumn::FullRound(round, state_index) => { + assert!(state_index < nb_red * STATE_SIZE); + // We start round 0 + assert!(round < NB_FULL_ROUND); + let offset = STATE_SIZE + (NB_PARTIAL_ROUND * (nb_red + STATE_SIZE - 1)); + let idx = offset + (round * nb_red * STATE_SIZE + state_index); + Column::Relation(idx) + } + PoseidonColumn::RoundConstant(round, state_index) => { + assert!(state_index < STATE_SIZE); + assert!(round < NB_FULL_ROUND + NB_PARTIAL_ROUND); + let idx = round * STATE_SIZE + state_index; + Column::FixedSelector(idx) + } + } + } +} diff --git a/ivc/src/poseidon_8_56_5_3_2/interpreter.rs b/ivc/src/poseidon_8_56_5_3_2/interpreter.rs new file mode 100644 index 0000000000..f70ac17c0c --- /dev/null +++ b/ivc/src/poseidon_8_56_5_3_2/interpreter.rs @@ -0,0 +1,324 @@ +//! Implement an interpreter for a specific instance of the Poseidon inner permutation. +//! The Poseidon construction is defined in the paper ["Poseidon: A New Hash +//! Function"](https://eprint.iacr.org/2019/458.pdf). +//! The Poseidon instance works on a state of size `STATE_SIZE` and is designed +//! to work with full and partial rounds. As a reminder, the Poseidon +//! permutation is a mapping from `F^STATE_SIZE` to `F^STATE_SIZE`. +//! The user is responsible to provide the correct number of full and partial +//! rounds for the given field and the state. +//! Also, it is hard-coded that the substitution is `5`. The user must verify +//! that `5` is coprime with `p - 1` where `p` is the order the field. +//! The constants and matrix can be generated the file +//! `poseidon/src/pasta/params.sage` + +use crate::poseidon_8_56_5_3_2::columns::PoseidonColumn; +use ark_ff::PrimeField; +use kimchi_msm::circuit_design::{ColAccessCap, ColWriteCap, HybridCopyCap}; +use num_bigint::BigUint; +use num_integer::Integer; + +/// Represents the parameters of the instance of the Poseidon permutation. +/// Constants are the round constants for each round, and MDS is the matrix used +/// by the linear layer. +/// The type is parametrized by the field, the state size, and the total number +/// of rounds. +// IMPROVEME merge constants and mds in a flat array, to use the CPU cache +pub trait PoseidonParams { + fn constants(&self) -> [[F; STATE_SIZE]; NB_TOTAL_ROUNDS]; + fn mds(&self) -> [[F; STATE_SIZE]; STATE_SIZE]; +} + +/// Populates and checks one poseidon invocation. +pub fn poseidon_circuit< + F: PrimeField, + const STATE_SIZE: usize, + const NB_FULL_ROUND: usize, + const NB_PARTIAL_ROUND: usize, + const NB_TOTAL_ROUND: usize, + PARAMETERS, + Env, +>( + env: &mut Env, + param: &PARAMETERS, + init_state: [Env::Variable; STATE_SIZE], +) -> [Env::Variable; STATE_SIZE] +where + F: PrimeField, + PARAMETERS: PoseidonParams, + Env: ColWriteCap> + + HybridCopyCap>, +{ + // Write inputs + init_state.iter().enumerate().for_each(|(i, value)| { + env.write_column(PoseidonColumn::Input(i), value); + }); + + // Create, write, and constrain all other columns. + apply_permutation(env, param) +} + +/// Apply the HADES-based Poseidon to the state. +/// The environment has to be initialized with the input values. +/// It mimicks the version described in the paper ["Poseidon: A New Hash +/// Function"](https://eprint.iacr.org/2019/458.pdf), figure 2. The construction +/// first starts with `NB_FULL_ROUND/2` full rounds, then `NB_PARTIAL_ROUND` +/// partial rounds, and finally `NB_FULL_ROUND/2` full rounds. +/// +/// Each full rounds consists of the following steps: +/// - adding the round constants on the whole state +/// - applying the sbox on the whole state +/// - applying the linear layer on the whole state +/// +/// Each partial round consists of the following steps: +/// - adding the round constants on the whole state +/// - applying the sbox on the first element of the state (FIXME: the +/// specification mentions the last element - map the implementation provided in +/// [mina_poseidon]) +/// - applying the linear layer on the whole state +pub fn apply_permutation< + F: PrimeField, + const STATE_SIZE: usize, + const NB_FULL_ROUND: usize, + const NB_PARTIAL_ROUND: usize, + const NB_TOTAL_ROUND: usize, + PARAMETERS, + Env, +>( + env: &mut Env, + param: &PARAMETERS, +) -> [Env::Variable; STATE_SIZE] +where + F: PrimeField, + PARAMETERS: PoseidonParams, + Env: ColAccessCap> + + HybridCopyCap>, +{ + // Checking that p - 1 is coprime with 5 as it has to be the case for the sbox + { + let one = BigUint::from(1u64); + let p: BigUint = TryFrom::try_from(::MODULUS).unwrap(); + let p_minus_one = p - one.clone(); + let five = BigUint::from(5u64); + assert_eq!(p_minus_one.gcd(&five), one); + } + + let mut state: [Env::Variable; STATE_SIZE] = + std::array::from_fn(|i| env.read_column(PoseidonColumn::Input(i))); + + // Full rounds + for i in 0..(NB_FULL_ROUND / 2) { + state = compute_one_full_round::< + F, + STATE_SIZE, + NB_FULL_ROUND, + NB_PARTIAL_ROUND, + NB_TOTAL_ROUND, + PARAMETERS, + Env, + >(env, param, i, &state); + } + + // Partial rounds + for i in 0..NB_PARTIAL_ROUND { + state = compute_one_partial_round::< + F, + STATE_SIZE, + NB_FULL_ROUND, + NB_PARTIAL_ROUND, + NB_TOTAL_ROUND, + PARAMETERS, + Env, + >(env, param, i, &state); + } + + // Remaining full rounds + for i in (NB_FULL_ROUND / 2)..NB_FULL_ROUND { + state = compute_one_full_round::< + F, + STATE_SIZE, + NB_FULL_ROUND, + NB_PARTIAL_ROUND, + NB_TOTAL_ROUND, + PARAMETERS, + Env, + >(env, param, i, &state); + } + + state +} + +/// Compute one full round the Poseidon permutation +fn compute_one_full_round< + F: PrimeField, + const STATE_SIZE: usize, + const NB_FULL_ROUND: usize, + const NB_PARTIAL_ROUND: usize, + const NB_TOTAL_ROUND: usize, + PARAMETERS, + Env, +>( + env: &mut Env, + param: &PARAMETERS, + round: usize, + state: &[Env::Variable; STATE_SIZE], +) -> [Env::Variable; STATE_SIZE] +where + PARAMETERS: PoseidonParams, + Env: ColAccessCap> + + HybridCopyCap>, +{ + // We start at round 0 + // This implementation mimicks the version described in + // poseidon_block_cipher in the mina_poseidon crate. + assert!( + round < NB_FULL_ROUND, + "The round index {:} is higher than the number of full rounds encoded in the type", + round + ); + + // Adding the round constants + let state: Vec = state + .iter() + .enumerate() + .map(|(i, var)| { + let offset = { + if round < NB_FULL_ROUND / 2 { + 0 + } else { + NB_PARTIAL_ROUND + } + }; + let rc = env.read_column(PoseidonColumn::RoundConstant(offset + round, i)); + var.clone() + rc + }) + .collect(); + + // Applying sbox + // For a state transition from (x, y, z) to (x', y', z'), we use the + // following columns shape: + // x^2, x^4, x^5, x', y^2, y^4, y^5, y', z^2, z^4, z^5, z') + // 0 1 2 3 4 5 6 7 8 9 10 11 + let nb_red = 4; + let state: Vec = state + .iter() + .enumerate() + .map(|(i, var)| { + // x^2 + let var_square_col = PoseidonColumn::FullRound(round, nb_red * i); + let var_square = env.hcopy(&(var.clone() * var.clone()), var_square_col); + // x^4 + let var_four_col = PoseidonColumn::FullRound(round, nb_red * i + 1); + let var_four = env.hcopy(&(var_square.clone() * var_square.clone()), var_four_col); + // x^5 + let var_five_col = PoseidonColumn::FullRound(round, nb_red * i + 2); + env.hcopy(&(var_four.clone() * var.clone()), var_five_col) + }) + .collect(); + + // Applying the linear layer + let mds = PoseidonParams::mds(param); + let state: Vec = mds + .into_iter() + .map(|m| { + state + .clone() + .into_iter() + .zip(m) + .fold(Env::constant(F::zero()), |acc, (s_i, mds_i_j)| { + Env::constant(mds_i_j) * s_i.clone() + acc.clone() + }) + }) + .collect(); + + let res_state: Vec = state + .iter() + .enumerate() + .map(|(i, res)| env.hcopy(res, PoseidonColumn::FullRound(round, nb_red * i + 3))) + .collect(); + + res_state + .try_into() + .expect("Resulting state must be of state size (={STATE_SIZE}) length") +} + +/// Compute one partial round of the Poseidon permutation +fn compute_one_partial_round< + F: PrimeField, + const STATE_SIZE: usize, + const NB_FULL_ROUND: usize, + const NB_PARTIAL_ROUND: usize, + const NB_TOTAL_ROUND: usize, + PARAMETERS, + Env, +>( + env: &mut Env, + param: &PARAMETERS, + round: usize, + state: &[Env::Variable; STATE_SIZE], +) -> [Env::Variable; STATE_SIZE] +where + F: PrimeField, + PARAMETERS: PoseidonParams, + Env: ColAccessCap> + + HybridCopyCap>, +{ + // We start at round 0 + assert!( + round < NB_PARTIAL_ROUND, + "The round index {:} is higher than the number of partial rounds encoded in the type", + round + ); + + // Adding the round constants + let mut state: Vec = state + .iter() + .enumerate() + .map(|(i, var)| { + let offset = NB_FULL_ROUND / 2; + let rc = env.read_column(PoseidonColumn::RoundConstant(offset + round, i)); + var.clone() + rc + }) + .collect(); + + // Applying the sbox + // Apply on the first element of the state + // FIXME: the specification mentions the last element. However, this version + // maps the iimplementation in [poseidon]. + { + let var = state[0].clone(); + let var_square_col = PoseidonColumn::PartialRound(round, 0); + let var_square = env.hcopy(&(var.clone() * var.clone()), var_square_col); + // x^4 + let var_four_col = PoseidonColumn::PartialRound(round, 1); + let var_four = env.hcopy(&(var_square.clone() * var_square.clone()), var_four_col); + // x^5 + let var_five_col = PoseidonColumn::PartialRound(round, 2); + let var_five = env.hcopy(&(var_four.clone() * var.clone()), var_five_col); + state[0] = var_five; + } + + // Applying the linear layer + let mds = PoseidonParams::mds(param); + let state: Vec = mds + .into_iter() + .map(|m| { + state + .clone() + .into_iter() + .zip(m) + .fold(Env::constant(F::zero()), |acc, (s_i, mds_i_j)| { + Env::constant(mds_i_j) * s_i.clone() + acc.clone() + }) + }) + .collect(); + + let res_state: Vec = state + .iter() + .enumerate() + .map(|(i, res)| env.hcopy(res, PoseidonColumn::PartialRound(round, 3 + i))) + .collect(); + + res_state + .try_into() + .expect("Resulting state must be of state size (={STATE_SIZE}) length") +} diff --git a/ivc/src/poseidon_8_56_5_3_2/mod.rs b/ivc/src/poseidon_8_56_5_3_2/mod.rs new file mode 100644 index 0000000000..66d7f6be84 --- /dev/null +++ b/ivc/src/poseidon_8_56_5_3_2/mod.rs @@ -0,0 +1,185 @@ +//! Specialised circuit for Poseidon where we have maximum degree 2 constraints. + +pub mod columns; +pub mod interpreter; + +pub mod bn254; + +#[cfg(test)] +mod tests { + use crate::poseidon_8_56_5_3_2::{ + bn254::{ + static_params, Column, PlonkSpongeConstantsIVC, PoseidonBN254Parameters, MAX_DEGREE, + NB_CONSTRAINTS, NB_FULL_ROUND, NB_PARTIAL_ROUND, NB_TOTAL_ROUND, STATE_SIZE, + }, + columns::PoseidonColumn, + interpreter, + interpreter::PoseidonParams, + }; + use ark_ff::{UniformRand, Zero}; + use kimchi_msm::{ + circuit_design::{ColAccessCap, ConstraintBuilderEnv, WitnessBuilderEnv}, + columns::ColumnIndexer, + lookups::DummyLookupTable, + Fp, + }; + use mina_poseidon::permutation::poseidon_block_cipher; + + pub const N_COL: usize = Column::N_COL; + pub const N_DSEL: usize = 0; + pub const N_FSEL: usize = NB_TOTAL_ROUND * STATE_SIZE; + + type PoseidonWitnessBuilderEnv = WitnessBuilderEnv< + Fp, + Column, + { ::N_COL }, + { ::N_COL }, + N_DSEL, + N_FSEL, + DummyLookupTable, + >; + + /// Tests that poseidon circuit is correctly formed (witness + /// generation + constraints match) and matches the CPU + /// specification of Poseidon. Fast to run, can be used for + /// debugging. + #[test] + pub fn test_poseidon_circuit() { + let mut rng = o1_utils::tests::make_test_rng(None); + let domain_size = 1 << 4; + + let mut witness_env: PoseidonWitnessBuilderEnv = WitnessBuilderEnv::create(); + // Write constants + { + let rc = PoseidonBN254Parameters.constants(); + rc.iter().enumerate().for_each(|(round, rcs)| { + rcs.iter().enumerate().for_each(|(state_index, rc)| { + let rc = vec![*rc; domain_size]; + witness_env.set_fixed_selector_cix( + PoseidonColumn::RoundConstant(round, state_index), + rc, + ) + }); + }); + } + + // Generate random inputs at each row + for _row in 0..domain_size { + let x: Fp = Fp::rand(&mut rng); + let y: Fp = Fp::rand(&mut rng); + let z: Fp = Fp::rand(&mut rng); + + interpreter::poseidon_circuit(&mut witness_env, &PoseidonBN254Parameters, [x, y, z]); + + // Check internal consistency of our circuit: that our + // computed values match the CPU-spec implementation of + // Poseidon. + { + let exp_output: Vec = { + let mut state: Vec = vec![x, y, z]; + poseidon_block_cipher::( + static_params(), + &mut state, + ); + state + }; + let x_col: PoseidonColumn = + PoseidonColumn::FullRound(NB_FULL_ROUND - 1, 3); + let y_col: PoseidonColumn = + PoseidonColumn::FullRound(NB_FULL_ROUND - 1, 7); + let z_col: PoseidonColumn = + PoseidonColumn::FullRound(NB_FULL_ROUND - 1, 11); + assert_eq!(witness_env.read_column(x_col), exp_output[0]); + assert_eq!(witness_env.read_column(y_col), exp_output[1]); + assert_eq!(witness_env.read_column(z_col), exp_output[2]); + } + + witness_env.next_row(); + } + } + + #[test] + /// Checks that poseidon circuit can be proven and verified. Big domain. + pub fn heavy_test_completeness() { + let mut rng = o1_utils::tests::make_test_rng(None); + let domain_size: usize = 1 << 15; + + let (relation_witness, fixed_selectors) = { + let mut witness_env: PoseidonWitnessBuilderEnv = WitnessBuilderEnv::create(); + + let mut fixed_selectors: [Vec; N_FSEL] = + std::array::from_fn(|_| vec![Fp::zero(); 1]); + // Write constants + { + let rc = PoseidonBN254Parameters.constants(); + rc.iter().enumerate().for_each(|(round, rcs)| { + rcs.iter().enumerate().for_each(|(state_index, rc)| { + witness_env.set_fixed_selector_cix( + PoseidonColumn::RoundConstant(round, state_index), + vec![*rc; domain_size], + ); + fixed_selectors[round * STATE_SIZE + state_index] = vec![*rc; domain_size]; + }); + }); + } + + // Generate random inputs at each row + for _row in 0..domain_size { + let x: Fp = Fp::rand(&mut rng); + let y: Fp = Fp::rand(&mut rng); + let z: Fp = Fp::rand(&mut rng); + + interpreter::poseidon_circuit( + &mut witness_env, + &PoseidonBN254Parameters, + [x, y, z], + ); + + witness_env.next_row(); + } + + ( + witness_env.get_relation_witness(domain_size), + fixed_selectors, + ) + }; + + let constraints = { + let mut constraint_env = ConstraintBuilderEnv::::create(); + interpreter::apply_permutation::< + Fp, + 3, + NB_FULL_ROUND, + NB_PARTIAL_ROUND, + NB_TOTAL_ROUND, + _, + _, + >(&mut constraint_env, &PoseidonBN254Parameters); + let constraints = constraint_env.get_constraints(); + + // We have 432 constraints in total if state size = 3, nb full + // rounds = 8, nb partial rounds = 56 + assert_eq!( + constraints.len(), + 4 * STATE_SIZE * NB_FULL_ROUND + (4 + STATE_SIZE - 1) * NB_PARTIAL_ROUND + ); + assert_eq!(constraints.len(), NB_CONSTRAINTS); + + // Maximum degree of the constraints is 2 + assert_eq!( + constraints.iter().map(|c| c.degree(1, 0)).max().unwrap(), + MAX_DEGREE + ); + + constraints + }; + + kimchi_msm::test::test_completeness_generic_no_lookups::( + constraints, + Box::new(fixed_selectors), + relation_witness, + domain_size, + &mut rng, + ); + } +} diff --git a/ivc/src/poseidon_params_55_0_7_3.rs b/ivc/src/poseidon_params_55_0_7_3.rs new file mode 100644 index 0000000000..397fbd0765 --- /dev/null +++ b/ivc/src/poseidon_params_55_0_7_3.rs @@ -0,0 +1,893 @@ +//! Poseidon parameters that can be used by [crate::poseidon_55_0_7_3_2] and +//! [crate::poseidon_55_0_7_3_7] over the scalar field of BN254 + +use ark_bn254::Fr as Fp; +use mina_poseidon::{constants::SpongeConstants, poseidon::ArithmeticSpongeParams}; +use once_cell::sync::Lazy; + +// FIXME: move into mina_poseidon when this code is ready to go into production. +/* Generated by the following command: +```shell +sage params.sage --rounds 55 rust 3 bn254 > bn254_poseidon.rs +``` +with the commit 85cbccd4266cdb567fd47ffc54c3fa52543c2c51 +where the curves have been changed in the script params.sage to use bn254 +*/ + +use std::str::FromStr; + +fn params() -> ArithmeticSpongeParams { + ArithmeticSpongeParams { + mds: vec![ + vec![ + Fp::from_str( + "18254759465961548724473677898900554279129333700345398792534035897544813417380", + ) + .unwrap(), + Fp::from_str( + "9701727079493493822153587075727161870060130406601968554920630310837077525782", + ) + .unwrap(), + Fp::from_str( + "6448075384674872988709285295252735647564764677180219427814076333498731944341", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "10582678689390677133753881397356982896844997242148790399395649637165847896125", + ) + .unwrap(), + Fp::from_str( + "14953878658754362056567950679445798963011901944940518481654524320703963311366", + ) + .unwrap(), + Fp::from_str( + "2769559556749664962636128855242028623701912278738796979170458933982658616355", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "11085432328429990699367643357327464988188253029150852680916595311968413759302", + ) + .unwrap(), + Fp::from_str( + "6616328759095010539674576883698535512612418243442338373652055842385579197823", + ) + .unwrap(), + Fp::from_str( + "1788173826467567565151305865636675972489937758217944451266517527454717587391", + ) + .unwrap(), + ], + ], + round_constants: vec![ + vec![ + Fp::from_str( + "464317818058904040455944620932325370057318211982758073789794737262512069860", + ) + .unwrap(), + Fp::from_str( + "3462163333362550903390325785926437039762839525013511939547720766741499409856", + ) + .unwrap(), + Fp::from_str( + "12308193519164615772431903590941145851045953550307364845912877146027882177970", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "2714591816949287658718284157908523280274340884728966119095559324137312851163", + ) + .unwrap(), + Fp::from_str( + "16384206429233192859678459425673999470857159372638200975562270773516251379588", + ) + .unwrap(), + Fp::from_str( + "19416611223573031732302519381697999456593181436167069017836818060955725864200", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "17603455420764872031698321921973319158461544396609704072026904164018868607752", + ) + .unwrap(), + Fp::from_str( + "11238098416885388471337306072943475212729724613485264776842071924751017567771", + ) + .unwrap(), + Fp::from_str( + "20115085162632036284496420664439046277270577640351232755593699800626568913281", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "13522402812268713007260660858020591835416325370361256561462576999381232564525", + ) + .unwrap(), + Fp::from_str( + "17993297334719348473740719243191272244929728431491444848946396448123051191677", + ) + .unwrap(), + Fp::from_str( + "14266894438972358618281326760350513234022317540856366683582613440301108328203", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "15544565783214950187472074309645319815959207946976778588709413987755227716658", + ) + .unwrap(), + Fp::from_str( + "11459281677950411823472431204851572448489387232119046190993717723024935465646", + ) + .unwrap(), + Fp::from_str( + "3175427723782491063823141004980187933184665826684252846205862234135139398026", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "18104087100840279413272173524166428384373648597907924300377376914338365801413", + ) + .unwrap(), + Fp::from_str( + "11059348970814019325544810642103429754398805409212413029570905399054361253560", + ) + .unwrap(), + Fp::from_str( + "9723694385274179134284666123116305587583263720877786033125569799811141709147", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "21334124835555176052046135550571447667154056413650326169040093998967534013839", + ) + .unwrap(), + Fp::from_str( + "13186187149295477383503606350078912331255804347098276303051944780008615958248", + ) + .unwrap(), + Fp::from_str( + "1470033952086888296713916393108192307114009370637357884278948695504654950262", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "7525745512005637899980788850937553026084357875660084470654426987818115307169", + ) + .unwrap(), + Fp::from_str( + "11196989505533608046175719825672890499899138989727925138952469167443760521626", + ) + .unwrap(), + Fp::from_str( + "9101167944054900479650446113089291293767218643046467434823338481469624077969", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "746152031855445071661969786731764979416640168084300832548286844362915267733", + ) + .unwrap(), + Fp::from_str( + "14500783258490511332342759610086566441467851535274198353685874901881586203920", + ) + .unwrap(), + Fp::from_str( + "12586553744079534087671160799108871165444925427796015728451282191808004632029", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "2658978799169405016067965116007625196452718808329137140199842584938385590196", + ) + .unwrap(), + Fp::from_str( + "14945177760821130319202365009462588707535813273905143480365649428853002471274", + ) + .unwrap(), + Fp::from_str( + "17059544527948691563140405641038225702602878092594471881549757285581285907208", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "20191280011693017888798491354094094243127547574655352220297687980090230216", + ) + .unwrap(), + Fp::from_str( + "18502499034030581022389564056073081062360187604161458073890564428974602584477", + ) + .unwrap(), + Fp::from_str( + "8264062626605956514525302888463961111763485965406820048999276299668892074923", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "10792209418985110179364134772680195398048938179898831723416765403620730620462", + ) + .unwrap(), + Fp::from_str( + "11664556276797417846866489492418594702491171278433997565894976597592370230667", + ) + .unwrap(), + Fp::from_str( + "12270771377697756787250169216048909236321853168489305289596185098410215090484", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "8811136672424023763408452444568991590125177469609096631745680177504609711487", + ) + .unwrap(), + Fp::from_str( + "20237676190929489436588870470052869226876887618934149899920104260655948094739", + ) + .unwrap(), + Fp::from_str( + "20062974513054622329783259390735561568588514582980747668306658460292592667248", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "798776603612212418444623010074034446472782213575824096964910597458981693714", + ) + .unwrap(), + Fp::from_str( + "8758094627630360838980606896109464566485363275799929171438278126952202465833", + ) + .unwrap(), + Fp::from_str( + "707487852423957926868336579455035906997963205205141270331975346408432067044", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "202779927146339323474786313977213080711135813758190854284590663556011706836", + ) + .unwrap(), + Fp::from_str( + "20391231297043803259194308176992776361609861855662006565804325535077866583400", + ) + .unwrap(), + Fp::from_str( + "13103971641300442794695006804822894722919592148173662271765398624597649213516", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "11887088560178314335635682366268022017575768082106271832073442174179560940002", + ) + .unwrap(), + Fp::from_str( + "9497730497365106303373876625863922615426121795921218372965815099779016166015", + ) + .unwrap(), + Fp::from_str( + "13255889376415173861282249540978269441170571583954046228817145377643690062643", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "21797601683155849424606801821462621112628210830522816165936272187765788522806", + ) + .unwrap(), + Fp::from_str( + "13791904812004141351280616602806766268926714029237345281809436730651246516674", + ) + .unwrap(), + Fp::from_str( + "3541295787421703076701812027030126532836938797821201403520460081703194798007", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "6772295061598565998757352284489337907928739308702535748450215960170676319556", + ) + .unwrap(), + Fp::from_str( + "19608658591798530616680061208581311706541541564609366646372793300896320095218", + ) + .unwrap(), + Fp::from_str( + "15108470522637839917783948187626705626608074131603801620229663133367773769409", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "6982167525494887125745884237873629770460477859443352710958203128353080003941", + ) + .unwrap(), + Fp::from_str( + "7877853416249147253292651939679543007752875545630493492558173940354891685347", + ) + .unwrap(), + Fp::from_str( + "16723886233609641431662771575929234875591162154288587674425439500724219825212", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "2671311012893846775573899497331547132295274951055892772678260375558714942406", + ) + .unwrap(), + Fp::from_str( + "10473922685305341900412641696196941263307819634043185788067854748950602802383", + ) + .unwrap(), + Fp::from_str( + "5111630741260296655990184310611375231177307348451724203930709456972805973880", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "4992057345914438551815056581049158703461724962129889522457212594798779955393", + ) + .unwrap(), + Fp::from_str( + "21183229986193453383825269196190741932488867108677665668392098260004852050858", + ) + .unwrap(), + Fp::from_str( + "17083812614758825262603118558173710997332640429545889370504202808090339572857", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "18125355879021821849041167225185950892076210587475775745349215815478764407922", + ) + .unwrap(), + Fp::from_str( + "8733966472516018163838600897336665253176449185681528084604422362542186741640", + ) + .unwrap(), + Fp::from_str( + "12872170036700590319474830060138441606481805188061693688276584984859227049202", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "4477722790415428221489180851061526086953856916136558152290587660464310195129", + ) + .unwrap(), + Fp::from_str( + "15013641992838646488575621859957900283128972996570425739867595472354843675113", + ) + .unwrap(), + Fp::from_str( + "11110048915932820585149019832900719558018746589803602931460479672600239301992", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "6386524788949068680243545386229548182250656575626566757865736117588702942149", + ) + .unwrap(), + Fp::from_str( + "10817607938125396761228235991828983933462215897289776505148575850147631400593", + ) + .unwrap(), + Fp::from_str( + "17865086662065068946652014543896618689007438069494510844490187921781045033543", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "21842035812228593226084835552836653742808044145939024628842446854237127923353", + ) + .unwrap(), + Fp::from_str( + "2471068742343144756503096572812849495213874183856231809598640313621938630669", + ) + .unwrap(), + Fp::from_str( + "6547073217674180994961882255834734219166276822047545566909303528752619346965", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "9562418092642273301602927692306423921540755207933680809887955184175060412865", + ) + .unwrap(), + Fp::from_str( + "5728419725627720701717078886669712201564103341166911632619636881997293379662", + ) + .unwrap(), + Fp::from_str( + "9577751993330236726740958969117670757888267786661142736615315540557935073808", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "18485877910552384007717083804696131794323032392196584185820990234315883649039", + ) + .unwrap(), + Fp::from_str( + "1769516528997465058766917692166940530540162775010485894812511264651123420964", + ) + .unwrap(), + Fp::from_str( + "386757200774461941643538807053842280717539800465562760112828539321993775241", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "3316509296104361799525651631938052468498330218596415774527777363167138087617", + ) + .unwrap(), + Fp::from_str( + "13729336006888298565425261097968915916390598196131379830434588594171666569512", + ) + .unwrap(), + Fp::from_str( + "16148239410408669906403652096870758949643310812745307334909798696471189947927", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "21823520381432477314141727574561655219860240640515827154913290917309686356497", + ) + .unwrap(), + Fp::from_str( + "4578651555280644323637866316527376720986123056791528481862274310799615790675", + ) + .unwrap(), + Fp::from_str( + "5091226074924237835610573037013208388130454756337337471122368918009398753447", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "6515120964938381342071855012952056094433032260551298595194434149594056652176", + ) + .unwrap(), + Fp::from_str( + "2340842607486511800425610472421888578518367747461866490788654350550767741883", + ) + .unwrap(), + Fp::from_str( + "14503247487728138668651412383800848326991400373338038830652551371427345571441", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "3712612816523606936094514234541515605991691986951168127960369409943944902903", + ) + .unwrap(), + Fp::from_str( + "1152766033584194531070871869834974085415579739949782153652057784926343256899", + ) + .unwrap(), + Fp::from_str( + "8765049913334393206963752695629450392338936491569237833753514286462581790561", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "1601627912182619663670418257843886466719103377571015176803695116602531025179", + ) + .unwrap(), + Fp::from_str( + "18098024103033142368624034583201165961009040474643655548434700066455514231331", + ) + .unwrap(), + Fp::from_str( + "11268744383531850444109289654319779756413360261210505934077738062212630536901", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "13295211045409383825433911169815486079945004241473963274225158733643991813121", + ) + .unwrap(), + Fp::from_str( + "10102619007330980171750476017504785290001093245315837544464105419145430593293", + ) + .unwrap(), + Fp::from_str( + "209066254001010781273163283043011468251658035504800687498237342325693440716", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "10569794888394925789900696666438098928482919555802848731145904974932210713381", + ) + .unwrap(), + Fp::from_str( + "3479060041588228962773080916545344859794556889989789923050334615578110393377", + ) + .unwrap(), + Fp::from_str( + "8177920417424677597775227574376448009357347539898615178798501628859714007479", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "12067616893727942346831314147294036302857666398663726324584342747296065928710", + ) + .unwrap(), + Fp::from_str( + "16971111464439978238379247781540292780056241741803063434856766170743124461469", + ) + .unwrap(), + Fp::from_str( + "19754437344667989781187357721513315145328407420825515279265102170857620051268", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "3975594217367336645402283024719769072885792075344511606692741601510641156311", + ) + .unwrap(), + Fp::from_str( + "19040197452680989990741508256669934716515775259429067120683774529031109727966", + ) + .unwrap(), + Fp::from_str( + "5634578132978895965594894841292458192813357921721236220653749372316851034659", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "18016203270230065093894089716032049117997200515456531317821733896468154005026", + ) + .unwrap(), + Fp::from_str( + "2432210503610876869795949218217418458948353953591635398452243060116570483824", + ) + .unwrap(), + Fp::from_str( + "13732529162121190544870742850789344400294092206745459683809641438775485491009", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "14451766174292343678890357279454517773013585124369849119851032684103228653072", + ) + .unwrap(), + Fp::from_str( + "12181367175079876176988175575030826371836485595254655739186027469233503901695", + ) + .unwrap(), + Fp::from_str( + "10991645916339084020684503132290495819766466380987797648576877855177656702574", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "21313503925572929158381750262537945058966279485804307864340280517601895617812", + ) + .unwrap(), + Fp::from_str( + "1550527311092834162328201290652099251984068161341884726521396319041776382263", + ) + .unwrap(), + Fp::from_str( + "12003435213367105899404267897476505638234793525266930670911913539817221417924", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "8914436165324002549401721460186086215484314929358430322367050538121529812441", + ) + .unwrap(), + Fp::from_str( + "5912060284310139977854439979850774712081952153779342442587915555451388976602", + ) + .unwrap(), + Fp::from_str( + "18562577306045458592699989513868153553084370626914798022460227144415813035912", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "3121108818032693608292705100552135299104657945137655223341160358361023747970", + ) + .unwrap(), + Fp::from_str( + "17009269623907524632744383402475624293106226516559465345873667100464648970276", + ) + .unwrap(), + Fp::from_str( + "16002511017673785381646235433633055570232704209925362749110958022614294680393", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "9391205603129777266623979897548135033992809302803677537718465809304634939535", + ) + .unwrap(), + Fp::from_str( + "18388689572534660670130774712714127815208938853203887656653830866982720723174", + ) + .unwrap(), + Fp::from_str( + "5541548341946936547043936820389981449139467811019058131623682269330484215783", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "7190170322642292762509602175876149041297305605840922902837779590457938918669", + ) + .unwrap(), + Fp::from_str( + "2027727849652978298961694136962688734430323753419089111303416976388312626148", + ) + .unwrap(), + Fp::from_str( + "402427768909234974448400451691124558802645007494232545425431267044631491274", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "12797987239774923416407369376953974586493255155241527032821517134381740228905", + ) + .unwrap(), + Fp::from_str( + "13333249780205062899642181453873590081512016926206507817605750877019442371089", + ) + .unwrap(), + Fp::from_str( + "18589810567518770657334682134081256712119747946842941531219734231027968237759", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "14776177603689091833670811465610858113310593319381602563122002687195119058247", + ) + .unwrap(), + Fp::from_str( + "8091773219661648644611506313794854881311037176009459530767471191252204309868", + ) + .unwrap(), + Fp::from_str( + "11572415535524327542083010017274275840592842320730873456621979565559067913733", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "20923164127054504021838810951918321616973675807578349217848643497149066166007", + ) + .unwrap(), + Fp::from_str( + "13778673838153545444828783204107784921806288587132367999597554094415915790790", + ) + .unwrap(), + Fp::from_str( + "15824998809651622678290890820327874961632638426240138370022501121781103446259", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "4090916114781784760754024121473951637398586212781078739533487486471895169595", + ) + .unwrap(), + Fp::from_str( + "960857870893264392654072242483170132703826969766439993624797963496069781505", + ) + .unwrap(), + Fp::from_str( + "20907571141369581631669592493277131511110363500792783458107745525298320150137", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "12109635595469802832249294970134397092397779126665661972278751053960740447966", + ) + .unwrap(), + Fp::from_str( + "4824844368626589074800225473359998312724599509010698030845972726192270761413", + ) + .unwrap(), + Fp::from_str( + "1279633418489206280063439936145640252901253476940505080338091886704297503137", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "13710991567078829309950552225511825775559273430088991705888338422645023107567", + ) + .unwrap(), + Fp::from_str( + "21744331847293235251226457402120455816051177980544192245842389157344912734176", + ) + .unwrap(), + Fp::from_str( + "13750456500419553884447154012886160987332428475286924998031161614878499276461", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "7897870706671989949903359523135485593775134275719857967175422051001928288650", + ) + .unwrap(), + Fp::from_str( + "11530512859482871748707639867674853723860540856854102438746375696832108722419", + ) + .unwrap(), + Fp::from_str( + "15102320430753094455411262873908179792479117498565463936483853362951093377404", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "20684012922073862986880613075424856809319565915848039596293835293542621074230", + ) + .unwrap(), + Fp::from_str( + "11143754623250183108051550852169137897571161411757151918630704123953225725876", + ) + .unwrap(), + Fp::from_str( + "21515833961121090836372549249533024521086358011705741464637998484545867651639", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "14927367861301244473471972423309886137778884679987552642033604382283866452496", + ) + .unwrap(), + Fp::from_str( + "2234445563530370737450306710054836240206601416475326119659514396778117223266", + ) + .unwrap(), + Fp::from_str( + "7320243644229576716977260850447281726825012665363843781486677636270396893993", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "3648709991731171089283528046834957511485556517829186374119154464357293936649", + ) + .unwrap(), + Fp::from_str( + "8431771340355225974834111903440269408389608143822442356390982704086190421477", + ) + .unwrap(), + Fp::from_str( + "850707530944871742730572985296454989068982533587676445405227892149682249460", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "12770330711881918473357359607046064327068283622530828696564875920105463050911", + ) + .unwrap(), + Fp::from_str( + "18309438953837613260356082733742847603856465956586832431344593061564231402706", + ) + .unwrap(), + Fp::from_str( + "7558769148417330857693721923631008074203529622851527561622447590521963662006", + ) + .unwrap(), + ], + vec![ + Fp::from_str( + "3817411164684956972519434296876841710387488302709712662164166985673595410943", + ) + .unwrap(), + Fp::from_str( + "14529667139169881031943219432572774573785882712681678612418006263967413092263", + ) + .unwrap(), + Fp::from_str( + "5591068111468958432601398487692727664831870363727501013569466045034798523918", + ) + .unwrap(), + ], + ], + } +} + +pub fn static_params() -> &'static ArithmeticSpongeParams { + static PARAMS: Lazy> = Lazy::new(params); + &PARAMS +} + +/// Constants used by the IVC circuit used by the folding scheme +/// This is meant to be only used for the IVC circuit, with the BN254 curve +/// It has not been tested with/for other curves. +// This must be moved into mina_poseidon later. +#[derive(Clone)] +pub struct PlonkSpongeConstantsIVC {} + +impl SpongeConstants for PlonkSpongeConstantsIVC { + const SPONGE_CAPACITY: usize = 1; + const SPONGE_WIDTH: usize = 3; + const SPONGE_RATE: usize = 2; + const PERM_ROUNDS_FULL: usize = 55; + const PERM_ROUNDS_PARTIAL: usize = 0; + const PERM_HALF_ROUNDS_FULL: usize = 0; + const PERM_SBOX: u32 = 7; + const PERM_FULL_MDS: bool = true; + const PERM_INITIAL_ARK: bool = false; +} + +#[cfg(test)] +mod tests { + use std::str::FromStr; + + use super::{static_params, PlonkSpongeConstantsIVC}; + use ark_bn254::Fr as Fp; + use kimchi::o1_utils::FieldHelpers; + use mina_poseidon::poseidon::{ArithmeticSponge as Poseidon, Sponge as _}; + + // Regression tests. Test vectors have been generated using the same + // code, commit a1ad3f17b12baff62bc7fab6c0c4bacb704518d5 + // This does not mean that the implementation is correct. + #[test] + fn test_poseidon() { + let mut hash = Poseidon::::new(static_params()); + let input: [Fp; 3] = [ + Fp::from_str("1").unwrap(), + Fp::from_str("1").unwrap(), + Fp::from_str("1").unwrap(), + ]; + let exp_output_str = [ + "62838cdb9e91f1e85b450c1c45ccd220da17896083479d4879d520dc51ca390c", + "1dcb5390c0c61fc6c461043e5b1ae385faedb3d5a87c95f488bb851faca61f0a", + "fc0388532e2e38ec1e1402fbcf016a7954cf8fb6b6ba5739c7666aea7af42c29", + ]; + let exp_output = exp_output_str.map(|x| Fp::from_hex(x).unwrap()); + hash.absorb(&input); + assert_eq!(hash.state, exp_output); + } +} diff --git a/ivc/src/prover.rs b/ivc/src/prover.rs new file mode 100644 index 0000000000..8e318ad7a7 --- /dev/null +++ b/ivc/src/prover.rs @@ -0,0 +1,500 @@ +#![allow(clippy::type_complexity)] +#![allow(clippy::boxed_local)] + +use crate::{ + expr_eval::SimpleEvalEnv, + plonkish_lang::{PlonkishChallenge, PlonkishInstance, PlonkishWitness}, +}; +use ark_ff::{Field, One, Zero}; +use ark_poly::{ + univariate::DensePolynomial, EvaluationDomain, Evaluations, Polynomial, + Radix2EvaluationDomain as R2D, +}; +use folding::{ + eval_leaf::EvalLeaf, + instance_witness::{ExtendedWitness, RelaxedInstance, RelaxedWitness}, + Alphas, FoldingCompatibleExpr, FoldingConfig, +}; +use kimchi::{ + self, + circuits::{ + domains::EvaluationDomains, + expr::{ColumnEvaluations, ExprError}, + }, + curve::KimchiCurve, + groupmap::GroupMap, + plonk_sponge::FrSponge, + proof::PointEvaluations, +}; +use kimchi_msm::{columns::Column as GenericColumn, witness::Witness}; +use mina_poseidon::{sponge::ScalarChallenge, FqSponge}; +use o1_utils::ExtendedDensePolynomial; +use poly_commitment::{ + commitment::{absorb_commitment, CommitmentCurve, PolyComm}, + kzg::{KZGProof, PairingSRS}, + utils::DensePolynomialOrEvaluations, + OpenProof, SRS, +}; +use rand::{CryptoRng, RngCore}; +use rayon::iter::{IntoParallelIterator, ParallelIterator}; +use std::collections::BTreeMap; +use thiserror::Error; + +/// Errors that can arise when creating a proof +#[derive(Error, Debug, Clone)] +pub enum ProverError { + #[error("the proof could not be constructed: {0}")] + Generic(&'static str), + + #[error("the provided (witness) constraints was not satisfied: {0}")] + ConstraintNotSatisfied(String), + + #[error("the provided (witness) constraint has degree {0} > allowed {1}; expr: {2}")] + ConstraintDegreeTooHigh(u64, u64, String), +} + +pub type Pairing = kimchi_msm::BN254; +/// The curve we commit into +pub type G = kimchi_msm::BN254G1Affine; +/// Scalar field of the curve. +pub type Fp = kimchi_msm::Fp; +/// The base field of the curve +/// Used to encode the polynomial commitments +pub type Fq = ark_bn254::Fq; + +#[derive(Debug, Clone)] +// TODO Should public input and fixed selectors evaluations be here? +pub struct ProofEvaluations< + const N_WIT: usize, + const N_REL: usize, + const N_DSEL: usize, + const N_FSEL: usize, + F, +> { + /// Witness evaluations, including public inputs + pub witness_evals: Witness>, + /// Evaluations of fixed selectors. + pub fixed_selectors_evals: Box<[PointEvaluations; N_FSEL]>, + pub error_vec: PointEvaluations, + /// Evaluation of Z_H(ζ) (t_0(X) + ζ^n t_1(X) + ...) at ζω. + pub ft_eval1: F, +} + +/// The trait ColumnEvaluations is used by the verifier. +/// It will return the evaluation of the corresponding column at the +/// evaluation points coined by the verifier during the protocol. +impl< + const N_WIT: usize, + const N_REL: usize, + const N_DSEL: usize, + const N_FSEL: usize, + F: Clone, + > ColumnEvaluations for ProofEvaluations +{ + type Column = kimchi_msm::columns::Column; + + fn evaluate(&self, col: Self::Column) -> Result, ExprError> { + // TODO: substitute when non-literal generic constants are available + assert!(N_WIT == N_REL + N_DSEL); + let res = match col { + Self::Column::Relation(i) => { + assert!(i < N_REL, "Index out of bounds"); + self.witness_evals[i].clone() + } + Self::Column::DynamicSelector(i) => { + assert!(i < N_DSEL, "Index out of bounds"); + self.witness_evals[N_REL + i].clone() + } + Self::Column::FixedSelector(i) => { + assert!(i < N_FSEL, "Index out of bounds"); + self.fixed_selectors_evals[i].clone() + } + _ => panic!("lookup columns not supported"), + }; + Ok(res) + } +} + +#[derive(Debug, Clone)] +pub struct ProofCommitments { + /// Commitments to the N columns of the circuits, also called the 'witnesses'. + /// If some columns are considered as public inputs, it is counted in the witness. + pub witness_comms: Witness>, + /// Commitments to the quotient polynomial. + /// The value contains the chunked polynomials. + pub t_comm: PolyComm, +} + +#[derive(Debug, Clone)] +pub struct Proof< + const N_WIT: usize, + const N_REL: usize, + const N_DSEL: usize, + const N_FSEL: usize, + G: KimchiCurve, + OpeningProof: OpenProof, +> { + pub proof_comms: ProofCommitments, + pub proof_evals: ProofEvaluations, + pub opening_proof: OpeningProof, + + // Unsure whether this is necessary. + pub alphas: Alphas, + pub challenges: [G::ScalarField; 3], + pub u: G::ScalarField, +} + +pub fn prove< + EFqSponge: Clone + FqSponge, + EFrSponge: FrSponge, + FC: FoldingConfig, + RNG, + const N_WIT: usize, + const N_WIT_QUAD: usize, // witness columns + quad columns + const N_REL: usize, + const N_DSEL: usize, + const N_FSEL: usize, + const N_ALPHAS: usize, +>( + domain: EvaluationDomains, + srs: &PairingSRS, + combined_expr: &FoldingCompatibleExpr, + folded_instance: RelaxedInstance>, + folded_witness: RelaxedWitness>, + rng: &mut RNG, +) -> Result>, ProverError> +where + RNG: RngCore + CryptoRng, +{ + assert_eq!( + folded_witness.extended_witness.extended.values().len(), + N_WIT_QUAD - N_WIT + ); + assert!(N_WIT == N_REL + N_DSEL); + + //////////////////////////////////////////////////////////////////////////// + // Setting up the protocol + //////////////////////////////////////////////////////////////////////////// + + let group_map = ::Map::setup(); + + //////////////////////////////////////////////////////////////////////////// + // Round 1: Creating and absorbing column commitments + //////////////////////////////////////////////////////////////////////////// + + let mut fq_sponge = EFqSponge::new(G::other_curve_sponge_params()); + + let fixed_selectors_evals_d1: Box<[Evaluations>; N_FSEL]> = + folded_witness.extended_witness.witness.fixed_selectors.cols; + + let fixed_selectors_polys: Box<[DensePolynomial; N_FSEL]> = + o1_utils::array::vec_to_boxed_array( + fixed_selectors_evals_d1 + .clone() + .into_par_iter() + .map(|evals| evals.interpolate()) + .collect(), + ); + + let fixed_selectors_comms: Box<[PolyComm; N_FSEL]> = { + let comm = |poly: &DensePolynomial| srs.commit_non_hiding(poly, 1); + o1_utils::array::vec_to_boxed_array( + fixed_selectors_polys + .as_ref() + .into_par_iter() + .map(comm) + .collect(), + ) + }; + + // Do not use parallelism + (fixed_selectors_comms) + .into_iter() + .for_each(|comm| absorb_commitment(&mut fq_sponge, &comm)); + + let witness_main: Witness = folded_witness.extended_witness.witness.witness; + let witness_ext: BTreeMap>> = + folded_witness.extended_witness.extended; + + // Joint main + ext + let witness_evals_d1: Witness> = { + let mut acc = witness_main.cols.to_vec(); + acc.extend(witness_ext.values().cloned()); + acc.try_into().unwrap() + }; + + let witness_polys: Witness> = { + witness_evals_d1 + .into_par_iter() + .map(|e| e.interpolate()) + .collect::>() + .try_into() + .unwrap() + }; + + let witness_comms: Witness> = { + let blinders = PolyComm { + chunks: vec![Fp::one()], + }; + let comm = { + |poly: &DensePolynomial| { + // In case the column polynomial is all zeroes, we want to mask the commitment + let comm = srs.commit_custom(poly, 1, &blinders).unwrap(); + comm.commitment + } + }; + (&witness_polys) + .into_par_iter() + .map(comm) + .collect::>>() + }; + + // Do not use parallelism + (&witness_comms) + .into_iter() + .for_each(|comm| absorb_commitment(&mut fq_sponge, comm)); + + //////////////////////////////////////////////////////////////////////////// + // Round 2: Creating and committing to the quotient polynomial + //////////////////////////////////////////////////////////////////////////// + + let (_, endo_r) = G::endos(); + + let quotient_poly = { + let evaluation_domain = domain.d4; + + let enlarge_to_domain_generic = + |evaluations: &Evaluations>, new_domain: R2D| { + assert!(evaluations.domain() == domain.d1); + evaluations + .interpolate_by_ref() + .evaluate_over_domain_by_ref(new_domain) + }; + + let enlarge_to_domain = |evaluations: &Evaluations>| { + enlarge_to_domain_generic(evaluations, evaluation_domain) + }; + + let simple_eval_env: SimpleEvalEnv = { + let ext_witness = ExtendedWitness { + witness: PlonkishWitness { + witness: (&witness_main) + .into_par_iter() + .map(enlarge_to_domain) + .collect(), + fixed_selectors: (&fixed_selectors_evals_d1.to_vec()) + .into_par_iter() + .map(enlarge_to_domain) + .collect(), + phantom: std::marker::PhantomData, + }, + extended: (&witness_ext) + .into_par_iter() + .map(|(ix, evals)| (*ix, enlarge_to_domain(evals))) + .collect(), + }; + + SimpleEvalEnv { + ext_witness, + alphas: folded_instance.extended_instance.instance.alphas.clone(), + challenges: folded_instance.extended_instance.instance.challenges, + error_vec: enlarge_to_domain(&folded_witness.error_vec), + u: folded_instance.u, + } + }; + + { + let eval_leaf = simple_eval_env.eval_naive_fcompat(combined_expr); + + let evaluations_big = match eval_leaf { + EvalLeaf::Result(evaluations) => evaluations, + EvalLeaf::Col(evaluations) => evaluations.to_vec().clone(), + _ => panic!("eval_leaf is not Result"), + }; + + let interpolated = + Evaluations::from_vec_and_domain(evaluations_big, evaluation_domain).interpolate(); + if interpolated.is_zero() { + println!("Interpolated expression is zero"); + } + + let (quotient, remainder) = interpolated + .divide_by_vanishing_poly(domain.d1) + .unwrap_or_else(|| panic!("ERROR: Cannot divide by vanishing polynomial")); + if !remainder.is_zero() { + panic!("ERROR: Remainder is not zero for joint folding expression",); + } + + quotient + } + }; + + // we interpolate over d4, so number of chunks should be 3 + let num_chunks: usize = 3; + + //~ 1. commit to the quotient polynomial $t$. + let t_comm = srs.commit_non_hiding("ient_poly, num_chunks); + + //////////////////////////////////////////////////////////////////////////// + // Round 3: Evaluations at ζ and ζω + //////////////////////////////////////////////////////////////////////////// + + //~ 1. Absorb the commitment of the quotient polynomial with the Fq-Sponge. + absorb_commitment(&mut fq_sponge, &t_comm); + + //~ 1. Sample ζ with the Fq-Sponge. + let zeta_chal = ScalarChallenge(fq_sponge.challenge()); + + let zeta = zeta_chal.to_field(endo_r); + + let omega = domain.d1.group_gen; + // We will also evaluate at ζω as lookups do require to go to the next row. + let zeta_omega = zeta * omega; + + let eval_at_challenge = |p: &DensePolynomial<_>| PointEvaluations { + zeta: p.evaluate(&zeta), + zeta_omega: p.evaluate(&zeta_omega), + }; + + // Evaluate the polynomials at ζ and ζω -- Columns + let witness_point_evals: Witness> = { + (&witness_polys) + .into_par_iter() + .map(eval_at_challenge) + .collect::>>() + }; + + let fixed_selectors_point_evals: Box<[PointEvaluations<_>; N_FSEL]> = { + o1_utils::array::vec_to_boxed_array( + fixed_selectors_polys + .as_ref() + .into_par_iter() + .map(eval_at_challenge) + .collect::<_>(), + ) + }; + + let error_vec_point_eval = eval_at_challenge(&folded_witness.error_vec.interpolate()); + + //////////////////////////////////////////////////////////////////////////// + // Round 4: Opening proof w/o linearization polynomial + //////////////////////////////////////////////////////////////////////////// + + // Fiat Shamir - absorbing evaluations + let fq_sponge_before_evaluations = fq_sponge.clone(); + let mut fr_sponge = EFrSponge::new(G::sponge_params()); + fr_sponge.absorb(&fq_sponge.digest()); + + for PointEvaluations { zeta, zeta_omega } in (&witness_point_evals).into_iter() { + fr_sponge.absorb(zeta); + fr_sponge.absorb(zeta_omega); + } + + for PointEvaluations { zeta, zeta_omega } in fixed_selectors_point_evals.as_ref().iter() { + fr_sponge.absorb(zeta); + fr_sponge.absorb(zeta_omega); + } + + // Compute ft(X) = \ + // (1 - ζ^n) \ + // (t_0(X) + ζ^n t_1(X) + ... + ζ^{kn} t_{k}(X)) + // where \sum_i t_i(X) X^{i n} = t(X), and t(X) is the quotient polynomial. + // At the end, we get the (partial) evaluation of the constraint polynomial + // in ζ. + // + // Note: both (ζ^n - 1) and (1 - ζ^n) (and C * (1 - ζ^n)) are + // vanishing polynomial, but we have to be consistent with respect + // to just one everywhere. + let ft: DensePolynomial = { + let evaluation_point_to_domain_size = zeta.pow([domain.d1.size]); + // Compute \sum_i t_i(X) ζ^{i n} + // First we split t in t_i, and we reduce to degree (n - 1) after using `linearize` + let t_chunked: DensePolynomial = quotient_poly + .to_chunked_polynomial(num_chunks, domain.d1.size as usize) + .linearize(evaluation_point_to_domain_size); + + // -Z_H = (1 - ζ^n) + let minus_vanishing_poly_at_zeta: Fp = -domain.d1.vanishing_polynomial().evaluate(&zeta); + // Multiply the polynomial \sum_i t_i(X) ζ^{i n} by -Z_H(ζ) + // (the evaluation in ζ of the vanishing polynomial) + t_chunked.scale(minus_vanishing_poly_at_zeta) + }; + + // We only evaluate at ζω as the verifier can compute the + // evaluation at ζ from the independent evaluations at ζ of the + // witness columns because ft(X) is the constraint polynomial, built from + // the public constraints. + // We evaluate at ζω because the lookup argument requires to compute + // \phi(Xω) - \phi(X). + let ft_eval1 = ft.evaluate(&zeta_omega); + + // Absorb ft(ζω) + fr_sponge.absorb(&ft_eval1); + + let v_chal = fr_sponge.challenge(); + let v = v_chal.to_field(endo_r); + let u_chal = fr_sponge.challenge(); + let u = u_chal.to_field(endo_r); + + let coefficients_form = DensePolynomialOrEvaluations::DensePolynomial; + let non_hiding = |n_chunks| PolyComm { + chunks: vec![Fp::zero(); n_chunks], + }; + let hiding = |n_chunks| PolyComm { + chunks: vec![Fp::one(); n_chunks], + }; + + // Gathering all polynomials_to_open to use in the opening proof + let mut polynomials_to_open: Vec<_> = vec![]; + + polynomials_to_open.extend( + (&witness_polys) + .into_par_iter() + .map(|poly| (coefficients_form(poly), hiding(1))) + .collect::>(), + ); + + // @volhovm: I'm not sure we need to prove opening of fixed + // selectors in the commitment. + polynomials_to_open.extend( + fixed_selectors_polys + .as_ref() + .into_par_iter() + .map(|poly| (coefficients_form(poly), non_hiding(1))) + .collect::>(), + ); + + polynomials_to_open.push((coefficients_form(&ft), non_hiding(1))); + + let opening_proof = OpenProof::open::<_, _, R2D>( + srs, + &group_map, + polynomials_to_open.as_slice(), + &[zeta, zeta_omega], + v, + u, + fq_sponge_before_evaluations, + rng, + ); + + let proof_evals: ProofEvaluations = { + ProofEvaluations { + witness_evals: witness_point_evals, + fixed_selectors_evals: fixed_selectors_point_evals, + error_vec: error_vec_point_eval, + ft_eval1, + } + }; + + Ok(Proof { + proof_comms: ProofCommitments { + witness_comms, + t_comm, + }, + proof_evals, + opening_proof, + alphas: folded_instance.extended_instance.instance.alphas, + challenges: folded_instance.extended_instance.instance.challenges, + u: folded_instance.u, + }) +} diff --git a/ivc/src/verifier.rs b/ivc/src/verifier.rs new file mode 100644 index 0000000000..acec3e5f6f --- /dev/null +++ b/ivc/src/verifier.rs @@ -0,0 +1,250 @@ +use crate::{ + expr_eval::GenericEvalEnv, + plonkish_lang::{PlonkishChallenge, PlonkishWitnessGeneric}, + prover::Proof, +}; +use ark_ff::Field; +use ark_poly::{ + univariate::DensePolynomial, EvaluationDomain, Evaluations, Polynomial, + Radix2EvaluationDomain as R2D, +}; +use folding::{ + eval_leaf::EvalLeaf, instance_witness::ExtendedWitness, FoldingCompatibleExpr, FoldingConfig, +}; +use kimchi::{ + self, circuits::domains::EvaluationDomains, curve::KimchiCurve, groupmap::GroupMap, + plonk_sponge::FrSponge, proof::PointEvaluations, +}; +use kimchi_msm::columns::Column as GenericColumn; +use mina_poseidon::{sponge::ScalarChallenge, FqSponge}; +use poly_commitment::{ + commitment::{ + absorb_commitment, combined_inner_product, BatchEvaluationProof, CommitmentCurve, + Evaluation, PolyComm, + }, + kzg::{KZGProof, PairingSRS}, + OpenProof, SRS, +}; +use rand::thread_rng; +use rayon::iter::{IntoParallelIterator, ParallelIterator}; + +pub type Pairing = kimchi_msm::BN254; +/// The curve we commit into +pub type G = kimchi_msm::BN254G1Affine; +/// Scalar field of the curve. +pub type Fp = kimchi_msm::Fp; +/// The base field of the curve +/// Used to encode the polynomial commitments +pub type Fq = ark_bn254::Fq; + +pub fn verify< + EFqSponge: Clone + FqSponge, + EFrSponge: FrSponge, + FC: FoldingConfig, + const N_WIT: usize, + const N_REL: usize, + const N_DSEL: usize, + const N_FSEL: usize, + const NPUB: usize, +>( + domain: EvaluationDomains, + srs: &PairingSRS, + combined_expr: &FoldingCompatibleExpr, + fixed_selectors: Box<[Evaluations>; N_FSEL]>, + proof: &Proof>, +) -> bool { + assert!(N_WIT == N_REL + N_DSEL); + + let Proof { + proof_comms, + proof_evals, + opening_proof, + .. + } = proof; + + //////////////////////////////////////////////////////////////////////////// + // Re-evaluating public inputs + //////////////////////////////////////////////////////////////////////////// + + let fixed_selectors_evals_d1: Box<[Evaluations>; N_FSEL]> = fixed_selectors; + + let fixed_selectors_polys: Box<[DensePolynomial; N_FSEL]> = { + o1_utils::array::vec_to_boxed_array( + fixed_selectors_evals_d1 + .into_par_iter() + .map(|evals| evals.interpolate()) + .collect(), + ) + }; + + let fixed_selectors_comms: Box<[PolyComm; N_FSEL]> = { + let comm = |poly: &DensePolynomial| srs.commit_non_hiding(poly, 1); + o1_utils::array::vec_to_boxed_array( + fixed_selectors_polys + .as_ref() + .into_par_iter() + .map(comm) + .collect(), + ) + }; + + //////////////////////////////////////////////////////////////////////////// + // Absorbing all the commitments to the columns + //////////////////////////////////////////////////////////////////////////// + + let mut fq_sponge = EFqSponge::new(G::other_curve_sponge_params()); + + fixed_selectors_comms + .as_ref() + .iter() + .chain(&proof_comms.witness_comms) + .for_each(|comm| absorb_commitment(&mut fq_sponge, comm)); + + //////////////////////////////////////////////////////////////////////////// + // Quotient polynomial + //////////////////////////////////////////////////////////////////////////// + + absorb_commitment(&mut fq_sponge, &proof_comms.t_comm); + + // -- Preparing for opening proof verification + let zeta_chal = ScalarChallenge(fq_sponge.challenge()); + let (_, endo_r) = G::endos(); + let zeta: Fp = zeta_chal.to_field(endo_r); + let omega = domain.d1.group_gen; + let zeta_omega = zeta * omega; + + let mut coms_and_evaluations: Vec> = vec![]; + + coms_and_evaluations.extend( + (&proof_comms.witness_comms) + .into_iter() + .zip(&proof_evals.witness_evals) + .map(|(commitment, point_eval)| Evaluation { + commitment: commitment.clone(), + evaluations: vec![vec![point_eval.zeta], vec![point_eval.zeta_omega]], + }), + ); + + coms_and_evaluations.extend( + (fixed_selectors_comms) + .into_iter() + .zip(proof_evals.fixed_selectors_evals.iter()) + .map(|(commitment, point_eval)| Evaluation { + commitment, + evaluations: vec![vec![point_eval.zeta], vec![point_eval.zeta_omega]], + }), + ); + + // -- Absorb all coms_and_evaluations + let fq_sponge_before_coms_and_evaluations = fq_sponge.clone(); + let mut fr_sponge = EFrSponge::new(G::sponge_params()); + fr_sponge.absorb(&fq_sponge.digest()); + + for PointEvaluations { zeta, zeta_omega } in (&proof_evals.witness_evals).into_iter() { + fr_sponge.absorb(zeta); + fr_sponge.absorb(zeta_omega); + } + + for PointEvaluations { zeta, zeta_omega } in proof_evals.fixed_selectors_evals.as_ref().iter() { + fr_sponge.absorb(zeta); + fr_sponge.absorb(zeta_omega); + } + + // Compute [ft(X)] = \ + // (1 - ζ^n) \ + // ([t_0(X)] + ζ^n [t_1(X)] + ... + ζ^{kn} [t_{k}(X)]) + let ft_comm = { + let evaluation_point_to_domain_size = zeta.pow([domain.d1.size]); + let chunked_t_comm = proof_comms + .t_comm + .chunk_commitment(evaluation_point_to_domain_size); + + // (1 - ζ^n) + let minus_vanishing_poly_at_zeta: Fp = -domain.d1.vanishing_polynomial().evaluate(&zeta); + chunked_t_comm.scale(minus_vanishing_poly_at_zeta) + }; + + let ft_eval0 = { + let witness_evals_vecs = (&proof_evals.witness_evals) + .into_iter() + .map(|x| vec![x.zeta]) + .collect::>() + .try_into() + .unwrap(); + let fixed_selectors_evals_vecs = proof_evals + .fixed_selectors_evals + .into_iter() + .map(|x| vec![x.zeta]) + .collect::>() + .try_into() + .unwrap(); + let error_vec = vec![proof_evals.error_vec.zeta]; + + let alphas = proof.alphas.clone(); + let challenges = proof.challenges; + let u = proof.u; + + let eval_env: GenericEvalEnv> = { + let ext_witness = ExtendedWitness { + witness: PlonkishWitnessGeneric { + witness: witness_evals_vecs, + fixed_selectors: fixed_selectors_evals_vecs, + phantom: std::marker::PhantomData, + }, + extended: Default::default(), + }; + + GenericEvalEnv { + ext_witness, + alphas, + challenges, + error_vec, + u, + } + }; + + let eval_res: Vec<_> = match eval_env.eval_naive_fcompat(combined_expr) { + EvalLeaf::Result(x) => x, + EvalLeaf::Col(x) => x.to_vec(), + _ => panic!("eval_leaf is not Result"), + }; + + // Note the minus! ft polynomial at zeta (ft_eval0) is minus evaluation of the expression. + -eval_res[0] + }; + + coms_and_evaluations.push(Evaluation { + commitment: ft_comm, + evaluations: vec![vec![ft_eval0], vec![proof_evals.ft_eval1]], + }); + + fr_sponge.absorb(&proof_evals.ft_eval1); + // -- End absorb all coms_and_evaluations + + let v_chal = fr_sponge.challenge(); + let v = v_chal.to_field(endo_r); + let u_chal = fr_sponge.challenge(); + let u = u_chal.to_field(endo_r); + + let combined_inner_product = { + let es: Vec<_> = coms_and_evaluations + .iter() + .map(|Evaluation { evaluations, .. }| evaluations.clone()) + .collect(); + + combined_inner_product(&v, &u, es.as_slice()) + }; + + let batch = BatchEvaluationProof { + sponge: fq_sponge_before_coms_and_evaluations, + evaluations: coms_and_evaluations, + evaluation_points: vec![zeta, zeta_omega], + polyscale: v, + evalscale: u, + opening: opening_proof, + combined_inner_product, + }; + + let group_map = ::Map::setup(); + OpenProof::verify(srs, &group_map, &mut [batch], &mut thread_rng()) +} diff --git a/ivc/tests/folding_ivc.rs b/ivc/tests/folding_ivc.rs new file mode 100644 index 0000000000..78e1b485c6 --- /dev/null +++ b/ivc/tests/folding_ivc.rs @@ -0,0 +1,141 @@ +use ark_poly::{EvaluationDomain, Radix2EvaluationDomain}; +use folding::{ + instance_witness::Foldable, Alphas, FoldingCompatibleExpr, FoldingConfig, FoldingEnv, + FoldingScheme, Instance, Side, Witness, +}; +use ivc::ivc::{constraints::constrain_ivc, lookups::IVCLookupTable, N_ADDITIONAL_WIT_COL_QUAD}; +use kimchi::circuits::{berkeley_columns::BerkeleyChallengeTerm, gate::CurrOrNext}; +use kimchi_msm::{circuit_design::ConstraintBuilderEnv, columns::Column, Ff1}; +use poly_commitment::{kzg::PairingSRS, SRS as _}; + +#[test] +fn test_regression_additional_columns_reduction_to_degree_2() { + pub type Fp = ark_bn254::Fr; + pub type Curve = ark_bn254::G1Affine; + pub type Pairing = ark_bn254::Bn254; + + // ---- Folding fake structures ---- + #[derive(Hash, Clone, Debug, PartialEq, Eq)] + struct TestConfig; + + #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] + pub enum Challenge { + Beta, + Gamma, + JointCombiner, + } + + impl From for Challenge { + fn from(chal: BerkeleyChallengeTerm) -> Self { + match chal { + BerkeleyChallengeTerm::Beta => Challenge::Beta, + BerkeleyChallengeTerm::Gamma => Challenge::Gamma, + BerkeleyChallengeTerm::JointCombiner => Challenge::JointCombiner, + BerkeleyChallengeTerm::Alpha => panic!("Alpha not allowed in folding expressions"), + } + } + } + + #[derive(Debug, Clone)] + struct TestInstance; + + impl Foldable for TestInstance { + fn combine(_a: Self, _b: Self, _challenge: Fp) -> Self { + unimplemented!() + } + } + + impl Instance for TestInstance { + fn get_alphas(&self) -> &Alphas { + todo!() + } + + fn to_absorb(&self) -> (Vec, Vec) { + todo!() + } + + fn get_blinder(&self) -> Fp { + todo!() + } + } + + #[derive(Clone)] + struct TestWitness; + + impl Foldable for TestWitness { + fn combine(_a: Self, _b: Self, _challenge: Fp) -> Self { + unimplemented!() + } + } + + impl Witness for TestWitness {} + + impl FoldingConfig for TestConfig { + type Column = Column; + + type Selector = (); + + type Challenge = Challenge; + + type Curve = Curve; + + type Srs = PairingSRS; + + type Instance = TestInstance; + + type Witness = TestWitness; + + type Structure = (); + + type Env = Env; + } + + struct Env; + + impl FoldingEnv for Env { + type Structure = (); + + fn new( + _structure: &Self::Structure, + _instances: [&TestInstance; 2], + _witnesses: [&TestWitness; 2], + ) -> Self { + todo!() + } + + fn col(&self, _col: Column, _curr_or_next: CurrOrNext, _side: Side) -> &[Fp] { + todo!() + } + + fn challenge(&self, _challenge: Challenge, _side: Side) -> Fp { + todo!() + } + + fn selector(&self, _s: &(), _side: Side) -> &[Fp] { + todo!() + } + } + + // ---- Folding fake structures ---- + + let mut constraint_env = ConstraintBuilderEnv::>::create(); + constrain_ivc::(&mut constraint_env); + let constraints = constraint_env.get_relation_constraints(); + + let domain = Radix2EvaluationDomain::::new(2).unwrap(); + let srs = PairingSRS::create(2); + srs.get_lagrange_basis(domain); + + let folding_compat_expresions: Vec> = constraints + .into_iter() + .map(FoldingCompatibleExpr::from) + .collect(); + let (scheme, _) = + FoldingScheme::::new(folding_compat_expresions, &srs, domain, &()); + assert_eq!( + scheme.get_number_of_additional_columns(), + N_ADDITIONAL_WIT_COL_QUAD, + "Expected {N_ADDITIONAL_WIT_COL_QUAD}, got {}", + scheme.get_number_of_additional_columns(), + ); +} diff --git a/ivc/tests/simple.rs b/ivc/tests/simple.rs new file mode 100644 index 0000000000..c651e33652 --- /dev/null +++ b/ivc/tests/simple.rs @@ -0,0 +1,1046 @@ +//! This is a simple example of how to use the folding crate with the IVC crate +//! to fold a simple addition circuit. The addition circuit consists of a single +//! constraint of degree 1 over 3 columns (A + B - C = 0). + +use ark_ec::AffineRepr; +use ark_ff::{One, PrimeField, UniformRand, Zero}; +use ark_poly::{Evaluations, Radix2EvaluationDomain as R2D}; +use folding::{ + eval_leaf::EvalLeaf, expressions::FoldingColumnTrait, instance_witness::ExtendedWitness, + standard_config::StandardConfig, Alphas, FoldingCompatibleExpr, FoldingOutput, FoldingScheme, +}; +use ivc::{ + self, + expr_eval::{GenericVecStructure, SimpleEvalEnv}, + ivc::{ + columns::{IVCColumn, N_BLOCKS, N_FSEL_IVC}, + constraints::constrain_ivc, + interpreter::{build_fixed_selectors, ivc_circuit, ivc_circuit_base_case}, + lookups::IVCLookupTable, + N_ADDITIONAL_WIT_COL_QUAD as N_COL_QUAD_IVC, N_ALPHAS as N_ALPHAS_IVC, + }, + plonkish_lang::{PlonkishChallenge, PlonkishInstance, PlonkishWitness}, + poseidon_8_56_5_3_2::bn254::PoseidonBN254Parameters, +}; +use kimchi::{ + circuits::{domains::EvaluationDomains, expr::Variable, gate::CurrOrNext}, + curve::KimchiCurve, +}; +use kimchi_msm::{ + circuit_design::{ + ColAccessCap, ColWriteCap, ConstraintBuilderEnv, HybridCopyCap, WitnessBuilderEnv, + }, + columns::{Column, ColumnIndexer}, + expr::E, + logup::LogupWitness, + lookups::DummyLookupTable, + proof::ProofInputs, + witness::Witness as GenericWitness, + BN254G1Affine, Fp, +}; +use mina_poseidon::{ + constants::PlonkSpongeConstantsKimchi, + sponge::{DefaultFqSponge, DefaultFrSponge}, + FqSponge, +}; +use poly_commitment::{kzg::PairingSRS, PolyComm, SRS as _}; +use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _}; +use std::{collections::BTreeMap, ops::Index}; +use strum::EnumCount; +use strum_macros::{EnumCount as EnumCountMacro, EnumIter}; + +/// The base field of the curve +/// Used to encode the polynomial commitments +pub type Fq = ark_bn254::Fq; +/// The curve we commit into +pub type Curve = BN254G1Affine; +pub type Pairing = ark_bn254::Bn254; + +pub type BaseSponge = DefaultFqSponge; +pub type ScalarSponge = DefaultFrSponge; +pub type SpongeParams = PlonkSpongeConstantsKimchi; + +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, EnumIter, EnumCountMacro, Hash)] +pub enum AdditionColumn { + A, + B, + C, +} + +impl ColumnIndexer for AdditionColumn { + const N_COL: usize = 3; + + fn to_column(self) -> Column { + match self { + AdditionColumn::A => Column::Relation(0), + AdditionColumn::B => Column::Relation(1), + AdditionColumn::C => Column::Relation(2), + } + } +} + +/// Simply compute A + B - C +pub fn interpreter_simple_add< + F: PrimeField, + Env: ColAccessCap + HybridCopyCap, +>( + env: &mut Env, +) { + let a = env.read_column(AdditionColumn::A); + let b = env.read_column(AdditionColumn::B); + let c = env.read_column(AdditionColumn::C); + let eq = a.clone() * a.clone() * b.clone() - c; + env.assert_zero(eq); +} + +// Ignoring this test for now. +// When run with the code coverage, it takes hours, and crashes the CI, even +// though it takes less than 10 minutes to run on a desktop without test +// coverage. +// The code isn't used, and we don't plan to use it in the short term, so it's +// not a big deal. +// Also, the code wasn't in a good state. +#[test] +#[ignore] +pub fn heavy_test_simple_add() { + let mut rng = o1_utils::tests::make_test_rng(None); + // FIXME: this has to be 1 << 15. the 16 is temporary, since we + // are at 35783 rows right now, but can only allow 32768. Light + // future optimisations will get us back to 15. + let domain_size: usize = 1 << 16; + let domain = EvaluationDomains::::create(domain_size).unwrap(); + + let mut fq_sponge: BaseSponge = FqSponge::new(Curve::other_curve_sponge_params()); + + let srs = kimchi_msm::precomputed_srs::get_bn254_srs(domain); + + // Total number of fixed selectors in the circuit for APP + IVC. + // There is no fixed selector in the APP circuit. + const N_FSEL_TOTAL: usize = N_FSEL_IVC; + + // Total number of witness columns in IVC. The blocks are public selectors. + const N_WIT_IVC: usize = ::N_COL - N_FSEL_IVC; + + // Number of witness columns in the circuit. + // It consists of the columns of the inner circuit and the columns for the + // IVC circuit. + const N_COL_TOTAL: usize = AdditionColumn::COUNT + N_WIT_IVC; + // One extra quad column in APP + const N_COL_QUAD: usize = N_COL_QUAD_IVC + 1; + const N_COL_TOTAL_QUAD: usize = N_COL_TOTAL + N_COL_QUAD; + + // Our application circuit has two constraints. + const N_ALPHAS_APP: usize = 1; + // Total number of challenges required by the circuit APP + IVC. + // The number of challenges required by the IVC is defined by the library. + // Therefore, we only need to add the challenges required by the specific + // application. + const N_ALPHAS: usize = N_ALPHAS_IVC + N_ALPHAS_APP; + // Number of extra quad constraints happens to be the same as + // extra quad columns. + const N_ALPHAS_QUAD: usize = N_ALPHAS + N_COL_QUAD; + + // There are two more challenges though. + // Not used at the moment as IVC circuit only handles alphas + const _N_CHALS: usize = N_ALPHAS + PlonkishChallenge::COUNT; + + println!("N_FSEL_TOTAL: {N_FSEL_TOTAL}"); + println!("N_COL_TOTAL: {N_COL_TOTAL}"); + println!("N_COL_TOTAL_QUAD: {N_COL_TOTAL_QUAD}"); + println!("N_ALPHAS: {N_ALPHAS}"); + println!("N_ALPHAS_QUAD: {N_ALPHAS_QUAD}"); + + // ---- Defining the folding configuration ---- + // FoldingConfig + impl FoldingColumnTrait for AdditionColumn { + fn is_witness(&self) -> bool { + true + } + } + + type AppWitnessBuilderEnv = WitnessBuilderEnv< + Fp, + AdditionColumn, + { AdditionColumn::COUNT }, + { AdditionColumn::COUNT }, + 0, + 0, + DummyLookupTable, + >; + + impl Index + for PlonkishWitness + { + type Output = [Fp]; + + fn index(&self, index: AdditionColumn) -> &Self::Output { + match index { + AdditionColumn::A => &self.witness.cols[0].evals, + AdditionColumn::B => &self.witness.cols[1].evals, + AdditionColumn::C => &self.witness.cols[2].evals, + } + } + } + + type Config< + const N_COL_TOTAL: usize, + const N_CHALS: usize, + const N_FSEL: usize, + const N_ALPHAS: usize, + > = StandardConfig< + Curve, + Column, + PlonkishChallenge, + PlonkishInstance, // TODO check if it's quad or not + PlonkishWitness, + PairingSRS, + (), + GenericVecStructure, + >; + type MainTestConfig = Config; + + //////////////////////////////////////////////////////////////////////////// + // Fixed Selectors + //////////////////////////////////////////////////////////////////////////// + + // ---- Start build the witness environment for the IVC + // Start building the constants of the circuit. + // For the IVC, we have all the "block selectors" - which depends on the + // number of columns of the circuit - and the poseidon round constants. + let ivc_fixed_selectors: Vec> = + build_fixed_selectors::(domain_size).to_vec(); + + // Sanity check on the domain size, can be removed later + assert_eq!(ivc_fixed_selectors.len(), N_FSEL_TOTAL); + ivc_fixed_selectors.iter().for_each(|s| { + assert_eq!(s.len(), domain_size); + }); + + let ivc_fixed_selectors_evals_d1: Vec>> = (&ivc_fixed_selectors) + .into_par_iter() + .map(|w| Evaluations::from_vec_and_domain(w.to_vec(), domain.d1)) + .collect(); + + let structure = GenericVecStructure( + ivc_fixed_selectors_evals_d1 + .iter() + .map(|x| x.evals.clone()) + .collect(), + ); + + //////////////////////////////////////////////////////////////////////////// + // Constraints + //////////////////////////////////////////////////////////////////////////// + + let app_constraints: Vec> = { + let mut constraint_env = ConstraintBuilderEnv::::create(); + interpreter_simple_add::(&mut constraint_env); + // Don't use lookups for now + constraint_env.get_relation_constraints() + }; + + let ivc_constraints: Vec> = { + let mut ivc_constraint_env = ConstraintBuilderEnv::>::create(); + constrain_ivc::(&mut ivc_constraint_env); + ivc_constraint_env.get_relation_constraints() + }; + + // Sanity check: we want to verify that we only have maximum degree 5 + // constraints. + // FIXME: we only want degree 2 (+1 for the selector in IVC). + // FIXME: we do have degree 5 as the fold iteration and the public selectors + // add one + assert_eq!( + app_constraints + .iter() + .chain(ivc_constraints.iter()) + .map(|c| c.degree(1, 0)) + .max(), + Some(4) + ); + + // Make the constraints folding compatible + let app_compat_constraints: Vec> = app_constraints + .into_iter() + .map(|x| FoldingCompatibleExpr::from(x.clone())) + .collect(); + + let ivc_compat_constraints: Vec> = { + let ivc_compat_constraints: Vec> = ivc_constraints + .into_iter() + .map(|x| FoldingCompatibleExpr::from(x.clone())) + .collect(); + + // IVC column expression should be shifted to the right to accomodate + // app witness. + let ivc_mapper = &(|Variable { col, row }| { + let rel_offset: usize = AdditionColumn::COUNT; + let fsel_offset: usize = 0; + let dsel_offset: usize = 0; + use kimchi_msm::columns::Column::*; + let new_col = match col { + Relation(i) => Relation(i + rel_offset), + FixedSelector(i) => FixedSelector(i + fsel_offset), + DynamicSelector(i) => DynamicSelector(i + dsel_offset), + c @ LookupPartialSum(_) => c, + c @ LookupMultiplicity(_) => c, + c @ LookupFixedTable(_) => c, + c @ LookupAggregation => c, + }; + Variable { col: new_col, row } + }); + + ivc_compat_constraints + .into_iter() + .map(|e| e.map_variable(ivc_mapper)) + .collect() + }; + + let folding_compat_constraints: Vec> = + app_compat_constraints + .into_iter() + .chain(ivc_compat_constraints) + .collect(); + + // We have as many alphas as constraints + assert!( + folding_compat_constraints.len() == N_ALPHAS, + "Folding compat constraints: expected {N_ALPHAS:?} got {}", + folding_compat_constraints.len() + ); + + // real_folding_compat_constraint is actual constraint + // it cannot be mapped back to Fp + // has some u and {alpha^i} + // this one needs to be used in prover(..). + let (folding_scheme, real_folding_compat_constraint) = FoldingScheme::::new( + folding_compat_constraints.clone(), + &srs, + domain.d1, + &structure, + ); + + let additional_columns = folding_scheme.get_number_of_additional_columns(); + + assert_eq!( + additional_columns, N_COL_QUAD, + "Expected {N_COL_QUAD} additional quad columns, got {additional_columns}" + ); + + //////////////////////////////////////////////////////////////////////////// + // Witness step 1 (APP + IVC) + //////////////////////////////////////////////////////////////////////////// + + println!("Building witness step 1"); + + type IVCWitnessBuilderEnvRaw = + WitnessBuilderEnv; + type LT = IVCLookupTable; + + let mut ivc_witness_env_0 = IVCWitnessBuilderEnvRaw::::create(); + + let mut app_witness_one: AppWitnessBuilderEnv = WitnessBuilderEnv::create(); + + let empty_logups: BTreeMap> = BTreeMap::new(); + + // Witness one + for _i in 0..domain_size { + let a: Fp = Fp::rand(&mut rng); + let b: Fp = Fp::rand(&mut rng); + app_witness_one.write_column(AdditionColumn::A, &a); + app_witness_one.write_column(AdditionColumn::B, &b); + app_witness_one.write_column(AdditionColumn::C, &(a * a * b)); + interpreter_simple_add(&mut app_witness_one); + app_witness_one.next_row(); + } + + let proof_inputs_one = ProofInputs { + evaluations: app_witness_one.get_relation_witness(domain_size), + logups: empty_logups.clone(), + }; + assert!(proof_inputs_one.evaluations.len() == 3); + + ivc_witness_env_0.set_fixed_selectors(ivc_fixed_selectors.clone()); + ivc_circuit_base_case::( + &mut ivc_witness_env_0, + domain_size, + ); + let ivc_proof_inputs_0 = ProofInputs { + evaluations: ivc_witness_env_0.get_relation_witness(domain_size), + logups: empty_logups.clone(), + }; + assert!(ivc_proof_inputs_0.evaluations.len() == N_WIT_IVC); + for i in 0..10 { + assert!( + ivc_proof_inputs_0.evaluations[0][i] == Fp::zero(), + "Iteration column row #{i:?} must be zero" + ); + } + + // FIXME this merely concatenates two witnesses. Most likely, we + // want to intersperse them in a smarter way later. Our witness is + // Relation || dynamic. + let joint_witness_one: Vec<_> = proof_inputs_one + .evaluations + .into_iter() + .chain(ivc_proof_inputs_0.evaluations.clone()) + .collect(); + + assert!(joint_witness_one.len() == N_COL_TOTAL); + + let folding_witness_one = PlonkishWitness { + witness: joint_witness_one + .into_par_iter() + .map(|w| Evaluations::from_vec_and_domain(w.to_vec(), domain.d1)) + .collect(), + fixed_selectors: ivc_fixed_selectors_evals_d1.clone().try_into().unwrap(), + phantom: std::marker::PhantomData, + }; + + let folding_instance_one = PlonkishInstance::from_witness( + &folding_witness_one.witness, + &mut fq_sponge, + &srs.full_srs, + domain.d1, + ); + + //////////////////////////////////////////////////////////////////////////// + // Witness step 2 + //////////////////////////////////////////////////////////////////////////// + + println!("Building witness step 2"); + + let mut app_witness_two: AppWitnessBuilderEnv = WitnessBuilderEnv::create(); + + // Witness two + for _i in 0..domain_size { + let a: Fp = Fp::rand(&mut rng); + let b: Fp = Fp::rand(&mut rng); + app_witness_two.write_column(AdditionColumn::A, &a); + app_witness_two.write_column(AdditionColumn::B, &b); + app_witness_two.write_column(AdditionColumn::C, &(a * a * b)); + interpreter_simple_add(&mut app_witness_two); + app_witness_two.next_row(); + } + + let proof_inputs_two = ProofInputs { + evaluations: app_witness_two.get_relation_witness(domain_size), + logups: empty_logups.clone(), + }; + + // IVC for the second witness is the same as for the first one, + // since they're both height 0. + let joint_witness_two: Vec<_> = proof_inputs_two + .evaluations + .into_iter() + .chain(ivc_proof_inputs_0.evaluations) + .collect(); + + let folding_witness_two_evals: GenericWitness>> = + joint_witness_two + .into_par_iter() + .map(|w| Evaluations::from_vec_and_domain(w.to_vec(), domain.d1)) + .collect(); + let folding_witness_two = PlonkishWitness { + witness: folding_witness_two_evals.clone(), + fixed_selectors: ivc_fixed_selectors_evals_d1.clone().try_into().unwrap(), + phantom: std::marker::PhantomData, + }; + + let folding_instance_two = PlonkishInstance::from_witness( + &folding_witness_two.witness, + &mut fq_sponge, + &srs.full_srs, + domain.d1, + ); + + //////////////////////////////////////////////////////////////////////////// + // Folding 1 + //////////////////////////////////////////////////////////////////////////// + + println!("Folding 1"); + + // To start, we only fold two instances. + // We must call fold_instance_witness_pair in a nested call `fold`. + // Something like: + // ``` + // instances.tail().fold(instances.head(), |acc, two| { + // let folding_output = folding_scheme.fold_instance_witness_pair(acc, two, &mut fq_sponge); + // // Compute IVC + // // ... + // // We return the folding instance, which will be used in the next + // // iteration, which is the folded value `acc` with `two`. + // folding_instance + // }); + // ``` + println!("fold_instance_witness_pair"); + let folding_output_one = folding_scheme.fold_instance_witness_pair( + (folding_instance_one, folding_witness_one), + (folding_instance_two, folding_witness_two), + &mut fq_sponge, + ); + println!("Folding 1 succeeded"); + + let FoldingOutput { + folded_instance: folded_instance_one, + // Should not be required for the IVC circuit as it is encoding the + // verifier. + folded_witness: folded_witness_one, + .. + } = folding_output_one; + + // The polynomial of the computation is linear, therefore, the error terms + // are zero + assert_ne!(folding_output_one.t_0.get_first_chunk(), Curve::zero()); + assert_ne!(folding_output_one.t_1.get_first_chunk(), Curve::zero()); + + // Sanity check that the u values are the same. The u value is there to + // homogeneoize the polynomial describing the NP relation. + assert_eq!( + folding_output_one.relaxed_extended_left_instance.u, + folding_output_one.relaxed_extended_right_instance.u + ); + + // 1. Get all the commitments from the left instance. + // We want a way to get also the potential additional columns. + let mut comms_left: Vec = Vec::with_capacity(N_COL_TOTAL_QUAD); + comms_left.extend( + folding_output_one + .relaxed_extended_left_instance + .extended_instance + .instance + .commitments, + ); + + // Additional columns of quadri + { + let extended = folding_output_one + .relaxed_extended_left_instance + .extended_instance + .extended; + let extended_comms: Vec<_> = extended.iter().map(|x| x.get_first_chunk()).collect(); + comms_left.extend(extended_comms.clone()); + extended_comms.iter().enumerate().for_each(|(i, x)| { + assert_ne!( + x, + &Curve::zero(), + "Left extended commitment number {i:?} is zero" + ); + }); + } + assert_eq!(comms_left.len(), N_COL_TOTAL_QUAD); + // Checking they are all not zero. + comms_left.iter().enumerate().for_each(|(i, c)| { + assert_ne!(c, &Curve::zero(), "Left commitment number {i:?} is zero"); + }); + + // IVC is expecting the coordinates. + let comms_left: [(Fq, Fq); N_COL_TOTAL_QUAD] = + std::array::from_fn(|i| (comms_left[i].x, comms_left[i].y)); + + // 2. Get all the commitments from the right instance. + // We want a way to get also the potential additional columns. + let mut comms_right = Vec::with_capacity(N_COL_TOTAL_QUAD); + comms_right.extend( + folding_output_one + .relaxed_extended_left_instance + .extended_instance + .instance + .commitments, + ); + { + let extended = folding_output_one + .relaxed_extended_right_instance + .extended_instance + .extended; + comms_right.extend(extended.iter().map(|x| x.get_first_chunk())); + } + assert_eq!(comms_right.len(), N_COL_TOTAL_QUAD); + // Checking they are all not zero. + comms_right.iter().enumerate().for_each(|(i, c)| { + assert_ne!(c, &Curve::zero(), "Right commitment number {i:?} is zero"); + }); + + // IVC is expecting the coordinates. + let comms_right: [(Fq, Fq); N_COL_TOTAL_QUAD] = + std::array::from_fn(|i| (comms_right[i].x, comms_right[i].y)); + + // 3. Get all the commitments from the folded instance. + // We want a way to get also the potential additional columns. + let mut comms_out = Vec::with_capacity(AdditionColumn::N_COL + additional_columns); + comms_out.extend(folded_instance_one.extended_instance.instance.commitments); + { + let extended = folded_instance_one.extended_instance.extended.clone(); + comms_out.extend(extended.iter().map(|x| x.get_first_chunk())); + } + // Checking they are all not zero. + comms_out.iter().for_each(|c| { + assert_ne!(c, &Curve::zero()); + }); + + // IVC is expecting the coordinates. + let comms_out: [(Fq, Fq); N_COL_TOTAL_QUAD] = + std::array::from_fn(|i| (comms_out[i].x, comms_out[i].y)); + + // FIXME: Should be handled in folding + let left_error_term = srs + .full_srs + .mask_custom( + folding_output_one + .relaxed_extended_left_instance + .error_commitment, + &PolyComm::new(vec![Fp::one()]), + ) + .unwrap() + .commitment; + + // FIXME: Should be handled in folding + let right_error_term = srs + .full_srs + .mask_custom( + folding_output_one + .relaxed_extended_right_instance + .error_commitment, + &PolyComm::new(vec![Fp::one()]), + ) + .unwrap() + .commitment; + + let error_terms = [ + left_error_term.get_first_chunk(), + right_error_term.get_first_chunk(), + folded_instance_one.error_commitment.get_first_chunk(), + ]; + error_terms.iter().for_each(|c| { + assert_ne!(c, &Curve::zero()); + }); + + let error_terms: [(Fq, Fq); 3] = std::array::from_fn(|i| (error_terms[i].x, error_terms[i].y)); + + let t_terms = [ + folding_output_one.t_0.get_first_chunk(), + folding_output_one.t_1.get_first_chunk(), + ]; + t_terms.iter().for_each(|c| { + assert_ne!(c, &Curve::zero()); + }); + let t_terms: [(Fq, Fq); 2] = std::array::from_fn(|i| (t_terms[i].x, t_terms[i].y)); + + let u = folding_output_one.relaxed_extended_left_instance.u; + + // FIXME: add columns of the previous IVC circuit in the comms_left, + // comms_right and comms_out. Can be faked. We should have 400 + 3 columns + let all_ivc_comms_left: [(Fq, Fq); N_COL_TOTAL_QUAD] = std::array::from_fn(|i| { + if i < IVCColumn::N_COL { + comms_left[0] + } else { + comms_left[i - IVCColumn::N_COL] + } + }); + let all_ivc_comms_right: [(Fq, Fq); N_COL_TOTAL_QUAD] = std::array::from_fn(|i| { + if i < IVCColumn::N_COL { + comms_right[0] + } else { + comms_right[i - IVCColumn::N_COL] + } + }); + let all_ivc_comms_out: [(Fq, Fq); N_COL_TOTAL_QUAD] = std::array::from_fn(|i| { + if i < IVCColumn::N_COL { + comms_out[0] + } else { + comms_out[i - IVCColumn::N_COL] + } + }); + + let mut ivc_witness_env_1 = IVCWitnessBuilderEnvRaw::::create(); + ivc_witness_env_1.set_fixed_selectors(ivc_fixed_selectors.clone()); + + let alphas: Vec<_> = folded_instance_one + .extended_instance + .instance + .alphas + .clone() + .powers(); + assert!( + alphas.len() == N_ALPHAS_QUAD, + "Alphas length mismatch: expected {N_ALPHAS_QUAD} got {:?}", + alphas.len() + ); + + ivc_circuit::( + &mut ivc_witness_env_1, + 1, + Box::new(all_ivc_comms_left), + Box::new(all_ivc_comms_right), + Box::new(all_ivc_comms_out), + error_terms, + t_terms, + u, + o1_utils::array::vec_to_boxed_array(alphas), + &PoseidonBN254Parameters, + domain_size, + ); + + let ivc_proof_inputs_1 = ProofInputs { + evaluations: ivc_witness_env_1.get_relation_witness(domain_size), + logups: empty_logups.clone(), + }; + assert!(ivc_proof_inputs_1.evaluations.len() == N_WIT_IVC); + + //////////////////////////////////////////////////////////////////////////// + // Witness step 3 + //////////////////////////////////////////////////////////////////////////// + + println!("Witness step 3"); + + let mut app_witness_three: WitnessBuilderEnv = + WitnessBuilderEnv::create(); + + // Witness three + for _i in 0..domain_size { + let a: Fp = Fp::rand(&mut rng); + let b: Fp = Fp::rand(&mut rng); + app_witness_three.write_column(AdditionColumn::A, &a); + app_witness_three.write_column(AdditionColumn::B, &b); + app_witness_three.write_column(AdditionColumn::C, &(a * a * b)); + interpreter_simple_add(&mut app_witness_three); + app_witness_three.next_row(); + } + + let proof_inputs_three = ProofInputs { + evaluations: app_witness_three.get_relation_witness(domain_size), + logups: empty_logups.clone(), + }; + + // Here we concatenate with ivc_proof_inputs 1, inductive case + let joint_witness_three: Vec<_> = (proof_inputs_three.evaluations) + .into_iter() + .chain(ivc_proof_inputs_1.evaluations) + .collect(); + + assert!(joint_witness_three.len() == N_COL_TOTAL); + + let folding_witness_three_evals: Vec>> = (&joint_witness_three) + .into_par_iter() + .map(|w| Evaluations::from_vec_and_domain(w.to_vec(), domain.d1)) + .collect(); + let folding_witness_three = PlonkishWitness { + witness: folding_witness_three_evals.clone().try_into().unwrap(), + fixed_selectors: ivc_fixed_selectors_evals_d1.clone().try_into().unwrap(), + phantom: std::marker::PhantomData, + }; + + let mut fq_sponge_before_instance_three = fq_sponge.clone(); + + let folding_instance_three = PlonkishInstance::from_witness( + &folding_witness_three.witness, + &mut fq_sponge, + &srs.full_srs, + domain.d1, + ); + + //////////////////////////////////////////////////////////////////////////// + // Folding 2 + //////////////////////////////////////////////////////////////////////////// + + println!("Folding two"); + + let mut fq_sponge_before_last_fold = fq_sponge.clone(); + + let folding_output_two = folding_scheme.fold_instance_witness_pair( + (folded_instance_one.clone(), folded_witness_one.clone()), + ( + folding_instance_three.clone(), + folding_witness_three.clone(), + ), + &mut fq_sponge, + ); + + let folded_instance_two = folding_output_two.folded_instance; + let folded_witness_two = folding_output_two.folded_witness; + + //////////////////////////////////////////////////////////////////////////// + // Testing folding exprs validity for last fold + //////////////////////////////////////////////////////////////////////////// + + let enlarge_to_domain_generic = |evaluations: &Evaluations>, + new_domain: R2D| { + assert!(evaluations.domain() == domain.d1); + evaluations + .interpolate_by_ref() + .evaluate_over_domain(new_domain) + }; + + { + println!("Testing individual expressions validity; creating evaluations"); + + let simple_eval_env: SimpleEvalEnv = { + let enlarge_to_domain = |evaluations: &Evaluations>| { + enlarge_to_domain_generic(evaluations, domain.d8) + }; + + let alpha = fq_sponge.challenge(); + let alphas = Alphas::new_sized(alpha, N_ALPHAS); + assert!( + alphas.clone().powers().len() == N_ALPHAS, + "Expected N_ALPHAS = {N_ALPHAS:?}, got {}", + alphas.clone().powers().len() + ); + + let beta = fq_sponge.challenge(); + let gamma = fq_sponge.challenge(); + let joint_combiner = fq_sponge.challenge(); + let challenges = [beta, gamma, joint_combiner]; + + SimpleEvalEnv { + ext_witness: ExtendedWitness { + witness: PlonkishWitness { + witness: (&folding_witness_three_evals) + .into_par_iter() + .map(enlarge_to_domain) + .collect(), + fixed_selectors: (&ivc_fixed_selectors_evals_d1) + .into_par_iter() + .map(enlarge_to_domain) + .collect(), + phantom: std::marker::PhantomData, + }, + extended: BTreeMap::new(), // No extended columns at this point + }, + alphas, + challenges, + error_vec: Evaluations::from_vec_and_domain(vec![], domain.d1), + u: Fp::zero(), + } + }; + + { + let target_expressions: Vec> = + folding_compat_constraints.clone(); + + for (expr_i, expr) in target_expressions.iter().enumerate() { + let eval_leaf = simple_eval_env.eval_naive_fcompat(expr); + + let evaluations_d8 = match eval_leaf { + EvalLeaf::Result(evaluations_d8) => evaluations_d8, + EvalLeaf::Col(evaluations_d8) => evaluations_d8.to_vec(), + _ => panic!("eval_leaf is not Result"), + }; + + let interpolated = + Evaluations::from_vec_and_domain(evaluations_d8.clone(), domain.d8) + .interpolate(); + if !interpolated.is_zero() { + let (_, remainder) = interpolated + .divide_by_vanishing_poly(domain.d1) + .unwrap_or_else(|| panic!("Cannot divide by vanishing polynomial")); + if !remainder.is_zero() { + panic!( + "Remainder is not zero for expression #{expr_i}: {}", + expr.to_string() + ); + } + } + } + + println!("All folding_compat_constraints for APP+(nontrivial) IVC satisfy FoldingExps"); + } + } + + //////////////////////////////////////////////////////////////////////////// + // Testing folding exprs validity with quadraticization + //////////////////////////////////////////////////////////////////////////// + + { + println!("Testing joint folding expression validity /with quadraticization/; creating evaluations"); + + // We can evaluate on d1, and then if the interpolated + // polynomial is 0, the expression holds. This is fast to do, + // and it effectively checks if the expressions hold. + // + // However this is not enough for computing quotient, since + // folding expressions are degree ... 2 or 3? So when this + // variable is set to domain.d8, all the evaluations will + // happen over d8, and quotient_polyonmial computation becomes + // possible. But this is 8 times slower. + let evaluation_domain = domain.d1; + + let enlarge_to_domain = |evaluations: &Evaluations>| { + enlarge_to_domain_generic(evaluations, evaluation_domain) + }; + + let simple_eval_env: SimpleEvalEnv = { + let ext_witness = ExtendedWitness { + witness: PlonkishWitness { + witness: (&folded_witness_two.extended_witness.witness.witness) + .into_par_iter() + .map(enlarge_to_domain) + .collect(), + fixed_selectors: (&ivc_fixed_selectors_evals_d1) + .into_par_iter() + .map(enlarge_to_domain) + .collect(), + phantom: std::marker::PhantomData, + }, + extended: folded_witness_two + .extended_witness + .extended + .iter() + .map(|(ix, evals)| (*ix, enlarge_to_domain(evals))) + .collect(), + }; + + SimpleEvalEnv { + ext_witness, + alphas: folded_instance_two + .extended_instance + .instance + .alphas + .clone(), + challenges: folded_instance_two.extended_instance.instance.challenges, + error_vec: enlarge_to_domain(&folded_witness_two.error_vec), + u: folded_instance_two.u, + } + }; + + { + let expr: FoldingCompatibleExpr = + real_folding_compat_constraint.clone(); + + let eval_leaf = simple_eval_env.eval_naive_fcompat(&expr); + + let evaluations_big = match eval_leaf { + EvalLeaf::Result(evaluations) => evaluations, + EvalLeaf::Col(evaluations) => evaluations.to_vec(), + _ => panic!("eval_leaf is not Result"), + }; + + let interpolated = + Evaluations::from_vec_and_domain(evaluations_big, evaluation_domain).interpolate(); + if !interpolated.is_zero() { + let (_, remainder) = interpolated + .divide_by_vanishing_poly(domain.d1) + .unwrap_or_else(|| panic!("ERROR: Cannot divide by vanishing polynomial")); + if !remainder.is_zero() { + panic!( + "ERROR: Remainder is not zero for joint expression: {}", + expr.to_string() + ); + } else { + println!("Interpolated expression is divisible by vanishing poly d1"); + } + } else { + println!("Interpolated expression is zero"); + } + } + } + + //////////////////////////////////////////////////////////////////////////// + // SNARKing everything + //////////////////////////////////////////////////////////////////////////// + + // quad columns become regular witness columns + let real_folding_compat_constraint_quad_merged: FoldingCompatibleExpr = { + let noquad_mapper = &(|quad_index: usize| { + let col = kimchi_msm::columns::Column::Relation(N_COL_TOTAL + quad_index); + Variable { + col, + row: CurrOrNext::Curr, + } + }); + + real_folding_compat_constraint + .clone() + .flatten_quad_columns(noquad_mapper) + }; + + println!("Creating a proof"); + + let proof = ivc::prover::prove::< + BaseSponge, + ScalarSponge, + MainTestConfig, + _, + N_COL_TOTAL, + N_COL_TOTAL_QUAD, + N_COL_TOTAL, + 0, + N_FSEL_TOTAL, + N_ALPHAS_QUAD, + >( + domain, + &srs, + &real_folding_compat_constraint, + folded_instance_two.clone(), + folded_witness_two.clone(), + &mut rng, + ) + .unwrap(); + + //////////////////////////////////////////////////////////////////////////// + // Verifying; + // below is everything one needs to do to verify the whole computation + //////////////////////////////////////////////////////////////////////////// + + println!("Verifying a proof"); + + let fixed_selectors_verifier = folded_witness_two + .extended_witness + .witness + .fixed_selectors + .cols; + + // Check that the last SNARK is correct + let verifies = ivc::verifier::verify::< + BaseSponge, + ScalarSponge, + MainTestConfig, + N_COL_TOTAL_QUAD, + N_COL_TOTAL_QUAD, + 0, + N_FSEL_TOTAL, + 0, + >( + domain, + &srs, + &real_folding_compat_constraint_quad_merged, + fixed_selectors_verifier, + &proof, + ); + + assert!(verifies, "The proof does not verify"); + + // Check that the last fold is correct + + println!("Checking last fold was done correctly"); + + { + assert!( + folded_instance_two + == folding_scheme.fold_instance_pair( + folding_output_two.relaxed_extended_left_instance, + folding_output_two.relaxed_extended_right_instance, + [folding_output_two.t_0, folding_output_two.t_1], + &mut fq_sponge_before_last_fold, + ), + "Last fold must (natively) verify" + ); + } + + // We have to check that: + // 1. `folding_instance_three` is relaxed (E = 0, u = 1) + // 2. u.x = Hash(n, z0, zn, U) + + // We don't yet do (2) because we don't support public input yet. + // + // And (1) we achieve automatically because + // `folding_instance_three` is a `PlonkishInstance` and not + // `RelaxedInstance`. We only have to check that its `alphas` are + // powers (and not arbitrary elements): + + // Check that `folding_instance_three` vas relaxed. + assert_eq!( + Ok(()), + folding_instance_three.verify_from_witness(&mut fq_sponge_before_instance_three) + ); +} diff --git a/kimchi/Cargo.toml b/kimchi/Cargo.toml index bf3ab6c136..847ff162f9 100644 --- a/kimchi/Cargo.toml +++ b/kimchi/Cargo.toml @@ -25,6 +25,7 @@ num-derive.workspace = true num-integer.workspace = true num-traits.workspace = true itertools.workspace = true +log.workspace = true rand = { workspace = true, features = ["std_rng"] } rand_core.workspace = true rayon.workspace = true @@ -37,11 +38,14 @@ hex.workspace = true strum.workspace = true strum_macros.workspace = true +ocaml = { workspace = true, optional = true } +ocaml-gen = { workspace = true, optional = true } -# TODO: audit this -disjoint-set = "0.0.2" +wasm-bindgen = { workspace = true, optional = true } +internal-tracing.workspace = true +# Internal dependencies turshi.workspace = true poly-commitment.workspace = true groupmap.workspace = true @@ -49,12 +53,6 @@ mina-curves.workspace = true o1-utils.workspace = true mina-poseidon.workspace = true -ocaml = { workspace = true, optional = true } -ocaml-gen = { workspace = true, optional = true } - -wasm-bindgen = { workspace = true, optional = true } - -internal-tracing.workspace = true [dev-dependencies] proptest.workspace = true diff --git a/kimchi/README.md b/kimchi/README.md index 76a71006cc..078376193f 100644 --- a/kimchi/README.md +++ b/kimchi/README.md @@ -1,24 +1,32 @@ # Kimchi -Kimchi is based on [plonk](https://eprint.iacr.org/2019/953.pdf), a zk-SNARK protocol. +Kimchi is based on [PlonK](https://eprint.iacr.org/2019/953.pdf), a zk-SNARK +protocol. ## Benchmarks -To bench kimchi, we have two types of benchmark engines. +To bench kimchi, we have two types of benchmark engines. -[Criterion](https://bheisler.github.io/criterion.rs/) is used to benchmark time: +[Criterion](https://bheisler.github.io/criterion.rs/) is used to benchmark time. +If you have installed +[cargo-criterion](https://github.com/bheisler/cargo-criterion) and +[gnuplot](http://www.gnuplot.info), you can run the benchmarks via ```console $ cargo criterion -p kimchi --bench proof_criterion ``` -The result will appear in `target/criterion/single\ proof/report/index.html` and look like this: +The result will appear in `target/criterion/single\ proof/report/index.html` and +look like this: ![criterion kimchi](https://i.imgur.com/OGqiuHD.png) -Note that it only does 10 passes. To have more accurate statistics, remove the `.sample_size(10)` line from the [bench](benches/proof_criterion.rs). +Note that it only does 10 passes. To have more accurate statistics, remove the +`.sample_size(10)` line from the [bench](benches/proof_criterion.rs). -The other benchmark uses [iai](https://github.com/bheisler/iai) to perform precise one-shot benchmarking. This is useful in CI, for example, where typical benchmarks are affected by the load of the host running CI. +The other benchmark uses [iai](https://github.com/bheisler/iai) to perform +precise one-shot benchmarking. This is useful in CI, for example, where typical +benchmarks are affected by the load of the host running CI. ```console $ cargo bench -p kimchi --bench proof_iai @@ -52,5 +60,6 @@ To obtain a flamegraph: ``` the [binary](src/bin/flamegraph.rs) will run forever, so you have to C-c to exit and produce the `flamegraph.svg` file. -Note: lots of good advice on system performance in the [flamegraph repo](https://github.com/flamegraph-rs/flamegraph#systems-performance-work-guided-by-flamegraphs). +Note: lots of good advice on system performance in the [flamegraph +repo](https://github.com/flamegraph-rs/flamegraph#systems-performance-work-guided-by-flamegraphs). diff --git a/kimchi/benches/proof_criterion.rs b/kimchi/benches/proof_criterion.rs index 8a1923e48f..80d68564b1 100644 --- a/kimchi/benches/proof_criterion.rs +++ b/kimchi/benches/proof_criterion.rs @@ -5,26 +5,38 @@ pub fn bench_proof_creation(c: &mut Criterion) { let mut group = c.benchmark_group("Proof creation"); group.sample_size(10).sampling_mode(SamplingMode::Flat); // for slow benchmarks - let ctx = BenchmarkCtx::new(10); - group.bench_function( - format!("proof creation (SRS size 2^{})", ctx.srs_size()), - |b| b.iter(|| black_box(ctx.create_proof())), - ); + for size in [10, 14] { + let ctx = BenchmarkCtx::new(size); - let ctx = BenchmarkCtx::new(14); - group.bench_function( - format!("proof creation (SRS size 2^{})", ctx.srs_size()), - |b| b.iter(|| black_box(ctx.create_proof())), - ); - - let proof_and_public = ctx.create_proof(); + group.bench_function( + format!( + "proof creation (SRS size 2^{{{}}}, {} gates)", + ctx.srs_size(), + ctx.num_gates + ), + |b| b.iter(|| black_box(ctx.create_proof())), + ); + } +} +pub fn bench_proof_verification(c: &mut Criterion) { + let mut group = c.benchmark_group("Proof verification"); group.sample_size(100).sampling_mode(SamplingMode::Auto); - group.bench_function( - format!("proof verification (SRS size 2^{})", ctx.srs_size()), - |b| b.iter(|| ctx.batch_verification(black_box(&vec![proof_and_public.clone()]))), - ); + + for size in [10, 14] { + let ctx = BenchmarkCtx::new(size); + let proof_and_public = ctx.create_proof(); + + group.bench_function( + format!( + "proof verification (SRS size 2^{{{}}}, {} gates)", + ctx.srs_size(), + ctx.num_gates + ), + |b| b.iter(|| ctx.batch_verification(black_box(&vec![proof_and_public.clone()]))), + ); + } } -criterion_group!(benches, bench_proof_creation); +criterion_group!(benches, bench_proof_creation, bench_proof_verification); criterion_main!(benches); diff --git a/kimchi/snarky-deriver/Cargo.toml b/kimchi/snarky-deriver/Cargo.toml new file mode 100644 index 0000000000..8e981c9b8a --- /dev/null +++ b/kimchi/snarky-deriver/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "snarky-deriver" +version = "0.1.0" +edition = "2021" +license = "Apache-2.0" +description = "The inner library of kimchi, for derive macros" +repository = "https://github.com/o1-labs/proof-systems" + +[lib] +proc-macro = true + +[dependencies] +convert_case.workspace = true +itertools.workspace = true +proc-macro2.workspace = true +quote.workspace = true +syn.workspace = true diff --git a/kimchi/snarky-deriver/README.md b/kimchi/snarky-deriver/README.md new file mode 100644 index 0000000000..3903df7379 --- /dev/null +++ b/kimchi/snarky-deriver/README.md @@ -0,0 +1,4 @@ +# Snarky deriver + +This crate contains procedural macros used by snarky. +In Rust, procedural macros must be defined in a separate crate, this is why this code is split from kimchi. diff --git a/kimchi/snarky-deriver/src/lib.rs b/kimchi/snarky-deriver/src/lib.rs new file mode 100644 index 0000000000..e31dc67b0f --- /dev/null +++ b/kimchi/snarky-deriver/src/lib.rs @@ -0,0 +1,573 @@ +//! **This crate is not meant to be imported directly by users**. +//! You should import [kimchi](https://crates.io/crates/kimchi) instead. +//! +//! snarky-deriver adds a number of derives to make snarky easier to use. +//! Refer to the [snarky](https://o1-labs.github.io/proof-systems/rustdoc/kimchi/snarky/index.html) documentation. + +extern crate proc_macro; +use std::collections::HashSet; + +use itertools::Itertools; +use proc_macro::TokenStream; +use proc_macro2::Span; +use quote::{format_ident, quote}; +use syn::{ + punctuated::Punctuated, Fields, Lit, MetaNameValue, PredicateType, TraitBound, + TraitBoundModifier, Type, TypeParamBound, WherePredicate, +}; + +/// The [SnarkyType] derive macro. +/// It generates implementations of \[`kimchi::snarky::SnarkyType`\], +/// as long as your structure's fields implement that type as well. +/// It works very similarly to \[`serde`\]. +/// +/// For example: +/// +/// ```ignore +/// #[derive(kimchi::SnarkyType)] +/// struct MyType where F: PrimeField { +/// // ... +/// } +/// ``` +/// +/// You can specify your snarky type implementation to only apply to a single +/// field (e.g. `Fp`), or to apply to a field that has a different name than +/// `F`: +/// +/// ```text +/// #[derive(kimchi::SnarkyType)] +/// #[snarky(field = "G::ScalarField")] +/// struct MyType where G: KimchiCurve { +/// ``` +/// +/// You can skip a field in the serializer: +/// +/// ```text +/// #[derive(kimchi::SnarkyType)] +/// struct MyType where F: PrimeField { +/// #[snarky(skip)] +/// field_to_skip: NotASnarkyType, +/// ``` +/// +/// Finally, you can specify a custom check function, +/// as well as a custom auxiliary function and type: +/// +/// ```text +/// #[derive(kimchi::SnarkyType)] +/// #[snarky(check_fn = "my_check_fn")] +/// #[snarky(auxiliary_fn = "my_auxiliary_fn")] +/// #[snarky(auxiliary_type = "MyAuxiliaryType")] +/// struct MyType where F: PrimeField { +/// ``` +/// +/// By default, a tuple will be use to represent the out-of-circuit type. +/// You can specify it yourself by using the `value` helper attribute: +/// +/// ```text +/// #[derive(kimchi::SnarkyType)] +/// #[snarky(value = "MyOutOfCircuitType")] +/// struct MyType where F: PrimeField { +/// ``` +/// +/// and implement the \[`CircuitAndValue`\] trait on your \[`SnarkyType`\]. +/// +#[proc_macro_derive(SnarkyType, attributes(snarky))] +pub fn derive_snarky_type(item: TokenStream) -> TokenStream { + // The strategy is the following: + // + // - we want to produce an implementation for the given struct + // - with the exception that we want to enforce the SnarkyType bound on every generic type + + // parse struct + let item_struct: syn::ItemStruct = + syn::parse(item).expect("only structs are supported with `SnarkyType` at the moment"); + + // parse macro attributes into our cfg struct + // warning: parsing macro attributes is hairy + #[derive(Default)] + struct HelperAttributes { + field: Option, + check_fn: Option, + auxiliary_fn: Option, + auxiliary_type: Option, + value: Option, + } + let mut helper_attributes = HelperAttributes::default(); + + let malformed_snarky_helper = + "snarky helper malformed. It should look like `#[snarky(key = value)]`"; + for attr in &item_struct.attrs { + if let Ok(syn::Meta::List(meta)) = attr.parse_meta() { + // we only care about `#[snarky(...)]` + if !meta.path.is_ident("snarky") { + continue; + } + + for meta_inner in meta.nested { + match meta_inner { + syn::NestedMeta::Meta(syn::Meta::NameValue(MetaNameValue { + path, + eq_token: _, + lit, + })) => { + let value = match lit { + Lit::Str(lit) => lit.value(), + _ => panic!("{malformed_snarky_helper}"), + }; + if path.is_ident("field") { + helper_attributes.field = Some(value); + } else if path.is_ident("check_fn") { + helper_attributes.check_fn = Some(value); + panic!("`check_fn` is not supported yet. Post an issue with this error to request it (https://github.com/o1-labs/proof-systems/issues)"); + } else if path.is_ident("auxiliary_fn") { + helper_attributes.auxiliary_fn = Some(value); + panic!("`auxiliary_fn` is not supported yet. Post an issue with this error to request it (https://github.com/o1-labs/proof-systems/issues)"); + } else if path.is_ident("auxiliary_type") { + helper_attributes.auxiliary_type = Some(value); + panic!("`auxiliary_type` is not supported yet. Post an issue with this error to request it (https://github.com/o1-labs/proof-systems/issues)"); + } else if path.is_ident("value") { + helper_attributes.value = Some(value); + } else { + panic!("{malformed_snarky_helper}"); + } + } + x => panic!("{x:?} {malformed_snarky_helper}"), + }; + } + } + } + + // enforce that custom auxiliary type and custom auxiliary fn come hand in hand + match ( + helper_attributes.auxiliary_type.is_some(), + helper_attributes.auxiliary_fn.is_some(), + ) { + (true, false) | (false, true) => { + panic!("`auxiliary_type` and `auxiliary_fn` have to be specified together") + } + _ => (), + }; + + // enforce that we have a type parameter `F: PrimeField` + // this is needed as we use the same type parameter in the implementation of SnarkyType + // (i.e. `impl SnarkyType for MyType`) + // + // Note: if you use `#[snarky(field = "MyField")]`, then we don't enforce this + + let has_f_primefield = |generic: &WherePredicate| { + if let WherePredicate::Type(t) = generic { + // check that the type parameter is `F` + if !matches!(&t.bounded_ty, Type::Path(p) if p.path.is_ident("F")) { + return false; + } + // check that it is bound by `PrimeField` + t.bounds.iter().any(|bound| { + if let TypeParamBound::Trait(trait_bound) = bound { + trait_bound.path.segments.last().unwrap().ident == "PrimeField" + } else { + false + } + }) + } else { + false + } + }; + + let missing_field_err = + "SnarkyType requires the type to have a type parameter `F: PrimeField, or to specify the field used via `#[snarky(field = \"...\")]`"; + + let has_f_primefield = item_struct + .generics + .where_clause + .as_ref() + .map(|w| w.predicates.iter().any(has_f_primefield)) + .unwrap_or(false); + + let impl_field = match (has_f_primefield, helper_attributes.field.as_ref()) { + (true, None) => "F", + (false, Some(field)) => field, + (false, None) => panic!("{missing_field_err}"), + (true, Some(_)) => panic!("you cannot specify a field via `#[snarky(field = \"...\")] if the type already has a `F: PrimeField`"), + }; + let impl_field_path: syn::Path = syn::parse_str(impl_field).unwrap(); + + // collect field names and types + let fields = &item_struct.fields; + let (field_names, field_types) = match fields { + Fields::Named(fields) => fields + .named + .iter() + .map(|f| (f.ident.as_ref().unwrap(), &f.ty)) + .unzip(), + Fields::Unnamed(_fields) => panic!("tuple structs are not supported yet"), + Fields::Unit => (vec![], vec![]), + }; + + // enforce that we have at least 1 field + if field_names.is_empty() { + panic!("to use `#[derive(SnarkyType)]` your struct must at least have one field"); + } + + // this deriver is used by both external users, and the kimchi crate itself, + // so we need to change the path to kimchi depending on the context + let lib_path = if std::env::var("CARGO_PKG_NAME").unwrap() == "kimchi" { + "crate" + } else { + "::kimchi" + }; + let snarky_type_path_str = format!("{lib_path}::SnarkyType<{impl_field}>"); + let snarky_type_path: syn::Path = syn::parse_str(&snarky_type_path_str).unwrap(); + + // + // We define every block of the SnarkyType implementation individually. + // + + let auxiliary = if field_names.len() > 1 { + quote! { + type Auxiliary = ( + #( <#field_types as #snarky_type_path>::Auxiliary ),* + ); + } + } else { + quote! { + type Auxiliary = ( + #( <#field_types as #snarky_type_path>::Auxiliary )* + ); + } + }; + + let out_of_circuit = if let Some(value) = &helper_attributes.value { + let typ: syn::Path = syn::parse_str(value).expect("could not parse your value"); + quote! { + type OutOfCircuit = #typ; + } + } else if field_names.len() > 1 { + quote! { + type OutOfCircuit = ( + #( <#field_types as #snarky_type_path>::OutOfCircuit ),* + ); + } + } else { + quote! { + type OutOfCircuit = ( + #( <#field_types as #snarky_type_path>::OutOfCircuit )* + ); + } + }; + + let size_in_field_elements = quote! { + const SIZE_IN_FIELD_ELEMENTS: usize = #( <#field_types as #snarky_type_path>::SIZE_IN_FIELD_ELEMENTS )+*; + }; + + let to_cvars = if field_names.len() > 1 { + // `let (cvars_i, aux_i) = field_i.to_cvars();` + let mut to_cvars_calls = vec![]; + for (idx, field) in field_names.iter().enumerate() { + let idx = syn::Index::from(idx); + let cvar_i = format_ident!("cvars_{}", idx); + let aux_i = format_ident!("aux_{}", idx); + to_cvars_calls.push(quote! { + let (#cvar_i, #aux_i) = self.#field.to_cvars(); + }); + } + + // `aux_i, ...` + let aux = (0..field_names.len()) + .map(|idx| format_ident!("aux_{}", syn::Index::from(idx))) + .collect_vec(); + + quote! { + fn to_cvars(&self) -> (Vec>, Self::Auxiliary) { + let mut cvars = Vec::with_capacity(Self::SIZE_IN_FIELD_ELEMENTS); + + #( #to_cvars_calls );* + + let aux = ( + #( #aux ),* + ); + + (cvars, aux) + } + } + } else { + quote! { + fn to_cvars(&self) -> (Vec>, Self::Auxiliary) { + #( self.#field_names.to_cvars() )* + } + } + }; + + let from_cvars_unsafe = if field_names.len() > 1 { + // ``` + // let end = offset + T_i::SIZE_IN_FIELD_ELEMENTS; + // let cvars_i = &cvars[offset..end]; + // offset = end; + // let aux_i = aux.i; + // ``` + let mut cvars_and_aux = vec![]; + for (idx, field_ty) in field_types.iter().enumerate() { + let idx = syn::Index::from(idx); + let cvar_i = format_ident!("cvars_{}", idx); + let aux_i = format_ident!("aux_{}", idx); + + cvars_and_aux.push(quote! { + let end = offset + <#field_ty as #snarky_type_path>::SIZE_IN_FIELD_ELEMENTS; + let #cvar_i = &cvars[offset..end]; + let offset = end; + let #aux_i = aux.#idx; + }); + } + + // ` T_i::from_cvars_unsafe(cvars_i, aux_i)` + let mut from_cvars_unsafe = Vec::with_capacity(field_types.len()); + for (idx, field_ty) in field_types.iter().enumerate() { + let idx = syn::Index::from(idx); + let cvar_i = format_ident!("cvars_{}", idx); + let aux_i = format_ident!("aux_{}", idx); + + from_cvars_unsafe.push(quote! { + <#field_ty as #snarky_type_path>::from_cvars_unsafe(#cvar_i.to_vec(), #aux_i) + }); + } + + quote! { + fn from_cvars_unsafe(cvars: Vec>, aux: Self::Auxiliary) -> Self { + assert_eq!(cvars.len(), Self::SIZE_IN_FIELD_ELEMENTS); + + let mut offset = 0; + #( #cvars_and_aux );* + + Self { + #( #field_names: #from_cvars_unsafe ),* + } + } + } + } else { + quote! { + fn from_cvars_unsafe(cvars: Vec>, aux: Self::Auxiliary) -> Self { + assert_eq!(cvars.len(), Self::SIZE_IN_FIELD_ELEMENTS); + + Self { + #( #field_names: <#field_types as #snarky_type_path>::from_cvars_unsafe(cvars, aux) ),* + } + } + } + }; + + let check = { + quote! { + fn check(&self, cs: &mut RunState<#impl_field_path>) -> SnarkyResult<()> { + #( self.#field_names.check(cs)? );* + Ok(()) + } + } + }; + + // constraint_system_auxiliary + let constraint_system_auxiliary = if field_names.len() > 1 { + quote! { + fn constraint_system_auxiliary() -> Self::Auxiliary { + ( + #( <#field_types as #snarky_type_path>::constraint_system_auxiliary() ),* + ) + } + } + } else { + quote! { + fn constraint_system_auxiliary() -> Self::Auxiliary { + ( + #( <#field_types as #snarky_type_path>::constraint_system_auxiliary() )* + ) + } + } + }; + + // value_to_field_elements + let value_to_field_elements = if helper_attributes.value.is_some() { + let value_trait = format!("{lib_path}::CircuitAndValue<{impl_field}>"); + let value_trait: syn::Path = syn::parse_str(&value_trait).unwrap(); + quote! { + fn value_to_field_elements(value: &Self::OutOfCircuit) -> (Vec<#impl_field_path>, Self::Auxiliary) { + ::from_value(value) + } + } + } else if field_names.len() > 1 { + // `let (fields_i, aux_i) = T_i::value_to_field_elements(&self.i);` + let mut value_to_field_elements_calls = Vec::with_capacity(field_types.len()); + for (idx, field_ty) in field_types.iter().enumerate() { + let idx = syn::Index::from(idx); + let fields_name = format_ident!("fields_{}", idx); + let aux_i = format_ident!("aux_{}", idx); + + value_to_field_elements_calls.push(quote! { + let (#fields_name, #aux_i) = <#field_ty as #snarky_type_path>::value_to_field_elements(&value.#idx); + }); + } + + // `[fields_i, ...].concat()` + let fields_i = (0..field_types.len()) + .map(|idx| format_ident!("fields_{}", syn::Index::from(idx))) + .collect_vec(); + + // `aux_i, ...` + let aux_i = (0..field_types.len()) + .map(|idx| format_ident!("aux_{}", syn::Index::from(idx))) + .collect_vec(); + + quote! { + fn value_to_field_elements(value: &Self::OutOfCircuit) -> (Vec<#impl_field_path>, Self::Auxiliary) { + #( #value_to_field_elements_calls );* + + let fields = [ #( #fields_i ),* ].concat(); + + let aux = ( + #( #aux_i ),* + ); + + (fields, aux) + } + } + } else { + quote! { + fn value_to_field_elements(value: &Self::OutOfCircuit) -> (Vec<#impl_field_path>, Self::Auxiliary) { + + #( <#field_types as #snarky_type_path>::value_to_field_elements(value) )* + } + } + }; + + // value_of_field_elements + let value_of_field_elements = if helper_attributes.value.is_some() { + let value_trait = format!("{lib_path}::CircuitAndValue<{impl_field}>"); + let value_trait: syn::Path = syn::parse_str(&value_trait).unwrap(); + + quote! { + fn value_of_field_elements(fields: Vec<#impl_field_path>, aux: Self::Auxiliary) -> Self::OutOfCircuit { + assert_eq!(fields.len(), Self::SIZE_IN_FIELD_ELEMENTS); + ::to_value(fields, aux) + } + } + } else if field_types.len() > 1 { + // ``` + // let end = offset + T_i::SIZE_IN_FIELD_ELEMENTS; + // let cvars_i = &cvars[offset..end]; + // offset = end; + // let aux_i = aux.i; + // ``` + let mut fields_and_aux = vec![]; + for (idx, field_ty) in field_types.iter().enumerate() { + let idx = syn::Index::from(idx); + let fields_i = format_ident!("fields_{}", idx); + let aux_i = format_ident!("aux_{}", idx); + + fields_and_aux.push(quote! { + let end = offset + <#field_ty as #snarky_type_path>::SIZE_IN_FIELD_ELEMENTS; + let #fields_i = &fields[offset..end]; + let offset = end; + let #aux_i = aux.#idx; + }); + } + + // ` T_i::value_of_field_elements(fields_i, aux_i)` + let mut value_of_field_elements = Vec::with_capacity(field_types.len()); + for (idx, field_ty) in field_types.iter().enumerate() { + let idx = syn::Index::from(idx); + let fields_i = format_ident!("fields_{}", idx); + let aux_i = format_ident!("aux_{}", idx); + + value_of_field_elements.push(quote! { + <#field_ty as #snarky_type_path>::value_of_field_elements(#fields_i.to_vec(), #aux_i) + }); + } + + quote! { + fn value_of_field_elements(fields: Vec<#impl_field_path>, aux: Self::Auxiliary) -> Self::OutOfCircuit { + assert_eq!(fields.len(), Self::SIZE_IN_FIELD_ELEMENTS); + + let mut offset = 0; + #( #fields_and_aux );* + + ( + #( #value_of_field_elements ),* + ) + } + } + } else { + quote! { + fn value_of_field_elements(fields: Vec<#impl_field_path>, aux: Self::Auxiliary) -> Self::OutOfCircuit { + assert_eq!(fields.len(), Self::SIZE_IN_FIELD_ELEMENTS); + + #( <#field_types as #snarky_type_path>::value_of_field_elements(fields, aux) )* + } + } + }; + + // + // Implementation + // + + let (impl_generics, ty_generics, _where_clause) = item_struct.generics.split_for_impl(); + + // add SnarkyType bounds to all the field types + let mut extended_generics = item_struct.generics.clone(); + extended_generics.make_where_clause(); + let mut extended_where_clause = extended_generics.where_clause.unwrap(); + + let path: syn::Path = + syn::parse_str(&snarky_type_path_str).expect("snarky_type_path_str is not a valid?"); + let impl_snarky_type = TraitBound { + paren_token: None, + modifier: TraitBoundModifier::None, + lifetimes: None, + path, + }; + + let mut types_to_bind = HashSet::new(); + for field_type in field_types { + match field_type { + Type::Path(_path) => types_to_bind.insert(field_type.clone()), + Type::Array(array) => types_to_bind.insert(*array.elem.clone()), + x => panic!("non-path as bound types are not supported yet. To request this feature, please copy the entire error to a Github issue (https://github.com/o1-labs/proof-systems): {x:?}"), + }; + } + + let mut bounds = Punctuated::::new(); + bounds.push(TypeParamBound::Trait(impl_snarky_type.clone())); + for bounded_ty in types_to_bind { + extended_where_clause + .predicates + .push(WherePredicate::Type(PredicateType { + lifetimes: None, + bounded_ty, + colon_token: syn::token::Colon { + spans: [Span::call_site()], + }, + bounds: bounds.clone(), + })); + } + + // generate implementations for OCamlDesc and OCamlBinding + let name = &item_struct.ident; + let gen = quote! { + impl #impl_generics #snarky_type_path for #name #ty_generics #extended_where_clause { + #auxiliary + + #out_of_circuit + + #size_in_field_elements + + #to_cvars + + #from_cvars_unsafe + + #check + + #constraint_system_auxiliary + + #value_to_field_elements + + #value_of_field_elements + } + }; + gen.into() +} diff --git a/kimchi/src/bench.rs b/kimchi/src/bench.rs index 8f1139ab65..ccbf1ea5bf 100644 --- a/kimchi/src/bench.rs +++ b/kimchi/src/bench.rs @@ -7,7 +7,7 @@ use mina_poseidon::{ sponge::{DefaultFqSponge, DefaultFrSponge}, }; use o1_utils::math; -use poly_commitment::{commitment::CommitmentCurve, evaluation_proof::OpeningProof}; +use poly_commitment::{commitment::CommitmentCurve, ipa::OpeningProof, SRS as _}; use crate::{ circuits::{ @@ -26,7 +26,7 @@ type BaseSponge = DefaultFqSponge; type ScalarSponge = DefaultFrSponge; pub struct BenchmarkCtx { - num_gates: usize, + pub num_gates: usize, group_map: BWParameters, index: ProverIndex>, verifier_index: VerifierIndex>, @@ -34,10 +34,11 @@ pub struct BenchmarkCtx { impl BenchmarkCtx { pub fn srs_size(&self) -> usize { - math::ceil_log2(self.index.srs.max_degree()) + math::ceil_log2(self.index.srs.max_poly_size()) } - /// This will create a context that allows for benchmarks of `num_gates` gates (multiplication gates). + /// This will create a context that allows for benchmarks of `num_gates` + /// gates (multiplication gates). pub fn new(srs_size_log2: u32) -> Self { // there's some overhead that we need to remove (e.g. zk rows) @@ -115,30 +116,3 @@ impl BenchmarkCtx { .unwrap(); } } - -#[cfg(test)] -mod tests { - use std::time::Instant; - - use super::*; - - #[test] - fn test_bench() { - // context created in 21.2235 ms - let start = Instant::now(); - let srs_size = 4; - let ctx = BenchmarkCtx::new(srs_size); - println!("testing bench code for SRS of size {srs_size}"); - println!("context created in {}s", start.elapsed().as_secs()); - - // proof created in 7.1227 ms - let start = Instant::now(); - let (proof, public_input) = ctx.create_proof(); - println!("proof created in {}s", start.elapsed().as_secs()); - - // proof verified in 1.710 ms - let start = Instant::now(); - ctx.batch_verification(&vec![(proof, public_input)]); - println!("proof verified in {}", start.elapsed().as_secs()); - } -} diff --git a/kimchi/src/circuits/argument.rs b/kimchi/src/circuits/argument.rs index 0b2519d5c2..92daa0f4df 100644 --- a/kimchi/src/circuits/argument.rs +++ b/kimchi/src/circuits/argument.rs @@ -3,6 +3,7 @@ //! Both the permutation and the plookup arguments fit this type. //! Gates can be seen as filtered arguments, //! which apply only in some points (rows) of the domain. +//! For more info, read book/src/kimchi/arguments.md use std::marker::PhantomData; @@ -10,8 +11,10 @@ use crate::{alphas::Alphas, circuits::expr::prologue::*}; use ark_ff::{Field, PrimeField}; use serde::{Deserialize, Serialize}; +//TODO use generic challenge use super::{ - expr::{constraints::ExprOps, Cache, ConstantExpr, Constants}, + berkeley_columns::{BerkeleyChallengeTerm, BerkeleyChallenges}, + expr::{constraints::ExprOps, Cache, ConstantExpr, ConstantTerm, Constants}, gate::{CurrOrNext, GateType}, polynomial::COLUMNS, }; @@ -32,7 +35,7 @@ pub enum ArgumentType { /// The argument environment is used to specify how the argument's constraints are /// represented when they are built. If the environment is created without ArgumentData -/// and with F = `Expr`, then the constraints are built as Expr expressions (e.g. for +/// and with `F = Expr`, then the constraints are built as Expr expressions (e.g. for /// use with the prover/verifier). On the other hand, if the environment is /// created with ArgumentData and F = Field or F = PrimeField, then the constraints /// are built as expressions of real field elements and can be evaluated directly on @@ -52,15 +55,21 @@ impl Default for ArgumentEnv { } } -impl> ArgumentEnv { +impl> ArgumentEnv { /// Initialize the environment for creating constraints of real field elements that can be /// evaluated directly over the witness without the prover/verifier - pub fn create(witness: ArgumentWitness, coeffs: Vec, constants: Constants) -> Self { + pub fn create( + witness: ArgumentWitness, + coeffs: Vec, + constants: Constants, + challenges: BerkeleyChallenges, + ) -> Self { ArgumentEnv { data: Some(ArgumentData { witness, coeffs, constants, + challenges, }), phantom_data: PhantomData, } @@ -81,24 +90,57 @@ impl> ArgumentEnv { T::witness(Next, col, self.data.as_ref()) } + /// Witness cells in current row in an interval [from, to) + pub fn witness_curr_chunk(&self, from: usize, to: usize) -> Vec { + let mut chunk = Vec::with_capacity(to - from); + for i in from..to { + chunk.push(self.witness_curr(i)); + } + chunk + } + + /// Witness cells in next row in an interval [from, to) + pub fn witness_next_chunk(&self, from: usize, to: usize) -> Vec { + let mut chunk = Vec::with_capacity(to - from); + for i in from..to { + chunk.push(self.witness_next(i)); + } + chunk + } + /// Coefficient value at index idx pub fn coeff(&self, idx: usize) -> T { T::coeff(idx, self.data.as_ref()) } + /// Chunk of consecutive coefficients in an interval [from, to) + pub fn coeff_chunk(&self, from: usize, to: usize) -> Vec { + let mut chunk = Vec::with_capacity(to - from); + for i in from..to { + chunk.push(self.coeff(i)); + } + chunk + } + /// Constant value (see [ConstantExpr] for supported constants) - pub fn constant(&self, expr: ConstantExpr) -> T { + pub fn constant(&self, expr: ConstantExpr) -> T { T::constant(expr, self.data.as_ref()) } /// Helper to access endomorphism coefficient constant pub fn endo_coefficient(&self) -> T { - T::constant(ConstantExpr::::EndoCoefficient, self.data.as_ref()) + T::constant( + ConstantExpr::from(ConstantTerm::EndoCoefficient), + self.data.as_ref(), + ) } /// Helper to access maximum distance separable matrix constant at row, col pub fn mds(&self, row: usize, col: usize) -> T { - T::constant(ConstantExpr::::Mds { row, col }, self.data.as_ref()) + T::constant( + ConstantExpr::from(ConstantTerm::Mds { row, col }), + self.data.as_ref(), + ) } } @@ -110,6 +152,7 @@ pub struct ArgumentData { pub coeffs: Vec, /// Constants pub constants: Constants, + pub challenges: BerkeleyChallenges, } /// Witness data for a argument @@ -142,7 +185,10 @@ pub trait Argument { const CONSTRAINTS: u32; /// Constraints for this argument - fn constraint_checks>(env: &ArgumentEnv, cache: &mut Cache) -> Vec; + fn constraint_checks>( + env: &ArgumentEnv, + cache: &mut Cache, + ) -> Vec; /// Returns the set of constraints required to prove this argument. fn constraints(cache: &mut Cache) -> Vec> { diff --git a/kimchi/src/circuits/berkeley_columns.rs b/kimchi/src/circuits/berkeley_columns.rs new file mode 100644 index 0000000000..2bf4fc87fb --- /dev/null +++ b/kimchi/src/circuits/berkeley_columns.rs @@ -0,0 +1,370 @@ +//! This module defines the particular form of the expressions used in the Mina +//! Berkeley hardfork. You can find more information in [this blog +//! article](https://www.o1labs.org/blog/reintroducing-kimchi). +//! This module is also a good starting point if you want to implement your own +//! variant of Kimchi using the expression framework. +//! +//! The module uses the generic expression framework defined in the +//! [crate::circuits::expr] module. +//! The expressions define the polynomials that can be used to describe the +//! constraints. +//! It starts by defining the different challenges used by the PLONK IOP in +//! [BerkeleyChallengeTerm] and [BerkeleyChallenges]. +//! It then defines the [Column] type which represents the different variables +//! the polynomials are defined over. +//! +//! Two "environments" are after that defined: one for the lookup argument +//! [LookupEnvironment], and one for the main argument [Environment], which +//! contains the former. +//! The trait [ColumnEnvironment] is then defined to provide the necessary +//! primitives used to evaluate the quotient polynomial. + +use crate::{ + circuits::{ + domains::{Domain, EvaluationDomains}, + expr::{ + CacheId, ColumnEnvironment, ColumnEvaluations, ConstantExpr, ConstantTerm, Constants, + Expr, ExprError, FormattedOutput, + }, + gate::{CurrOrNext, GateType}, + lookup::{index::LookupSelectors, lookups::LookupPattern}, + wires::COLUMNS, + }, + proof::{PointEvaluations, ProofEvaluations}, +}; +use ark_ff::FftField; +use ark_poly::{Evaluations, Radix2EvaluationDomain as D}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +/// The challenge terms used in Berkeley. +#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub enum BerkeleyChallengeTerm { + /// Used to combine constraints + Alpha, + /// The first challenge used in the permutation argument + Beta, + /// The second challenge used in the permutation argument + Gamma, + /// A challenge used to columns of a lookup table + JointCombiner, +} + +impl std::fmt::Display for BerkeleyChallengeTerm { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + use BerkeleyChallengeTerm::*; + let str = match self { + Alpha => "alpha".to_string(), + Beta => "beta".to_string(), + Gamma => "gamma".to_string(), + JointCombiner => "joint_combiner".to_string(), + }; + write!(f, "{}", str) + } +} + +impl<'a> super::expr::AlphaChallengeTerm<'a> for BerkeleyChallengeTerm { + const ALPHA: Self = Self::Alpha; +} + +pub struct BerkeleyChallenges { + /// The challenge α from the PLONK IOP. + pub alpha: F, + /// The challenge β from the PLONK IOP. + pub beta: F, + /// The challenge γ from the PLONK IOP. + pub gamma: F, + /// The challenge joint_combiner which is used to combine joint lookup + /// tables. + pub joint_combiner: F, +} + +impl std::ops::Index for BerkeleyChallenges { + type Output = F; + + fn index(&self, challenge_term: BerkeleyChallengeTerm) -> &Self::Output { + match challenge_term { + BerkeleyChallengeTerm::Alpha => &self.alpha, + BerkeleyChallengeTerm::Beta => &self.beta, + BerkeleyChallengeTerm::Gamma => &self.gamma, + BerkeleyChallengeTerm::JointCombiner => &self.joint_combiner, + } + } +} + +/// A type representing the variables involved in the constraints of the +/// Berkeley hardfork. +/// +/// In Berkeley, the constraints are defined over the following variables: +/// - The [COLUMNS] witness columns. +/// - The permutation polynomial, Z. +/// - The public coefficients, `Coefficients`, which can be used for public +/// values. For instance, it is used for the Poseidon round constants. +/// - ... +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)] +pub enum Column { + Witness(usize), + Z, + LookupSorted(usize), + LookupAggreg, + LookupTable, + LookupKindIndex(LookupPattern), + LookupRuntimeSelector, + LookupRuntimeTable, + Index(GateType), + Coefficient(usize), + Permutation(usize), +} + +impl FormattedOutput for Column { + fn is_alpha(&self) -> bool { + // FIXME. Unused at the moment + unimplemented!() + } + + fn ocaml(&self, _cache: &mut HashMap) -> String { + // FIXME. Unused at the moment + unimplemented!() + } + + fn latex(&self, _cache: &mut HashMap) -> String { + match self { + Column::Witness(i) => format!("w_{{{i}}}"), + Column::Z => "Z".to_string(), + Column::LookupSorted(i) => format!("s_{{{i}}}"), + Column::LookupAggreg => "a".to_string(), + Column::LookupTable => "t".to_string(), + Column::LookupKindIndex(i) => format!("k_{{{i:?}}}"), + Column::LookupRuntimeSelector => "rts".to_string(), + Column::LookupRuntimeTable => "rt".to_string(), + Column::Index(gate) => { + format!("{gate:?}") + } + Column::Coefficient(i) => format!("c_{{{i}}}"), + Column::Permutation(i) => format!("sigma_{{{i}}}"), + } + } + + fn text(&self, _cache: &mut HashMap) -> String { + match self { + Column::Witness(i) => format!("w[{i}]"), + Column::Z => "Z".to_string(), + Column::LookupSorted(i) => format!("s[{i}]"), + Column::LookupAggreg => "a".to_string(), + Column::LookupTable => "t".to_string(), + Column::LookupKindIndex(i) => format!("k[{i:?}]"), + Column::LookupRuntimeSelector => "rts".to_string(), + Column::LookupRuntimeTable => "rt".to_string(), + Column::Index(gate) => { + format!("{gate:?}") + } + Column::Coefficient(i) => format!("c[{i}]"), + Column::Permutation(i) => format!("sigma_[{i}]"), + } + } +} + +impl ColumnEvaluations for ProofEvaluations> { + type Column = Column; + fn evaluate(&self, col: Self::Column) -> Result, ExprError> { + use Column::*; + match col { + Witness(i) => Ok(self.w[i]), + Z => Ok(self.z), + LookupSorted(i) => self.lookup_sorted[i].ok_or(ExprError::MissingIndexEvaluation(col)), + LookupAggreg => self + .lookup_aggregation + .ok_or(ExprError::MissingIndexEvaluation(col)), + LookupTable => self + .lookup_table + .ok_or(ExprError::MissingIndexEvaluation(col)), + LookupRuntimeTable => self + .runtime_lookup_table + .ok_or(ExprError::MissingIndexEvaluation(col)), + Index(GateType::Poseidon) => Ok(self.poseidon_selector), + Index(GateType::Generic) => Ok(self.generic_selector), + Index(GateType::CompleteAdd) => Ok(self.complete_add_selector), + Index(GateType::VarBaseMul) => Ok(self.mul_selector), + Index(GateType::EndoMul) => Ok(self.emul_selector), + Index(GateType::EndoMulScalar) => Ok(self.endomul_scalar_selector), + Index(GateType::RangeCheck0) => self + .range_check0_selector + .ok_or(ExprError::MissingIndexEvaluation(col)), + Index(GateType::RangeCheck1) => self + .range_check1_selector + .ok_or(ExprError::MissingIndexEvaluation(col)), + Index(GateType::ForeignFieldAdd) => self + .foreign_field_add_selector + .ok_or(ExprError::MissingIndexEvaluation(col)), + Index(GateType::ForeignFieldMul) => self + .foreign_field_mul_selector + .ok_or(ExprError::MissingIndexEvaluation(col)), + Index(GateType::Xor16) => self + .xor_selector + .ok_or(ExprError::MissingIndexEvaluation(col)), + Index(GateType::Rot64) => self + .rot_selector + .ok_or(ExprError::MissingIndexEvaluation(col)), + Permutation(i) => Ok(self.s[i]), + Coefficient(i) => Ok(self.coefficients[i]), + LookupKindIndex(LookupPattern::Xor) => self + .xor_lookup_selector + .ok_or(ExprError::MissingIndexEvaluation(col)), + LookupKindIndex(LookupPattern::Lookup) => self + .lookup_gate_lookup_selector + .ok_or(ExprError::MissingIndexEvaluation(col)), + LookupKindIndex(LookupPattern::RangeCheck) => self + .range_check_lookup_selector + .ok_or(ExprError::MissingIndexEvaluation(col)), + LookupKindIndex(LookupPattern::ForeignFieldMul) => self + .foreign_field_mul_lookup_selector + .ok_or(ExprError::MissingIndexEvaluation(col)), + LookupRuntimeSelector => self + .runtime_lookup_table_selector + .ok_or(ExprError::MissingIndexEvaluation(col)), + Index(_) => Err(ExprError::MissingIndexEvaluation(col)), + } + } +} + +impl<'a, F: FftField> ColumnEnvironment<'a, F, BerkeleyChallengeTerm, BerkeleyChallenges> + for Environment<'a, F> +{ + type Column = Column; + + fn get_column(&self, col: &Self::Column) -> Option<&'a Evaluations>> { + use Column::*; + let lookup = self.lookup.as_ref(); + match col { + Witness(i) => Some(&self.witness[*i]), + Coefficient(i) => Some(&self.coefficient[*i]), + Z => Some(self.z), + LookupKindIndex(i) => lookup.and_then(|l| l.selectors[*i].as_ref()), + LookupSorted(i) => lookup.map(|l| &l.sorted[*i]), + LookupAggreg => lookup.map(|l| l.aggreg), + LookupTable => lookup.map(|l| l.table), + LookupRuntimeSelector => lookup.and_then(|l| l.runtime_selector), + LookupRuntimeTable => lookup.and_then(|l| l.runtime_table), + Index(t) => match self.index.get(t) { + None => None, + Some(e) => Some(e), + }, + Permutation(_) => None, + } + } + + fn get_domain(&self, d: Domain) -> D { + match d { + Domain::D1 => self.domain.d1, + Domain::D2 => self.domain.d2, + Domain::D4 => self.domain.d4, + Domain::D8 => self.domain.d8, + } + } + + fn column_domain(&self, col: &Self::Column) -> Domain { + match *col { + Self::Column::Index(GateType::Generic) => Domain::D4, + Self::Column::Index(GateType::CompleteAdd) => Domain::D4, + _ => Domain::D8, + } + } + + fn get_constants(&self) -> &Constants { + &self.constants + } + + fn get_challenges(&self) -> &BerkeleyChallenges { + &self.challenges + } + + fn vanishes_on_zero_knowledge_and_previous_rows(&self) -> &'a Evaluations> { + self.vanishes_on_zero_knowledge_and_previous_rows + } + + fn l0_1(&self) -> F { + self.l0_1 + } +} + +/// The polynomials specific to the lookup argument. +/// +/// All are evaluations over the D8 domain +pub struct LookupEnvironment<'a, F: FftField> { + /// The sorted lookup table polynomials. + pub sorted: &'a Vec>>, + /// The lookup aggregation polynomials. + pub aggreg: &'a Evaluations>, + /// The lookup-type selector polynomials. + pub selectors: &'a LookupSelectors>>, + /// The evaluations of the combined lookup table polynomial. + pub table: &'a Evaluations>, + /// The evaluations of the optional runtime selector polynomial. + pub runtime_selector: Option<&'a Evaluations>>, + /// The evaluations of the optional runtime table. + pub runtime_table: Option<&'a Evaluations>>, +} + +/// The collection of polynomials (all in evaluation form) and constants +/// required to evaluate an expression as a polynomial. +/// +/// All are evaluations. +pub struct Environment<'a, F: FftField> { + /// The witness column polynomials + pub witness: &'a [Evaluations>; COLUMNS], + /// The coefficient column polynomials + pub coefficient: &'a [Evaluations>; COLUMNS], + /// The polynomial that vanishes on the zero-knowledge rows and the row before. + pub vanishes_on_zero_knowledge_and_previous_rows: &'a Evaluations>, + /// The permutation aggregation polynomial. + pub z: &'a Evaluations>, + /// The index selector polynomials. + pub index: HashMap>>, + /// The value `prod_{j != 1} (1 - omega^j)`, used for efficiently + /// computing the evaluations of the unnormalized Lagrange basis polynomials. + pub l0_1: F, + /// Constant values required + pub constants: Constants, + /// Challenges from the IOP. + pub challenges: BerkeleyChallenges, + /// The domains used in the PLONK argument. + pub domain: EvaluationDomains, + /// Lookup specific polynomials + pub lookup: Option>, +} + +// +// Helpers +// + +/// An alias for the intended usage of the expression type in constructing constraints. +pub type E = Expr, Column>; + +/// Convenience function to create a constant as [Expr]. +pub fn constant(x: F) -> E { + ConstantTerm::Literal(x).into() +} + +/// Helper function to quickly create an expression for a witness. +pub fn witness(i: usize, row: CurrOrNext) -> E { + E::::cell(Column::Witness(i), row) +} + +/// Same as [witness] but for the current row. +pub fn witness_curr(i: usize) -> E { + witness(i, CurrOrNext::Curr) +} + +/// Same as [witness] but for the next row. +pub fn witness_next(i: usize) -> E { + witness(i, CurrOrNext::Next) +} + +/// Handy function to quickly create an expression for a gate. +pub fn index(g: GateType) -> E { + E::::cell(Column::Index(g), CurrOrNext::Curr) +} + +pub fn coeff(i: usize) -> E { + E::::cell(Column::Coefficient(i), CurrOrNext::Curr) +} diff --git a/kimchi/src/circuits/constraints.rs b/kimchi/src/circuits/constraints.rs index f8f2f20555..b53c925f20 100644 --- a/kimchi/src/circuits/constraints.rs +++ b/kimchi/src/circuits/constraints.rs @@ -299,6 +299,20 @@ impl ConstraintSystem { .set(precomputations) .expect("Precomputation has been set before"); } + + /// test helpers + pub fn for_testing(gates: Vec>) -> Self { + let public = 0; + // not sure if theres a smarter way instead of the double unwrap, but should be fine in the test + ConstraintSystem::::create(gates) + .public(public) + .build() + .unwrap() + } + + pub fn fp_for_testing(gates: Vec>) -> Self { + Self::for_testing(gates) + } } impl, OpeningProof: OpenProof> @@ -342,7 +356,7 @@ impl, OpeningProof: OpenProof> } // for public gates, only the left wire is toggled - if row < self.cs.public && gate.coeffs[0] != F::one() { + if row < self.cs.public && gate.coeffs.get(0) != Some(&F::one()) { return Err(GateError::IncorrectPublic(row)); } @@ -639,6 +653,22 @@ impl ConstraintSystem { } } +/// The default number of chunks in a circuit is one (< 2^16 rows) +pub const NUM_CHUNKS_BY_DEFAULT: usize = 1; + +/// The number of rows required for zero knowledge in circuits with one single chunk +pub const ZK_ROWS_BY_DEFAULT: u64 = 3; + +/// This function computes a strict lower bound in the number of rows required +/// for zero knowledge in circuits with `num_chunks` chunks. This means that at +/// least one needs 1 more row than the result of this function to achieve zero +/// knowledge. +/// Example: +/// for 1 chunk, this function returns 2, but at least 3 rows are needed +/// Note: +/// the number of zero knowledge rows is usually computed across the codebase +/// as the formula `(16 * num_chunks + 5) / 7`, which is precisely the formula +/// in this function plus one. pub fn zk_rows_strict_lower_bound(num_chunks: usize) -> usize { (2 * (PERMUTS + 1) * num_chunks - 2) / PERMUTS } @@ -931,55 +961,3 @@ impl Builder { Ok(constraints) } } - -#[cfg(test)] -pub mod tests { - use super::*; - use mina_curves::pasta::Fp; - - impl ConstraintSystem { - pub fn for_testing(gates: Vec>) -> Self { - let public = 0; - // not sure if theres a smarter way instead of the double unwrap, but should be fine in the test - ConstraintSystem::::create(gates) - .public(public) - .build() - .unwrap() - } - } - - impl ConstraintSystem { - pub fn fp_for_testing(gates: Vec>) -> Self { - //let fp_sponge_params = mina_poseidon::pasta::fp_kimchi::params(); - Self::for_testing(gates) - } - } - - #[test] - pub fn test_domains_computation_with_runtime_tables() { - let dummy_gate = CircuitGate { - typ: GateType::Generic, - wires: [Wire::new(0, 0); PERMUTS], - coeffs: vec![Fp::zero()], - }; - // inputs + expected output - let data = [((10, 10), 128), ((0, 0), 8), ((5, 100), 512)]; - for ((number_of_rt_cfgs, size), expected_domain_size) in data.into_iter() { - let builder = ConstraintSystem::create(vec![dummy_gate.clone(), dummy_gate.clone()]); - let table_ids: Vec = (0..number_of_rt_cfgs).collect(); - let rt_cfgs: Vec> = table_ids - .into_iter() - .map(|table_id| { - let indexes: Vec = (0..size).collect(); - let first_column: Vec = indexes.into_iter().map(Fp::from).collect(); - RuntimeTableCfg { - id: table_id, - first_column, - } - }) - .collect(); - let res = builder.runtime(Some(rt_cfgs)).build().unwrap(); - assert_eq!(res.domain.d1.size, expected_domain_size) - } - } -} diff --git a/kimchi/src/circuits/domains.rs b/kimchi/src/circuits/domains.rs index 89251bea5f..114e2a5226 100644 --- a/kimchi/src/circuits/domains.rs +++ b/kimchi/src/circuits/domains.rs @@ -1,53 +1,65 @@ +//! This module describes the evaluation domains that can be used by the +//! polynomials. + use ark_ff::FftField; -use ark_poly::{EvaluationDomain, Radix2EvaluationDomain as Domain}; +use ark_poly::{EvaluationDomain, Radix2EvaluationDomain}; use serde::{Deserialize, Serialize}; use serde_with::serde_as; use crate::error::DomainCreationError; +/// The different multiplicaive domain sizes that can be used by the polynomials. +/// We do support up to 8 times the size of the original domain for now. +#[derive(Clone, Copy, Debug, PartialEq, FromPrimitive, ToPrimitive)] +pub enum Domain { + D1 = 1, + D2 = 2, + D4 = 4, + D8 = 8, +} + #[serde_as] #[derive(Debug, Clone, Copy, Serialize, Deserialize)] pub struct EvaluationDomains { #[serde_as(as = "o1_utils::serialization::SerdeAs")] - pub d1: Domain, // size n + pub d1: Radix2EvaluationDomain, // size n #[serde_as(as = "o1_utils::serialization::SerdeAs")] - pub d2: Domain, // size 2n + pub d2: Radix2EvaluationDomain, // size 2n #[serde_as(as = "o1_utils::serialization::SerdeAs")] - pub d4: Domain, // size 4n + pub d4: Radix2EvaluationDomain, // size 4n #[serde_as(as = "o1_utils::serialization::SerdeAs")] - pub d8: Domain, // size 8n + pub d8: Radix2EvaluationDomain, // size 8n } impl EvaluationDomains { - /// Creates 4 evaluation domains `d1` (of size `n`), `d2` (of size `2n`), `d4` (of size `4n`), - /// and `d8` (of size `8n`). If generator of `d8` is `g`, the generator - /// of `d4` is `g^2`, the generator of `d2` is `g^4`, and the generator of `d1` is `g^8`. + /// Creates 4 evaluation domains `d1` (of size `n`), `d2` (of size `2n`), + /// `d4` (of size `4n`), and `d8` (of size `8n`). If generator of `d8` is + /// `g`, the generator of `d4` is `g^2`, the generator of `d2` is `g^4`, and + /// the generator of `d1` is `g^8`. pub fn create(n: usize) -> Result { - let n = Domain::::compute_size_of_domain(n) + let n = Radix2EvaluationDomain::::compute_size_of_domain(n) .ok_or(DomainCreationError::DomainSizeFailed(n))?; - let d1 = Domain::::new(n).ok_or(DomainCreationError::DomainConstructionFailed( - "d1".to_string(), - n, - ))?; + let d1 = Radix2EvaluationDomain::::new(n).ok_or( + DomainCreationError::DomainConstructionFailed("d1".to_string(), n), + )?; // we also create domains of larger sizes // to efficiently operate on polynomials in evaluation form. - // (in evaluation form, the domain needs to grow as the degree of a polynomial grows) - let d2 = Domain::::new(2 * n).ok_or(DomainCreationError::DomainConstructionFailed( - "d2".to_string(), - 2 * n, - ))?; - let d4 = Domain::::new(4 * n).ok_or(DomainCreationError::DomainConstructionFailed( - "d4".to_string(), - 4 * n, - ))?; - let d8 = Domain::::new(8 * n).ok_or(DomainCreationError::DomainConstructionFailed( - "d8".to_string(), - 8 * n, - ))?; + // (in evaluation form, the domain needs to grow as the degree of a + // polynomial grows) + let d2 = Radix2EvaluationDomain::::new(2 * n).ok_or( + DomainCreationError::DomainConstructionFailed("d2".to_string(), 2 * n), + )?; + let d4 = Radix2EvaluationDomain::::new(4 * n).ok_or( + DomainCreationError::DomainConstructionFailed("d4".to_string(), 4 * n), + )?; + let d8 = Radix2EvaluationDomain::::new(8 * n).ok_or( + DomainCreationError::DomainConstructionFailed("d8".to_string(), 8 * n), + )?; - // ensure the relationship between the three domains in case the library's behavior changes + // ensure the relationship between the three domains in case the + // library's behavior changes assert_eq!(d2.group_gen.square(), d1.group_gen); assert_eq!(d4.group_gen.square(), d2.group_gen); assert_eq!(d8.group_gen.square(), d4.group_gen); @@ -55,24 +67,3 @@ impl EvaluationDomains { Ok(EvaluationDomains { d1, d2, d4, d8 }) } } - -#[cfg(test)] -mod tests { - use super::*; - use ark_ff::Field; - use mina_curves::pasta::Fp; - - #[test] - #[ignore] // TODO(mimoo): wait for fix upstream (https://github.com/arkworks-rs/algebra/pull/307) - fn test_create_domain() { - if let Ok(d) = EvaluationDomains::::create(usize::MAX) { - assert!(d.d4.group_gen.pow([4]) == d.d1.group_gen); - assert!(d.d8.group_gen.pow([2]) == d.d4.group_gen); - println!("d8 = {:?}", d.d8.group_gen); - println!("d8^2 = {:?}", d.d8.group_gen.pow([2])); - println!("d4 = {:?}", d.d4.group_gen); - println!("d4 = {:?}", d.d4.group_gen.pow([4])); - println!("d1 = {:?}", d.d1.group_gen); - } - } -} diff --git a/kimchi/src/circuits/expr.rs b/kimchi/src/circuits/expr.rs index 7158bb896a..8235109145 100644 --- a/kimchi/src/circuits/expr.rs +++ b/kimchi/src/circuits/expr.rs @@ -1,16 +1,16 @@ use crate::{ circuits::{ + berkeley_columns, + berkeley_columns::BerkeleyChallengeTerm, constraints::FeatureFlags, - domains::EvaluationDomains, - gate::{CurrOrNext, GateType}, - lookup::{ - index::LookupSelectors, - lookups::{LookupPattern, LookupPatterns}, + domains::Domain, + gate::CurrOrNext, + lookup::lookups::{LookupPattern, LookupPatterns}, + polynomials::{ + foreign_field_common::KimchiForeignElement, permutation::eval_vanishes_on_last_n_rows, }, - polynomials::permutation::eval_vanishes_on_last_n_rows, - wires::COLUMNS, }, - proof::{PointEvaluations, ProofEvaluations}, + proof::PointEvaluations, }; use ark_ff::{FftField, Field, One, PrimeField, Zero}; use ark_poly::{ @@ -24,8 +24,9 @@ use std::{ cmp::Ordering, collections::{HashMap, HashSet}, fmt, + fmt::{Debug, Display}, iter::FromIterator, - ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub}, + ops::{Add, AddAssign, Index, Mul, MulAssign, Neg, Sub}, }; use thiserror::Error; use CurrOrNext::{Curr, Next}; @@ -33,7 +34,7 @@ use CurrOrNext::{Curr, Next}; use self::constraints::ExprOps; #[derive(Debug, Error)] -pub enum ExprError { +pub enum ExprError { #[error("Empty stack")] EmptyStack, @@ -47,23 +48,23 @@ pub enum ExprError { MissingIndexEvaluation(Column), #[error("Linearization failed (too many unevaluated columns: {0:?}")] - FailedLinearization(Vec), + FailedLinearization(Vec>), #[error("runtime table not available")] MissingRuntime, } +/// The Challenge term that contains an alpha. +/// Is used to make a random linear combination of constraints +pub trait AlphaChallengeTerm<'a>: + Copy + Clone + Debug + PartialEq + Eq + Serialize + Deserialize<'a> + Display +{ + const ALPHA: Self; +} + /// The collection of constants required to evaluate an `Expr`. +#[derive(Clone)] pub struct Constants { - /// The challenge alpha from the PLONK IOP. - pub alpha: F, - /// The challenge beta from the PLONK IOP. - pub beta: F, - /// The challenge gamma from the PLONK IOP. - pub gamma: F, - /// The challenge joint_combiner which is used to combine - /// joint lookup tables. - pub joint_combiner: Option, /// The endomorphism coefficient pub endo_coefficient: F, /// The MDS matrix @@ -72,71 +73,41 @@ pub struct Constants { pub zk_rows: u64, } -/// The polynomials specific to the lookup argument. -/// -/// All are evaluations over the D8 domain -pub struct LookupEnvironment<'a, F: FftField> { - /// The sorted lookup table polynomials. - pub sorted: &'a Vec>>, - /// The lookup aggregation polynomials. - pub aggreg: &'a Evaluations>, - /// The lookup-type selector polynomials. - pub selectors: &'a LookupSelectors>>, - /// The evaluations of the combined lookup table polynomial. - pub table: &'a Evaluations>, - /// The evaluations of the optional runtime selector polynomial. - pub runtime_selector: Option<&'a Evaluations>>, - /// The evaluations of the optional runtime table. - pub runtime_table: Option<&'a Evaluations>>, -} - -/// The collection of polynomials (all in evaluation form) and constants -/// required to evaluate an expression as a polynomial. -/// -/// All are evaluations. -pub struct Environment<'a, F: FftField> { - /// The witness column polynomials - pub witness: &'a [Evaluations>; COLUMNS], - /// The coefficient column polynomials - pub coefficient: &'a [Evaluations>; COLUMNS], - /// The polynomial that vanishes on the zero-knowledge rows and the row before. - pub vanishes_on_zero_knowledge_and_previous_rows: &'a Evaluations>, - /// The permutation aggregation polynomial. - pub z: &'a Evaluations>, - /// The index selector polynomials. - pub index: HashMap>>, - /// The value `prod_{j != 1} (1 - omega^j)`, used for efficiently +pub trait ColumnEnvironment< + 'a, + F: FftField, + ChallengeTerm, + Challenges: Index, +> +{ + /// The generic type of column the environment can use. + /// In other words, with the multi-variate polynomial analogy, it is the + /// variables the multi-variate polynomials are defined upon. + /// i.e. for a polynomial `P(X, Y, Z)`, the type will represent the variable + /// `X`, `Y` and `Z`. + type Column; + + /// Return the evaluation of the given column, over the domain. + fn get_column(&self, col: &Self::Column) -> Option<&'a Evaluations>>; + + /// Defines the domain over which the column is evaluated + fn column_domain(&self, col: &Self::Column) -> Domain; + + fn get_domain(&self, d: Domain) -> D; + + /// Return the constants parameters that the expression might use. + /// For instance, it can be the matrix used by the linear layer in the + /// permutation. + fn get_constants(&self) -> &Constants; + + /// Return the challenges, coined by the verifier. + fn get_challenges(&self) -> &Challenges; + + fn vanishes_on_zero_knowledge_and_previous_rows(&self) -> &'a Evaluations>; + + /// Return the value `prod_{j != 1} (1 - omega^j)`, used for efficiently /// computing the evaluations of the unnormalized Lagrange basis polynomials. - pub l0_1: F, - /// Constant values required - pub constants: Constants, - /// The domains used in the PLONK argument. - pub domain: EvaluationDomains, - /// Lookup specific polynomials - pub lookup: Option>, -} - -impl<'a, F: FftField> Environment<'a, F> { - fn get_column(&self, col: &Column) -> Option<&'a Evaluations>> { - use Column::*; - let lookup = self.lookup.as_ref(); - match col { - Witness(i) => Some(&self.witness[*i]), - Coefficient(i) => Some(&self.coefficient[*i]), - Z => Some(self.z), - LookupKindIndex(i) => lookup.and_then(|l| l.selectors[*i].as_ref()), - LookupSorted(i) => lookup.map(|l| &l.sorted[*i]), - LookupAggreg => lookup.map(|l| l.aggreg), - LookupTable => lookup.map(|l| l.table), - LookupRuntimeSelector => lookup.and_then(|l| l.runtime_selector), - LookupRuntimeTable => lookup.and_then(|l| l.runtime_table), - Index(t) => match self.index.get(t) { - None => None, - Some(e) => Some(e), - }, - Permutation(_) => None, - } - } + fn l0_1(&self) -> F; } // In this file, we define... @@ -158,7 +129,7 @@ pub fn l0_1(d: D) -> F { } // Compute the ith unnormalized lagrange basis -fn unnormalized_lagrange_basis(domain: &D, i: i32, pt: &F) -> F { +pub fn unnormalized_lagrange_basis(domain: &D, i: i32, pt: &F) -> F { let omega_i = if i < 0 { domain.group_gen.pow([-i as u64]).inverse().unwrap() } else { @@ -168,188 +139,366 @@ fn unnormalized_lagrange_basis(domain: &D, i: i32, pt: &F) -> F } #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)] -/// A type representing one of the polynomials involved in the PLONK IOP. -pub enum Column { - Witness(usize), - Z, - LookupSorted(usize), - LookupAggreg, - LookupTable, - LookupKindIndex(LookupPattern), - LookupRuntimeSelector, - LookupRuntimeTable, - Index(GateType), - Coefficient(usize), - Permutation(usize), -} - -impl Column { - fn domain(&self) -> Domain { +/// A type representing a variable which can appear in a constraint. It specifies a column +/// and a relative position (Curr or Next) +pub struct Variable { + /// The column of this variable + pub col: Column, + /// The row (Curr of Next) of this variable + pub row: CurrOrNext, +} + +/// Define the constant terms an expression can use. +/// It can be any constant term (`Literal`), a matrix (`Mds` - used by the +/// permutation used by Poseidon for instance), or endomorphism coefficients +/// (`EndoCoefficient` - used as an optimisation). +/// As for `challengeTerm`, it has been used initially to implement the PLONK +/// IOP, with the custom gate Poseidon. However, the terms have no built-in +/// semantic in the expression framework. +/// TODO: we should generalize the expression type over challenges and constants. +/// See +#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub enum ConstantTerm { + EndoCoefficient, + Mds { row: usize, col: usize }, + Literal(F), +} + +pub trait Literal: Sized + Clone { + type F; + + fn literal(x: Self::F) -> Self; + + fn to_literal(self) -> Result; + + fn to_literal_ref(&self) -> Option<&Self::F>; + + /// Obtains the representation of some constants as a literal. + /// This is useful before converting Kimchi expressions with constants + /// to folding compatible expressions. + fn as_literal(&self, constants: &Constants) -> Self; +} + +impl Literal for F { + type F = F; + + fn literal(x: Self::F) -> Self { + x + } + + fn to_literal(self) -> Result { + Ok(self) + } + + fn to_literal_ref(&self) -> Option<&Self::F> { + Some(self) + } + + fn as_literal(&self, _constants: &Constants) -> Self { + *self + } +} + +impl Literal for ConstantTerm { + type F = F; + fn literal(x: Self::F) -> Self { + ConstantTerm::Literal(x) + } + fn to_literal(self) -> Result { match self { - Column::Index(GateType::Generic) => Domain::D4, - Column::Index(GateType::CompleteAdd) => Domain::D4, - _ => Domain::D8, + ConstantTerm::Literal(x) => Ok(x), + x => Err(x), } } - - fn latex(&self) -> String { + fn to_literal_ref(&self) -> Option<&Self::F> { + match self { + ConstantTerm::Literal(x) => Some(x), + _ => None, + } + } + fn as_literal(&self, constants: &Constants) -> Self { match self { - Column::Witness(i) => format!("w_{{{i}}}"), - Column::Z => "Z".to_string(), - Column::LookupSorted(i) => format!("s_{{{i}}}"), - Column::LookupAggreg => "a".to_string(), - Column::LookupTable => "t".to_string(), - Column::LookupKindIndex(i) => format!("k_{{{i:?}}}"), - Column::LookupRuntimeSelector => "rts".to_string(), - Column::LookupRuntimeTable => "rt".to_string(), - Column::Index(gate) => { - format!("{gate:?}") + ConstantTerm::EndoCoefficient => { + ConstantTerm::Literal(constants.endo_coefficient.clone()) } - Column::Coefficient(i) => format!("c_{{{i}}}"), - Column::Permutation(i) => format!("sigma_{{{i}}}"), + ConstantTerm::Mds { row, col } => { + ConstantTerm::Literal(constants.mds[*row][*col].clone()) + } + ConstantTerm::Literal(_) => self.clone(), } } +} - fn text(&self) -> String { +#[derive(Clone, Debug, PartialEq)] +pub enum ConstantExprInner { + Challenge(ChallengeTerm), + Constant(ConstantTerm), +} + +impl<'a, F: Clone, ChallengeTerm: AlphaChallengeTerm<'a>> Literal + for ConstantExprInner +{ + type F = F; + fn literal(x: Self::F) -> Self { + Self::Constant(ConstantTerm::literal(x)) + } + fn to_literal(self) -> Result { match self { - Column::Witness(i) => format!("w[{i}]"), - Column::Z => "Z".to_string(), - Column::LookupSorted(i) => format!("s[{i}]"), - Column::LookupAggreg => "a".to_string(), - Column::LookupTable => "t".to_string(), - Column::LookupKindIndex(i) => format!("k[{i:?}]"), - Column::LookupRuntimeSelector => "rts".to_string(), - Column::LookupRuntimeTable => "rt".to_string(), - Column::Index(gate) => { - format!("{gate:?}") - } - Column::Coefficient(i) => format!("c[{i}]"), - Column::Permutation(i) => format!("sigma_[{i}]"), + Self::Constant(x) => match x.to_literal() { + Ok(x) => Ok(x), + Err(x) => Err(Self::Constant(x)), + }, + x => Err(x), + } + } + fn to_literal_ref(&self) -> Option<&Self::F> { + match self { + Self::Constant(x) => x.to_literal_ref(), + _ => None, + } + } + fn as_literal(&self, constants: &Constants) -> Self { + match self { + Self::Constant(x) => Self::Constant(x.as_literal(constants)), + Self::Challenge(_) => self.clone(), } } } -#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)] -/// A type representing a variable which can appear in a constraint. It specifies a column -/// and a relative position (Curr or Next) -pub struct Variable { - /// The column of this variable - pub col: Column, - /// The row (Curr of Next) of this variable - pub row: CurrOrNext, +impl<'a, F, ChallengeTerm: AlphaChallengeTerm<'a>> From + for ConstantExprInner +{ + fn from(x: ChallengeTerm) -> Self { + ConstantExprInner::Challenge(x) + } } -impl Variable { - fn ocaml(&self) -> String { - format!("var({:?}, {:?})", self.col, self.row) +impl From> for ConstantExprInner { + fn from(x: ConstantTerm) -> Self { + ConstantExprInner::Constant(x) + } +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub enum Operations { + Atom(T), + Pow(Box, u64), + Add(Box, Box), + Mul(Box, Box), + Sub(Box, Box), + Double(Box), + Square(Box), + Cache(CacheId, Box), + IfFeature(FeatureFlag, Box, Box), +} + +impl From for Operations { + fn from(x: T) -> Self { + Operations::Atom(x) } +} - fn latex(&self) -> String { - let col = self.col.latex(); - match self.row { - Curr => col, - Next => format!("\\tilde{{{col}}}"), +impl Literal for Operations { + type F = T::F; + + fn literal(x: Self::F) -> Self { + Self::Atom(T::literal(x)) + } + + fn to_literal(self) -> Result { + match self { + Self::Atom(x) => match x.to_literal() { + Ok(x) => Ok(x), + Err(x) => Err(Self::Atom(x)), + }, + x => Err(x), } } - fn text(&self) -> String { - let col = self.col.text(); - match self.row { - Curr => format!("Curr({col})"), - Next => format!("Next({col})"), + fn to_literal_ref(&self) -> Option<&Self::F> { + match self { + Self::Atom(x) => x.to_literal_ref(), + _ => None, + } + } + + fn as_literal(&self, constants: &Constants) -> Self { + match self { + Self::Atom(x) => Self::Atom(x.as_literal(constants)), + Self::Pow(x, n) => Self::Pow(Box::new(x.as_literal(constants)), *n), + Self::Add(x, y) => Self::Add( + Box::new(x.as_literal(constants)), + Box::new(y.as_literal(constants)), + ), + Self::Mul(x, y) => Self::Mul( + Box::new(x.as_literal(constants)), + Box::new(y.as_literal(constants)), + ), + Self::Sub(x, y) => Self::Sub( + Box::new(x.as_literal(constants)), + Box::new(y.as_literal(constants)), + ), + Self::Double(x) => Self::Double(Box::new(x.as_literal(constants))), + Self::Square(x) => Self::Square(Box::new(x.as_literal(constants))), + Self::Cache(id, x) => Self::Cache(*id, Box::new(x.as_literal(constants))), + Self::IfFeature(flag, if_true, if_false) => Self::IfFeature( + *flag, + Box::new(if_true.as_literal(constants)), + Box::new(if_false.as_literal(constants)), + ), } } } -#[derive(Clone, Debug, PartialEq)] -/// An arithmetic expression over -/// -/// - the operations *, +, -, ^ -/// - the constants `alpha`, `beta`, `gamma`, `joint_combiner`, and literal field elements. -pub enum ConstantExpr { - // TODO: Factor these out into an enum just for Alpha, Beta, Gamma, JointCombiner - Alpha, - Beta, - Gamma, - JointCombiner, - // TODO: EndoCoefficient and Mds differ from the other 4 base constants in - // that they are known at compile time. This should be extracted out into two - // separate constant expression types. - EndoCoefficient, - Mds { row: usize, col: usize }, - Literal(F), - Pow(Box>, u64), - // TODO: I think having separate Add, Sub, Mul constructors is faster than - // having a BinOp constructor :( - Add(Box>, Box>), - Mul(Box>, Box>), - Sub(Box>, Box>), +pub type ConstantExpr = Operations>; + +impl From> for ConstantExpr { + fn from(x: ConstantTerm) -> Self { + ConstantExprInner::from(x).into() + } +} + +impl<'a, F, ChallengeTerm: AlphaChallengeTerm<'a>> From + for ConstantExpr +{ + fn from(x: ChallengeTerm) -> Self { + ConstantExprInner::from(x).into() + } +} + +impl ConstantExprInner { + fn to_polish( + &self, + _cache: &mut HashMap, + res: &mut Vec>, + ) { + match self { + ConstantExprInner::Challenge(chal) => res.push(PolishToken::Challenge(*chal)), + ConstantExprInner::Constant(c) => res.push(PolishToken::Constant(*c)), + } + } } -impl ConstantExpr { - fn to_polish_(&self, res: &mut Vec>) { +impl Operations> { + fn to_polish( + &self, + cache: &mut HashMap, + res: &mut Vec>, + ) { match self { - ConstantExpr::Alpha => res.push(PolishToken::Alpha), - ConstantExpr::Beta => res.push(PolishToken::Beta), - ConstantExpr::Gamma => res.push(PolishToken::Gamma), - ConstantExpr::JointCombiner => res.push(PolishToken::JointCombiner), - ConstantExpr::EndoCoefficient => res.push(PolishToken::EndoCoefficient), - ConstantExpr::Mds { row, col } => res.push(PolishToken::Mds { - row: *row, - col: *col, - }), - ConstantExpr::Add(x, y) => { - x.as_ref().to_polish_(res); - y.as_ref().to_polish_(res); + Operations::Atom(atom) => atom.to_polish(cache, res), + Operations::Add(x, y) => { + x.as_ref().to_polish(cache, res); + y.as_ref().to_polish(cache, res); res.push(PolishToken::Add) } - ConstantExpr::Mul(x, y) => { - x.as_ref().to_polish_(res); - y.as_ref().to_polish_(res); + Operations::Mul(x, y) => { + x.as_ref().to_polish(cache, res); + y.as_ref().to_polish(cache, res); res.push(PolishToken::Mul) } - ConstantExpr::Sub(x, y) => { - x.as_ref().to_polish_(res); - y.as_ref().to_polish_(res); + Operations::Sub(x, y) => { + x.as_ref().to_polish(cache, res); + y.as_ref().to_polish(cache, res); res.push(PolishToken::Sub) } - ConstantExpr::Literal(x) => res.push(PolishToken::Literal(*x)), - ConstantExpr::Pow(x, n) => { - x.to_polish_(res); + Operations::Pow(x, n) => { + x.to_polish(cache, res); res.push(PolishToken::Pow(*n)) } + Operations::Double(x) => { + x.to_polish(cache, res); + res.push(PolishToken::Dup); + res.push(PolishToken::Add); + } + Operations::Square(x) => { + x.to_polish(cache, res); + res.push(PolishToken::Dup); + res.push(PolishToken::Mul); + } + Operations::Cache(id, x) => { + match cache.get(id) { + Some(pos) => + // Already computed and stored this. + { + res.push(PolishToken::Load(*pos)) + } + None => { + // Haven't computed this yet. Compute it, then store it. + x.to_polish(cache, res); + res.push(PolishToken::Store); + cache.insert(*id, cache.len()); + } + } + } + Operations::IfFeature(feature, if_true, if_false) => { + { + // True branch + let tok = PolishToken::SkipIfNot(*feature, 0); + res.push(tok); + let len_before = res.len(); + /* Clone the cache, to make sure we don't try to access cached statements later + when the feature flag is off. */ + let mut cache = cache.clone(); + if_true.to_polish(&mut cache, res); + let len_after = res.len(); + res[len_before - 1] = PolishToken::SkipIfNot(*feature, len_after - len_before); + } + + { + // False branch + let tok = PolishToken::SkipIfNot(*feature, 0); + res.push(tok); + let len_before = res.len(); + /* Clone the cache, to make sure we don't try to access cached statements later + when the feature flag is on. */ + let mut cache = cache.clone(); + if_false.to_polish(&mut cache, res); + let len_after = res.len(); + res[len_before - 1] = PolishToken::SkipIfNot(*feature, len_after - len_before); + } + } } } } -impl ConstantExpr { +impl Operations +where + T::F: Field, +{ /// Exponentiate a constant expression. pub fn pow(self, p: u64) -> Self { if p == 0 { - return Literal(F::one()); + return Self::literal(T::F::one()); } - use ConstantExpr::*; - match self { - Literal(x) => Literal(x.pow([p])), - x => Pow(Box::new(x), p), + match self.to_literal() { + Ok(l) => Self::literal(::pow(&l, [p])), + Err(x) => Self::Pow(Box::new(x), p), } } +} +impl ConstantExpr { /// Evaluate the given constant expression to a field element. - pub fn value(&self, c: &Constants) -> F { - use ConstantExpr::*; + pub fn value(&self, c: &Constants, chals: &dyn Index) -> F { + use ConstantExprInner::*; + use Operations::*; match self { - Alpha => c.alpha, - Beta => c.beta, - Gamma => c.gamma, - JointCombiner => c.joint_combiner.expect("joint lookup was not expected"), - EndoCoefficient => c.endo_coefficient, - Mds { row, col } => c.mds[*row][*col], - Literal(x) => *x, - Pow(x, p) => x.value(c).pow([*p]), - Mul(x, y) => x.value(c) * y.value(c), - Add(x, y) => x.value(c) + y.value(c), - Sub(x, y) => x.value(c) - y.value(c), + Atom(Challenge(challenge_term)) => chals[*challenge_term], + Atom(Constant(ConstantTerm::EndoCoefficient)) => c.endo_coefficient, + Atom(Constant(ConstantTerm::Mds { row, col })) => c.mds[*row][*col], + Atom(Constant(ConstantTerm::Literal(x))) => *x, + Pow(x, p) => x.value(c, chals).pow([*p]), + Mul(x, y) => x.value(c, chals) * y.value(c, chals), + Add(x, y) => x.value(c, chals) + y.value(c, chals), + Sub(x, y) => x.value(c, chals) - y.value(c, chals), + Double(x) => x.value(c, chals).double(), + Square(x) => x.value(c, chals).square(), + Cache(_, x) => { + // TODO: Use cache ID + x.value(c, chals) + } + IfFeature(_flag, _if_true, _if_false) => todo!(), } } } @@ -395,10 +544,6 @@ impl CacheId { fn latex_name(&self) -> String { format!("x_{{{}}}", self.0) } - - fn text_name(&self) -> String { - format!("x[{}]", self.0) - } } impl Cache { @@ -408,32 +553,13 @@ impl Cache { CacheId(id) } - pub fn cache>(&mut self, e: T) -> T { + pub fn cache>(&mut self, e: T) -> T { e.cache(self) } } -/// A binary operation -#[derive(Clone, Debug, PartialEq, Eq)] -pub enum Op2 { - Add, - Mul, - Sub, -} - -impl Op2 { - fn to_polish(&self) -> PolishToken { - use Op2::*; - match self { - Add => PolishToken::Add, - Mul => PolishToken::Mul, - Sub => PolishToken::Sub, - } - } -} - /// The feature flags that can be used to enable or disable parts of constraints. -#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)] #[cfg_attr( feature = "ocaml_types", derive(ocaml::IntoValue, ocaml::FromValue, ocaml_gen::Enum) @@ -466,6 +592,16 @@ pub struct RowOffset { pub offset: i32, } +#[derive(Clone, Debug, PartialEq)] +pub enum ExprInner { + Constant(C), + Cell(Variable), + VanishesOnZeroKnowledgeAndPreviousRows, + /// UnnormalizedLagrangeBasis(i) is + /// (x^n - 1) / (x - omega^i) + UnnormalizedLagrangeBasis(RowOffset), +} + /// An multi-variate polynomial over the base ring `C` with /// variables /// @@ -476,35 +612,78 @@ pub struct RowOffset { /// This represents a PLONK "custom constraint", which enforces that /// the corresponding combination of the polynomials corresponding to /// the above variables should vanish on the PLONK domain. -#[derive(Clone, Debug, PartialEq)] -pub enum Expr { - Constant(C), - Cell(Variable), - Double(Box>), - Square(Box>), - BinOp(Op2, Box>, Box>), - VanishesOnZeroKnowledgeAndPreviousRows, - /// UnnormalizedLagrangeBasis(i) is - /// (x^n - 1) / (x - omega^i) - UnnormalizedLagrangeBasis(RowOffset), - Pow(Box>, u64), - Cache(CacheId, Box>), - /// If the feature flag is enabled, return the first expression; otherwise, return the second. - IfFeature(FeatureFlag, Box>, Box>), +pub type Expr = Operations>; + +impl From> + for Expr, Column> +{ + fn from(x: ConstantExpr) -> Self { + Expr::Atom(ExprInner::Constant(x)) + } +} + +impl<'a, F, Column, ChallengeTerm: AlphaChallengeTerm<'a>> From> + for Expr, Column> +{ + fn from(x: ConstantTerm) -> Self { + ConstantExpr::from(x).into() + } +} + +impl<'a, F, Column, ChallengeTerm: AlphaChallengeTerm<'a>> From + for Expr, Column> +{ + fn from(x: ChallengeTerm) -> Self { + ConstantExpr::from(x).into() + } } -impl + PartialEq + Clone> Expr { - fn apply_feature_flags_inner(&self, features: &FeatureFlags) -> (Expr, bool) { - use Expr::*; +impl Literal for ExprInner { + type F = T::F; + + fn literal(x: Self::F) -> Self { + ExprInner::Constant(T::literal(x)) + } + + fn to_literal(self) -> Result { + match self { + ExprInner::Constant(x) => match x.to_literal() { + Ok(x) => Ok(x), + Err(x) => Err(ExprInner::Constant(x)), + }, + x => Err(x), + } + } + + fn to_literal_ref(&self) -> Option<&Self::F> { + match self { + ExprInner::Constant(x) => x.to_literal_ref(), + _ => None, + } + } + + fn as_literal(&self, constants: &Constants) -> Self { + match self { + ExprInner::Constant(x) => ExprInner::Constant(x.as_literal(constants)), + ExprInner::Cell(_) + | ExprInner::VanishesOnZeroKnowledgeAndPreviousRows + | ExprInner::UnnormalizedLagrangeBasis(_) => self.clone(), + } + } +} + +impl Operations +where + T::F: Field, +{ + fn apply_feature_flags_inner(&self, features: &FeatureFlags) -> (Self, bool) { + use Operations::*; match self { - Constant(_) - | Cell(_) - | VanishesOnZeroKnowledgeAndPreviousRows - | UnnormalizedLagrangeBasis(_) => (self.clone(), false), + Atom(_) => (self.clone(), false), Double(c) => { let (c_reduced, reduce_further) = c.apply_feature_flags_inner(features); if reduce_further && c_reduced.is_zero() { - (Expr::zero(), true) + (Self::zero(), true) } else { (Double(Box::new(c_reduced)), false) } @@ -517,62 +696,53 @@ impl + PartialEq + Clone> Expr { (Square(Box::new(c_reduced)), false) } } - BinOp(op, c1, c2) => { + Add(c1, c2) => { let (c1_reduced, reduce_further1) = c1.apply_feature_flags_inner(features); let (c2_reduced, reduce_further2) = c2.apply_feature_flags_inner(features); - match op { - Op2::Add => { - if reduce_further1 && c1_reduced.is_zero() { - if reduce_further2 && c2_reduced.is_zero() { - (Expr::zero(), true) - } else { - (c2_reduced, false) - } - } else if reduce_further2 && c2_reduced.is_zero() { - (c1_reduced, false) - } else { - ( - BinOp(Op2::Add, Box::new(c1_reduced), Box::new(c2_reduced)), - false, - ) - } + if reduce_further1 && c1_reduced.is_zero() { + if reduce_further2 && c2_reduced.is_zero() { + (Self::zero(), true) + } else { + (c2_reduced, false) } - Op2::Sub => { - if reduce_further1 && c1_reduced.is_zero() { - if reduce_further2 && c2_reduced.is_zero() { - (Expr::zero(), true) - } else { - (-c2_reduced, false) - } - } else if reduce_further2 && c2_reduced.is_zero() { - (c1_reduced, false) - } else { - ( - BinOp(Op2::Sub, Box::new(c1_reduced), Box::new(c2_reduced)), - false, - ) - } + } else if reduce_further2 && c2_reduced.is_zero() { + (c1_reduced, false) + } else { + (Add(Box::new(c1_reduced), Box::new(c2_reduced)), false) + } + } + Sub(c1, c2) => { + let (c1_reduced, reduce_further1) = c1.apply_feature_flags_inner(features); + let (c2_reduced, reduce_further2) = c2.apply_feature_flags_inner(features); + if reduce_further1 && c1_reduced.is_zero() { + if reduce_further2 && c2_reduced.is_zero() { + (Self::zero(), true) + } else { + (-c2_reduced, false) } - Op2::Mul => { - if reduce_further1 && c1_reduced.is_zero() - || reduce_further2 && c2_reduced.is_zero() - { - (Expr::zero(), true) - } else if reduce_further1 && c1_reduced.is_one() { - if reduce_further2 && c2_reduced.is_one() { - (Expr::one(), true) - } else { - (c2_reduced, false) - } - } else if reduce_further2 && c2_reduced.is_one() { - (c1_reduced, false) - } else { - ( - BinOp(Op2::Mul, Box::new(c1_reduced), Box::new(c2_reduced)), - false, - ) - } + } else if reduce_further2 && c2_reduced.is_zero() { + (c1_reduced, false) + } else { + (Sub(Box::new(c1_reduced), Box::new(c2_reduced)), false) + } + } + Mul(c1, c2) => { + let (c1_reduced, reduce_further1) = c1.apply_feature_flags_inner(features); + let (c2_reduced, reduce_further2) = c2.apply_feature_flags_inner(features); + if reduce_further1 && c1_reduced.is_zero() + || reduce_further2 && c2_reduced.is_zero() + { + (Self::zero(), true) + } else if reduce_further1 && c1_reduced.is_one() { + if reduce_further2 && c2_reduced.is_one() { + (Self::one(), true) + } else { + (c2_reduced, false) } + } else if reduce_further2 && c2_reduced.is_one() { + (c1_reduced, false) + } else { + (Mul(Box::new(c1_reduced), Box::new(c2_reduced)), false) } } Pow(c, power) => { @@ -628,7 +798,7 @@ impl + PartialEq + Clone> Expr { } } } - pub fn apply_feature_flags(&self, features: &FeatureFlags) -> Expr { + pub fn apply_feature_flags(&self, features: &FeatureFlags) -> Self { let (res, _) = self.apply_feature_flags_inner(features); res } @@ -638,18 +808,10 @@ impl + PartialEq + Clone> Expr { /// [reverse Polish notation](https://en.wikipedia.org/wiki/Reverse_Polish_notation) /// expressions, which are vectors of the below tokens. #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] -pub enum PolishToken { - Alpha, - Beta, - Gamma, - JointCombiner, - EndoCoefficient, - Mds { - row: usize, - col: usize, - }, - Literal(F), - Cell(Variable), +pub enum PolishToken { + Constant(ConstantTerm), + Challenge(ChallengeTerm), + Cell(Variable), Dup, Pow(u64), Add, @@ -665,72 +827,17 @@ pub enum PolishToken { SkipIfNot(FeatureFlag, usize), } -impl Variable { - fn evaluate( +pub trait ColumnEvaluations { + type Column; + fn evaluate(&self, col: Self::Column) -> Result, ExprError>; +} + +impl Variable { + fn evaluate>( &self, - evals: &ProofEvaluations>, - ) -> Result { - let point_evaluations = { - use Column::*; - match self.col { - Witness(i) => Ok(evals.w[i]), - Z => Ok(evals.z), - LookupSorted(i) => { - evals.lookup_sorted[i].ok_or(ExprError::MissingIndexEvaluation(self.col)) - } - LookupAggreg => evals - .lookup_aggregation - .ok_or(ExprError::MissingIndexEvaluation(self.col)), - LookupTable => evals - .lookup_table - .ok_or(ExprError::MissingIndexEvaluation(self.col)), - LookupRuntimeTable => evals - .runtime_lookup_table - .ok_or(ExprError::MissingIndexEvaluation(self.col)), - Index(GateType::Poseidon) => Ok(evals.poseidon_selector), - Index(GateType::Generic) => Ok(evals.generic_selector), - Index(GateType::CompleteAdd) => Ok(evals.complete_add_selector), - Index(GateType::VarBaseMul) => Ok(evals.mul_selector), - Index(GateType::EndoMul) => Ok(evals.emul_selector), - Index(GateType::EndoMulScalar) => Ok(evals.endomul_scalar_selector), - Index(GateType::RangeCheck0) => evals - .range_check0_selector - .ok_or(ExprError::MissingIndexEvaluation(self.col)), - Index(GateType::RangeCheck1) => evals - .range_check1_selector - .ok_or(ExprError::MissingIndexEvaluation(self.col)), - Index(GateType::ForeignFieldAdd) => evals - .foreign_field_add_selector - .ok_or(ExprError::MissingIndexEvaluation(self.col)), - Index(GateType::ForeignFieldMul) => evals - .foreign_field_mul_selector - .ok_or(ExprError::MissingIndexEvaluation(self.col)), - Index(GateType::Xor16) => evals - .xor_selector - .ok_or(ExprError::MissingIndexEvaluation(self.col)), - Index(GateType::Rot64) => evals - .rot_selector - .ok_or(ExprError::MissingIndexEvaluation(self.col)), - Permutation(i) => Ok(evals.s[i]), - Coefficient(i) => Ok(evals.coefficients[i]), - Column::LookupKindIndex(LookupPattern::Xor) => evals - .xor_lookup_selector - .ok_or(ExprError::MissingIndexEvaluation(self.col)), - Column::LookupKindIndex(LookupPattern::Lookup) => evals - .lookup_gate_lookup_selector - .ok_or(ExprError::MissingIndexEvaluation(self.col)), - Column::LookupKindIndex(LookupPattern::RangeCheck) => evals - .range_check_lookup_selector - .ok_or(ExprError::MissingIndexEvaluation(self.col)), - Column::LookupKindIndex(LookupPattern::ForeignFieldMul) => evals - .foreign_field_mul_lookup_selector - .ok_or(ExprError::MissingIndexEvaluation(self.col)), - Column::LookupRuntimeSelector => evals - .runtime_lookup_table_selector - .ok_or(ExprError::MissingIndexEvaluation(self.col)), - Index(_) => Err(ExprError::MissingIndexEvaluation(self.col)), - } - }?; + evals: &Evaluations, + ) -> Result> { + let point_evaluations = evals.evaluate(self.col)?; match self.row { CurrOrNext::Curr => Ok(point_evaluations.zeta), CurrOrNext::Next => Ok(point_evaluations.zeta_omega), @@ -738,15 +845,16 @@ impl Variable { } } -impl PolishToken { +impl PolishToken { /// Evaluate an RPN expression to a field element. - pub fn evaluate( - toks: &[PolishToken], + pub fn evaluate>( + toks: &[PolishToken], d: D, pt: F, - evals: &ProofEvaluations>, + evals: &Evaluations, c: &Constants, - ) -> Result { + chals: &dyn Index, + ) -> Result> { let mut stack = vec![]; let mut cache: Vec = vec![]; @@ -757,16 +865,13 @@ impl PolishToken { skip_count -= 1; continue; } + + use ConstantTerm::*; use PolishToken::*; match t { - Alpha => stack.push(c.alpha), - Beta => stack.push(c.beta), - Gamma => stack.push(c.gamma), - JointCombiner => { - stack.push(c.joint_combiner.expect("no joint lookup was expected")) - } - EndoCoefficient => stack.push(c.endo_coefficient), - Mds { row, col } => stack.push(c.mds[*row][*col]), + Challenge(challenge_term) => stack.push(chals[*challenge_term]), + Constant(EndoCoefficient) => stack.push(c.endo_coefficient), + Constant(Mds { row, col }) => stack.push(c.mds[*row][*col]), VanishesOnZeroKnowledgeAndPreviousRows => { stack.push(eval_vanishes_on_last_n_rows(d, c.zk_rows + 1, pt)) } @@ -778,7 +883,7 @@ impl PolishToken { }; stack.push(unnormalized_lagrange_basis(&d, offset, &pt)) } - Literal(x) => stack.push(*x), + Constant(Literal(x)) => stack.push(*x), Dup => stack.push(stack[stack.len() - 1]), Cell(v) => match v.evaluate(evals) { Ok(x) => stack.push(x), @@ -828,10 +933,10 @@ impl PolishToken { } } -impl Expr { +impl Expr { /// Convenience function for constructing cell variables. - pub fn cell(col: Column, row: CurrOrNext) -> Expr { - Expr::Cell(Variable { col, row }) + pub fn cell(col: Column, row: CurrOrNext) -> Expr { + Expr::Atom(ExprInner::Cell(Variable { col, row })) } pub fn double(self) -> Self { @@ -843,21 +948,30 @@ impl Expr { } /// Convenience function for constructing constant expressions. - pub fn constant(c: C) -> Expr { - Expr::Constant(c) - } - - fn degree(&self, d1_size: u64, zk_rows: u64) -> u64 { - use Expr::*; + pub fn constant(c: C) -> Expr { + Expr::Atom(ExprInner::Constant(c)) + } + + /// Return the degree of the expression. + /// The degree of a cell is defined by the first argument `d1_size`, a + /// constant being of degree zero. The degree of the expression is defined + /// recursively using the definition of the degree of a multivariate + /// polynomial. The function can be (and is) used to compute the domain + /// size, hence the name of the first argument `d1_size`. + /// The second parameter `zk_rows` is used to define the degree of the + /// constructor `VanishesOnZeroKnowledgeAndPreviousRows`. + pub fn degree(&self, d1_size: u64, zk_rows: u64) -> u64 { + use ExprInner::*; + use Operations::*; match self { Double(x) => x.degree(d1_size, zk_rows), - Constant(_) => 0, - VanishesOnZeroKnowledgeAndPreviousRows => zk_rows + 1, - UnnormalizedLagrangeBasis(_) => d1_size, - Cell(_) => d1_size, + Atom(Constant(_)) => 0, + Atom(VanishesOnZeroKnowledgeAndPreviousRows) => zk_rows + 1, + Atom(UnnormalizedLagrangeBasis(_)) => d1_size, + Atom(Cell(_)) => d1_size, Square(x) => 2 * x.degree(d1_size, zk_rows), - BinOp(Op2::Mul, x, y) => (*x).degree(d1_size, zk_rows) + (*y).degree(d1_size, zk_rows), - BinOp(Op2::Add, x, y) | BinOp(Op2::Sub, x, y) => { + Mul(x, y) => (*x).degree(d1_size, zk_rows) + (*y).degree(d1_size, zk_rows), + Add(x, y) | Sub(x, y) => { std::cmp::max((*x).degree(d1_size, zk_rows), (*y).degree(d1_size, zk_rows)) } Pow(e, d) => d * e.degree(d1_size, zk_rows), @@ -869,23 +983,18 @@ impl Expr { } } -impl fmt::Display for Expr> +impl<'a, F, Column: FormattedOutput + Debug + Clone, ChallengeTerm> fmt::Display + for Expr, Column> where F: PrimeField, + ChallengeTerm: AlphaChallengeTerm<'a>, { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.text_str()) + let cache = &mut HashMap::new(); + write!(f, "{}", self.text(cache)) } } -#[derive(Clone, Copy, Debug, PartialEq, FromPrimitive, ToPrimitive)] -enum Domain { - D1 = 1, - D2 = 2, - D4 = 4, - D8 = 8, -} - #[derive(Clone)] enum EvalResult<'a, F: FftField> { Constant(F), @@ -952,11 +1061,17 @@ pub fn pows(x: F, n: usize) -> Vec { /// = (omega^{q n} omega_8^{r n} - 1) / (omega_8^k - omega^i) /// = ((omega_8^n)^r - 1) / (omega_8^k - omega^i) /// = ((omega_8^n)^r - 1) / (omega^q omega_8^r - omega^i) -fn unnormalized_lagrange_evals( +fn unnormalized_lagrange_evals< + 'a, + F: FftField, + ChallengeTerm, + Challenge: Index, + Environment: ColumnEnvironment<'a, F, ChallengeTerm, Challenge>, +>( l0_1: F, i: i32, res_domain: Domain, - env: &Environment, + env: &Environment, ) -> Evaluations> { let k = match res_domain { Domain::D1 => 1, @@ -964,9 +1079,9 @@ fn unnormalized_lagrange_evals( Domain::D4 => 4, Domain::D8 => 8, }; - let res_domain = get_domain(res_domain, env); + let res_domain = env.get_domain(res_domain); - let d1 = env.domain.d1; + let d1 = env.get_domain(Domain::D1); let n = d1.size; // Renormalize negative values to wrap around at domain size let i = if i < 0 { @@ -1020,7 +1135,18 @@ fn unnormalized_lagrange_evals( Evaluations::>::from_vec_and_domain(evals, res_domain) } +/// Implement algebraic methods like `add`, `sub`, `mul`, `square`, etc to use +/// algebra on the type `EvalResult`. impl<'a, F: FftField> EvalResult<'a, F> { + /// Create an evaluation over the domain `res_domain`. + /// The second parameter, `g`, is a function used to define the + /// evaluations at a given point of the domain. + /// For instance, the second parameter `g` can simply be the identity + /// functions over a set of field elements. + /// It can also be used to define polynomials like `x^2` when we only have the + /// value of `x`. It can be used in particular to evaluate an expression (a + /// multi-variate polynomial) when we only do have access to the evaluations + /// of the individual variables. fn init_ F>( res_domain: (Domain, D), g: G, @@ -1032,6 +1158,8 @@ impl<'a, F: FftField> EvalResult<'a, F> { ) } + /// Call the internal function `init_` and return the computed evaluation as + /// a value `Evals`. fn init F>(res_domain: (Domain, D), g: G) -> Self { Self::Evals { domain: res_domain.0, @@ -1066,7 +1194,12 @@ impl<'a, F: FftField> EvalResult<'a, F> { ) => { let n = res_domain.1.size(); let scale = (domain as usize) / (res_domain.0 as usize); - assert!(scale != 0); + assert!( + scale != 0, + "Check that the implementation of + column_domain and the evaluation domain of the + witnesses are the same" + ); let v: Vec<_> = (0..n) .into_par_iter() .map(|i| { @@ -1118,7 +1251,12 @@ impl<'a, F: FftField> EvalResult<'a, F> { }, ) => { let scale = (d_sub as usize) / (d as usize); - assert!(scale != 0); + assert!( + scale != 0, + "Check that the implementation of + column_domain and the evaluation domain of the + witnesses are the same" + ); evals.evals.par_iter_mut().enumerate().for_each(|(i, e)| { *e += es_sub.evals[(scale * i + (d_sub as usize) * s) % es_sub.evals.len()]; }); @@ -1137,10 +1275,19 @@ impl<'a, F: FftField> EvalResult<'a, F> { }, ) => { let scale1 = (d1 as usize) / (res_domain.0 as usize); - assert!(scale1 != 0); + assert!( + scale1 != 0, + "Check that the implementation of + column_domain and the evaluation domain of the + witnesses are the same" + ); let scale2 = (d2 as usize) / (res_domain.0 as usize); - assert!(scale2 != 0); - + assert!( + scale2 != 0, + "Check that the implementation of + column_domain and the evaluation domain of the + witnesses are the same" + ); let n = res_domain.1.size(); let v: Vec<_> = (0..n) .into_par_iter() @@ -1179,7 +1326,12 @@ impl<'a, F: FftField> EvalResult<'a, F> { Constant(x), ) => { let scale = (d as usize) / (res_domain.0 as usize); - assert!(scale != 0); + assert!( + scale != 0, + "Check that the implementation of + column_domain and the evaluation domain of the + witnesses are the same" + ); EvalResult::init(res_domain, |i| { evals.evals[(scale * i + (d as usize) * s) % evals.evals.len()] - x }) @@ -1193,7 +1345,13 @@ impl<'a, F: FftField> EvalResult<'a, F> { }, ) => { let scale = (d as usize) / (res_domain.0 as usize); - assert!(scale != 0); + assert!( + scale != 0, + "Check that the implementation of + column_domain and the evaluation domain of the + witnesses are the same" + ); + EvalResult::init(res_domain, |i| { x - evals.evals[(scale * i + (d as usize) * s) % evals.evals.len()] }) @@ -1227,7 +1385,13 @@ impl<'a, F: FftField> EvalResult<'a, F> { }, ) => { let scale = (d_sub as usize) / (d as usize); - assert!(scale != 0); + assert!( + scale != 0, + "Check that the implementation of + column_domain and the evaluation domain of the + witnesses are the same" + ); + evals.evals.par_iter_mut().enumerate().for_each(|(i, e)| { *e = es_sub.evals[(scale * i + (d_sub as usize) * s) % es_sub.evals.len()] - *e; }); @@ -1245,7 +1409,12 @@ impl<'a, F: FftField> EvalResult<'a, F> { }, ) => { let scale = (d_sub as usize) / (d as usize); - assert!(scale != 0); + assert!( + scale != 0, + "Check that the implementation of + column_domain and the evaluation domain of the + witnesses are the same" + ); evals.evals.par_iter_mut().enumerate().for_each(|(i, e)| { *e -= es_sub.evals[(scale * i + (d_sub as usize) * s) % es_sub.evals.len()]; }); @@ -1264,9 +1433,19 @@ impl<'a, F: FftField> EvalResult<'a, F> { }, ) => { let scale1 = (d1 as usize) / (res_domain.0 as usize); - assert!(scale1 != 0); + assert!( + scale1 != 0, + "Check that the implementation of + column_domain and the evaluation domain of the + witnesses are the same" + ); let scale2 = (d2 as usize) / (res_domain.0 as usize); - assert!(scale2 != 0); + assert!( + scale2 != 0, + "Check that the implementation of + column_domain and the evaluation domain of the + witnesses are the same" + ); EvalResult::init(res_domain, |i| { es1.evals[(scale1 * i + (d1 as usize) * s1) % es1.evals.len()] @@ -1305,7 +1484,12 @@ impl<'a, F: FftField> EvalResult<'a, F> { shift: s, } => { let scale = (d as usize) / (res_domain.0 as usize); - assert!(scale != 0); + assert!( + scale != 0, + "Check that the implementation of + column_domain and the evaluation domain of the + witnesses are the same" + ); EvalResult::init(res_domain, |i| { evals.evals[(scale * i + (d as usize) * s) % evals.evals.len()].square() }) @@ -1339,7 +1523,12 @@ impl<'a, F: FftField> EvalResult<'a, F> { }, ) => { let scale = (d as usize) / (res_domain.0 as usize); - assert!(scale != 0); + assert!( + scale != 0, + "Check that the implementation of + column_domain and the evaluation domain of the + witnesses are the same" + ); EvalResult::init(res_domain, |i| { x * evals.evals[(scale * i + (d as usize) * s) % evals.evals.len()] }) @@ -1384,7 +1573,13 @@ impl<'a, F: FftField> EvalResult<'a, F> { }, ) => { let scale = (d_sub as usize) / (d as usize); - assert!(scale != 0); + assert!( + scale != 0, + "Check that the implementation of + column_domainand the evaluation domain of the + witnesses are the same" + ); + evals.evals.par_iter_mut().enumerate().for_each(|(i, e)| { *e *= es_sub.evals[(scale * i + (d_sub as usize) * s) % es_sub.evals.len()]; }); @@ -1403,10 +1598,20 @@ impl<'a, F: FftField> EvalResult<'a, F> { }, ) => { let scale1 = (d1 as usize) / (res_domain.0 as usize); - assert!(scale1 != 0); + assert!( + scale1 != 0, + "Check that the implementation of + column_domain and the evaluation domain of the + witnesses are the same" + ); let scale2 = (d2 as usize) / (res_domain.0 as usize); - assert!(scale2 != 0); + assert!( + scale2 != 0, + "Check that the implementation of + column_domain and the evaluation domain of the + witnesses are the same" + ); EvalResult::init(res_domain, |i| { es1.evals[(scale1 * i + (d1 as usize) * s1) % es1.evals.len()] * es2.evals[(scale2 * i + (d2 as usize) * s2) % es2.evals.len()] @@ -1416,43 +1621,40 @@ impl<'a, F: FftField> EvalResult<'a, F> { } } -fn get_domain(d: Domain, env: &Environment) -> D { - match d { - Domain::D1 => env.domain.d1, - Domain::D2 => env.domain.d2, - Domain::D4 => env.domain.d4, - Domain::D8 => env.domain.d8, - } -} - -impl Expr> { +impl<'a, F: Field, Column: PartialEq + Copy, ChallengeTerm: AlphaChallengeTerm<'a>> + Expr, Column> +{ /// Convenience function for constructing expressions from literal /// field elements. pub fn literal(x: F) -> Self { - Expr::Constant(ConstantExpr::Literal(x)) + ConstantTerm::Literal(x).into() } /// Combines multiple constraints `[c0, ..., cn]` into a single constraint /// `alpha^alpha0 * c0 + alpha^{alpha0 + 1} * c1 + ... + alpha^{alpha0 + n} * cn`. pub fn combine_constraints(alphas: impl Iterator, cs: Vec) -> Self { - let zero = Expr::>::zero(); + let zero = Expr::, Column>::zero(); cs.into_iter() .zip_eq(alphas) - .map(|(c, i)| Expr::Constant(ConstantExpr::Alpha.pow(i as u64)) * c) + .map(|(c, i)| Expr::from(ConstantExpr::pow(ChallengeTerm::ALPHA.into(), i as u64)) * c) .fold(zero, |acc, x| acc + x) } } -impl Expr> { +impl Expr, Column> { /// Compile an expression to an RPN expression. - pub fn to_polish(&self) -> Vec> { + pub fn to_polish(&self) -> Vec> { let mut res = vec![]; let mut cache = HashMap::new(); self.to_polish_(&mut cache, &mut res); res } - fn to_polish_(&self, cache: &mut HashMap, res: &mut Vec>) { + fn to_polish_( + &self, + cache: &mut HashMap, + res: &mut Vec>, + ) { match self { Expr::Double(x) => { x.to_polish_(cache, res); @@ -1468,20 +1670,30 @@ impl Expr> { x.to_polish_(cache, res); res.push(PolishToken::Pow(*d)) } - Expr::Constant(c) => { - c.to_polish_(res); + Expr::Atom(ExprInner::Constant(c)) => { + c.to_polish(cache, res); } - Expr::Cell(v) => res.push(PolishToken::Cell(*v)), - Expr::VanishesOnZeroKnowledgeAndPreviousRows => { + Expr::Atom(ExprInner::Cell(v)) => res.push(PolishToken::Cell(*v)), + Expr::Atom(ExprInner::VanishesOnZeroKnowledgeAndPreviousRows) => { res.push(PolishToken::VanishesOnZeroKnowledgeAndPreviousRows); } - Expr::UnnormalizedLagrangeBasis(i) => { + Expr::Atom(ExprInner::UnnormalizedLagrangeBasis(i)) => { res.push(PolishToken::UnnormalizedLagrangeBasis(*i)); } - Expr::BinOp(op, x, y) => { + Expr::Add(x, y) => { + x.to_polish_(cache, res); + y.to_polish_(cache, res); + res.push(PolishToken::Add); + } + Expr::Sub(x, y) => { x.to_polish_(cache, res); y.to_polish_(cache, res); - res.push(op.to_polish()); + res.push(PolishToken::Sub); + } + Expr::Mul(x, y) => { + x.to_polish_(cache, res); + y.to_polish_(cache, res); + res.push(PolishToken::Mul); } Expr::Cache(id, e) => { match cache.get(id) { @@ -1527,79 +1739,92 @@ impl Expr> { } } } +} - /// The expression `beta`. - pub fn beta() -> Self { - Expr::Constant(ConstantExpr::Beta) - } - - fn evaluate_constants_(&self, c: &Constants) -> Expr { - use Expr::*; +impl + Expr, Column> +{ + fn evaluate_constants_( + &self, + c: &Constants, + chals: &dyn Index, + ) -> Expr { + use ExprInner::*; + use Operations::*; // TODO: Use cache match self { - Double(x) => x.evaluate_constants_(c).double(), - Pow(x, d) => x.evaluate_constants_(c).pow(*d), - Square(x) => x.evaluate_constants_(c).square(), - Constant(x) => Constant(x.value(c)), - Cell(v) => Cell(*v), - VanishesOnZeroKnowledgeAndPreviousRows => VanishesOnZeroKnowledgeAndPreviousRows, - UnnormalizedLagrangeBasis(i) => UnnormalizedLagrangeBasis(*i), - BinOp(Op2::Add, x, y) => x.evaluate_constants_(c) + y.evaluate_constants_(c), - BinOp(Op2::Mul, x, y) => x.evaluate_constants_(c) * y.evaluate_constants_(c), - BinOp(Op2::Sub, x, y) => x.evaluate_constants_(c) - y.evaluate_constants_(c), - Cache(id, e) => Cache(*id, Box::new(e.evaluate_constants_(c))), + Double(x) => x.evaluate_constants_(c, chals).double(), + Pow(x, d) => x.evaluate_constants_(c, chals).pow(*d), + Square(x) => x.evaluate_constants_(c, chals).square(), + Atom(Constant(x)) => Atom(Constant(x.value(c, chals))), + Atom(Cell(v)) => Atom(Cell(*v)), + Atom(VanishesOnZeroKnowledgeAndPreviousRows) => { + Atom(VanishesOnZeroKnowledgeAndPreviousRows) + } + Atom(UnnormalizedLagrangeBasis(i)) => Atom(UnnormalizedLagrangeBasis(*i)), + Add(x, y) => x.evaluate_constants_(c, chals) + y.evaluate_constants_(c, chals), + Mul(x, y) => x.evaluate_constants_(c, chals) * y.evaluate_constants_(c, chals), + Sub(x, y) => x.evaluate_constants_(c, chals) - y.evaluate_constants_(c, chals), + Cache(id, e) => Cache(*id, Box::new(e.evaluate_constants_(c, chals))), IfFeature(feature, e1, e2) => IfFeature( *feature, - Box::new(e1.evaluate_constants_(c)), - Box::new(e2.evaluate_constants_(c)), + Box::new(e1.evaluate_constants_(c, chals)), + Box::new(e2.evaluate_constants_(c, chals)), ), } } /// Evaluate an expression as a field element against an environment. - pub fn evaluate( + pub fn evaluate< + 'a, + Evaluations: ColumnEvaluations, + Challenge: Index, + Environment: ColumnEnvironment<'a, F, ChallengeTerm, Challenge, Column = Column>, + >( &self, d: D, pt: F, - evals: &ProofEvaluations>, - env: &Environment, - ) -> Result { - self.evaluate_(d, pt, evals, &env.constants) + evals: &Evaluations, + env: &Environment, + ) -> Result> { + self.evaluate_(d, pt, evals, env.get_constants(), env.get_challenges()) } /// Evaluate an expression as a field element against the constants. - pub fn evaluate_( + pub fn evaluate_>( &self, d: D, pt: F, - evals: &ProofEvaluations>, + evals: &Evaluations, c: &Constants, - ) -> Result { - use Expr::*; + chals: &dyn Index, + ) -> Result> { + use ExprInner::*; + use Operations::*; match self { - Double(x) => x.evaluate_(d, pt, evals, c).map(|x| x.double()), - Constant(x) => Ok(x.value(c)), - Pow(x, p) => Ok(x.evaluate_(d, pt, evals, c)?.pow([*p])), - BinOp(Op2::Mul, x, y) => { - let x = (*x).evaluate_(d, pt, evals, c)?; - let y = (*y).evaluate_(d, pt, evals, c)?; + Double(x) => x.evaluate_(d, pt, evals, c, chals).map(|x| x.double()), + Atom(Constant(x)) => Ok(x.value(c, chals)), + Pow(x, p) => Ok(x.evaluate_(d, pt, evals, c, chals)?.pow([*p])), + Mul(x, y) => { + let x = (*x).evaluate_(d, pt, evals, c, chals)?; + let y = (*y).evaluate_(d, pt, evals, c, chals)?; Ok(x * y) } - Square(x) => Ok(x.evaluate_(d, pt, evals, c)?.square()), - BinOp(Op2::Add, x, y) => { - let x = (*x).evaluate_(d, pt, evals, c)?; - let y = (*y).evaluate_(d, pt, evals, c)?; + Square(x) => Ok(x.evaluate_(d, pt, evals, c, chals)?.square()), + Add(x, y) => { + let x = (*x).evaluate_(d, pt, evals, c, chals)?; + let y = (*y).evaluate_(d, pt, evals, c, chals)?; Ok(x + y) } - BinOp(Op2::Sub, x, y) => { - let x = (*x).evaluate_(d, pt, evals, c)?; - let y = (*y).evaluate_(d, pt, evals, c)?; + Sub(x, y) => { + let x = (*x).evaluate_(d, pt, evals, c, chals)?; + let y = (*y).evaluate_(d, pt, evals, c, chals)?; Ok(x - y) } - VanishesOnZeroKnowledgeAndPreviousRows => { + Atom(VanishesOnZeroKnowledgeAndPreviousRows) => { Ok(eval_vanishes_on_last_n_rows(d, c.zk_rows + 1, pt)) } - UnnormalizedLagrangeBasis(i) => { + Atom(UnnormalizedLagrangeBasis(i)) => { let offset = if i.zk_rows { -(c.zk_rows as i32) + i.offset } else { @@ -1607,68 +1832,91 @@ impl Expr> { }; Ok(unnormalized_lagrange_basis(&d, offset, &pt)) } - Cell(v) => v.evaluate(evals), - Cache(_, e) => e.evaluate_(d, pt, evals, c), + Atom(Cell(v)) => v.evaluate(evals), + Cache(_, e) => e.evaluate_(d, pt, evals, c, chals), IfFeature(feature, e1, e2) => { if feature.is_enabled() { - e1.evaluate_(d, pt, evals, c) + e1.evaluate_(d, pt, evals, c, chals) } else { - e2.evaluate_(d, pt, evals, c) + e2.evaluate_(d, pt, evals, c, chals) } } } } /// Evaluate the constant expressions in this expression down into field elements. - pub fn evaluate_constants(&self, env: &Environment) -> Expr { - self.evaluate_constants_(&env.constants) + pub fn evaluate_constants< + 'a, + Challenge: Index, + Environment: ColumnEnvironment<'a, F, ChallengeTerm, Challenge, Column = Column>, + >( + &self, + env: &Environment, + ) -> Expr { + self.evaluate_constants_(env.get_constants(), env.get_challenges()) } /// Compute the polynomial corresponding to this expression, in evaluation form. - pub fn evaluations(&self, env: &Environment<'_, F>) -> Evaluations> { + /// The routine will first replace the constants (verifier challenges and + /// constants like the matrix used by `Poseidon`) in the expression with their + /// respective values using `evaluate_constants` and will after evaluate the + /// monomials with the corresponding column values using the method + /// `evaluations`. + pub fn evaluations< + 'a, + Challenge: Index, + Environment: ColumnEnvironment<'a, F, ChallengeTerm, Challenge, Column = Column>, + >( + &self, + env: &Environment, + ) -> Evaluations> { self.evaluate_constants(env).evaluations(env) } } +/// Use as a result of the expression evaluations routine. +/// For now, the left branch is the result of an evaluation and the right branch +/// is the ID of an element in the cache enum Either { Left(A), Right(B), } -impl Expr { +impl Expr { /// Evaluate an expression into a field element. - pub fn evaluate( + pub fn evaluate>( &self, d: D, pt: F, zk_rows: u64, - evals: &ProofEvaluations>, - ) -> Result { - use Expr::*; + evals: &Evaluations, + ) -> Result> { + use ExprInner::*; + use Operations::*; match self { - Constant(x) => Ok(*x), + Atom(Constant(x)) => Ok(*x), Pow(x, p) => Ok(x.evaluate(d, pt, zk_rows, evals)?.pow([*p])), Double(x) => x.evaluate(d, pt, zk_rows, evals).map(|x| x.double()), Square(x) => x.evaluate(d, pt, zk_rows, evals).map(|x| x.square()), - BinOp(Op2::Mul, x, y) => { + Mul(x, y) => { let x = (*x).evaluate(d, pt, zk_rows, evals)?; let y = (*y).evaluate(d, pt, zk_rows, evals)?; Ok(x * y) } - BinOp(Op2::Add, x, y) => { + Add(x, y) => { let x = (*x).evaluate(d, pt, zk_rows, evals)?; let y = (*y).evaluate(d, pt, zk_rows, evals)?; Ok(x + y) } - BinOp(Op2::Sub, x, y) => { + Sub(x, y) => { let x = (*x).evaluate(d, pt, zk_rows, evals)?; let y = (*y).evaluate(d, pt, zk_rows, evals)?; Ok(x - y) } - VanishesOnZeroKnowledgeAndPreviousRows => { + Atom(VanishesOnZeroKnowledgeAndPreviousRows) => { Ok(eval_vanishes_on_last_n_rows(d, zk_rows + 1, pt)) } - UnnormalizedLagrangeBasis(i) => { + Atom(UnnormalizedLagrangeBasis(i)) => { let offset = if i.zk_rows { -(zk_rows as i32) + i.offset } else { @@ -1676,7 +1924,7 @@ impl Expr { }; Ok(unnormalized_lagrange_basis(&d, offset, &pt)) } - Cell(v) => v.evaluate(evals), + Atom(Cell(v)) => v.evaluate(evals), Cache(_, e) => e.evaluate(d, pt, zk_rows, evals), IfFeature(feature, e1, e2) => { if feature.is_enabled() { @@ -1689,9 +1937,17 @@ impl Expr { } /// Compute the polynomial corresponding to this expression, in evaluation form. - pub fn evaluations(&self, env: &Environment<'_, F>) -> Evaluations> { - let d1_size = env.domain.d1.size; - let deg = self.degree(d1_size, env.constants.zk_rows); + pub fn evaluations< + 'a, + ChallengeTerm, + Challenge: Index, + Environment: ColumnEnvironment<'a, F, ChallengeTerm, Challenge, Column = Column>, + >( + &self, + env: &Environment, + ) -> Evaluations> { + let d1_size = env.get_domain(Domain::D1).size; + let deg = self.degree(d1_size, env.get_constants().zk_rows); let d = if deg <= d1_size { Domain::D1 } else if deg <= 4 * d1_size { @@ -1714,15 +1970,20 @@ impl Expr { assert_eq!(domain, d); evals } - EvalResult::Constant(x) => EvalResult::init_((d, get_domain(d, env)), |_| x), + EvalResult::Constant(x) => EvalResult::init_((d, env.get_domain(d)), |_| x), EvalResult::SubEvals { evals, domain: d_sub, shift: s, } => { - let res_domain = get_domain(d, env); + let res_domain = env.get_domain(d); let scale = (d_sub as usize) / (d as usize); - assert!(scale != 0); + assert!( + scale != 0, + "Check that the implementation of + column_domain and the evaluation domain of the + witnesses are the same" + ); EvalResult::init_((d, res_domain), |i| { evals.evals[(scale * i + (d_sub as usize) * s) % evals.evals.len()] }) @@ -1730,16 +1991,22 @@ impl Expr { } } - fn evaluations_helper<'a, 'b>( + fn evaluations_helper< + 'a, + 'b, + ChallengeTerm, + Challenge: Index, + Environment: ColumnEnvironment<'a, F, ChallengeTerm, Challenge, Column = Column>, + >( &self, cache: &'b mut HashMap>, d: Domain, - env: &Environment<'a, F>, + env: &Environment, ) -> Either, CacheId> where 'a: 'b, { - let dom = (d, get_domain(d, env)); + let dom = (d, env.get_domain(d)); let res: EvalResult<'a, F> = match self { Expr::Square(x) => match x.evaluations_helper(cache, d, env) { @@ -1801,30 +2068,30 @@ impl Expr { Expr::Pow(x, p) => { let x = x.evaluations_helper(cache, d, env); match x { - Either::Left(x) => x.pow(*p, (d, get_domain(d, env))), + Either::Left(x) => x.pow(*p, (d, env.get_domain(d))), Either::Right(id) => { - id.get_from(cache).unwrap().pow(*p, (d, get_domain(d, env))) + id.get_from(cache).unwrap().pow(*p, (d, env.get_domain(d))) } } } - Expr::VanishesOnZeroKnowledgeAndPreviousRows => EvalResult::SubEvals { + Expr::Atom(ExprInner::VanishesOnZeroKnowledgeAndPreviousRows) => EvalResult::SubEvals { domain: Domain::D8, shift: 0, - evals: env.vanishes_on_zero_knowledge_and_previous_rows, + evals: env.vanishes_on_zero_knowledge_and_previous_rows(), }, - Expr::Constant(x) => EvalResult::Constant(*x), - Expr::UnnormalizedLagrangeBasis(i) => { + Expr::Atom(ExprInner::Constant(x)) => EvalResult::Constant(*x), + Expr::Atom(ExprInner::UnnormalizedLagrangeBasis(i)) => { let offset = if i.zk_rows { - -(env.constants.zk_rows as i32) + i.offset + -(env.get_constants().zk_rows as i32) + i.offset } else { i.offset }; EvalResult::Evals { domain: d, - evals: unnormalized_lagrange_evals(env.l0_1, offset, d, env), + evals: unnormalized_lagrange_evals(env.l0_1(), offset, d, env), } } - Expr::Cell(Variable { col, row }) => { + Expr::Atom(ExprInner::Cell(Variable { col, row })) => { let evals: &'a Evaluations> = { match env.get_column(col) { None => return Either::Left(EvalResult::Constant(F::zero())), @@ -1832,18 +2099,44 @@ impl Expr { } }; EvalResult::SubEvals { - domain: col.domain(), + domain: env.column_domain(col), shift: row.shift(), evals, } } - Expr::BinOp(op, e1, e2) => { - let dom = (d, get_domain(d, env)); - let f = |x: EvalResult, y: EvalResult| match op { - Op2::Mul => x.mul(y, dom), - Op2::Add => x.add(y, dom), - Op2::Sub => x.sub(y, dom), - }; + Expr::Add(e1, e2) => { + let dom = (d, env.get_domain(d)); + let f = |x: EvalResult, y: EvalResult| x.add(y, dom); + let e1 = e1.evaluations_helper(cache, d, env); + let e2 = e2.evaluations_helper(cache, d, env); + use Either::*; + match (e1, e2) { + (Left(e1), Left(e2)) => f(e1, e2), + (Right(id1), Left(e2)) => f(id1.get_from(cache).unwrap(), e2), + (Left(e1), Right(id2)) => f(e1, id2.get_from(cache).unwrap()), + (Right(id1), Right(id2)) => { + f(id1.get_from(cache).unwrap(), id2.get_from(cache).unwrap()) + } + } + } + Expr::Sub(e1, e2) => { + let dom = (d, env.get_domain(d)); + let f = |x: EvalResult, y: EvalResult| x.sub(y, dom); + let e1 = e1.evaluations_helper(cache, d, env); + let e2 = e2.evaluations_helper(cache, d, env); + use Either::*; + match (e1, e2) { + (Left(e1), Left(e2)) => f(e1, e2), + (Right(id1), Left(e2)) => f(id1.get_from(cache).unwrap(), e2), + (Left(e1), Right(id2)) => f(e1, id2.get_from(cache).unwrap()), + (Right(id1), Right(id2)) => { + f(id1.get_from(cache).unwrap(), id2.get_from(cache).unwrap()) + } + } + } + Expr::Mul(e1, e2) => { + let dom = (d, env.get_domain(d)); + let f = |x: EvalResult, y: EvalResult| x.mul(y, dom); let e1 = e1.evaluations_helper(cache, d, env); let e2 = e2.evaluations_helper(cache, d, env); use Either::*; @@ -1874,12 +2167,12 @@ impl Expr { #[derive(Clone, Debug, Serialize, Deserialize)] /// A "linearization", which is linear combination with `E` coefficients of /// columns. -pub struct Linearization { +pub struct Linearization { pub constant_term: E, pub index_terms: Vec<(Column, E)>, } -impl Default for Linearization { +impl Default for Linearization { fn default() -> Self { Linearization { constant_term: E::default(), @@ -1888,9 +2181,9 @@ impl Default for Linearization { } } -impl Linearization { +impl Linearization { /// Apply a function to all the coefficients in the linearization. - pub fn map B>(&self, f: F) -> Linearization { + pub fn map B>(&self, f: F) -> Linearization { Linearization { constant_term: f(&self.constant_term), index_terms: self.index_terms.iter().map(|(c, x)| (*c, f(x))).collect(), @@ -1898,28 +2191,46 @@ impl Linearization { } } -impl Linearization>> { +impl + Linearization, Column>, Column> +{ /// Evaluate the constants in a linearization with `ConstantExpr` coefficients down /// to literal field elements. - pub fn evaluate_constants(&self, env: &Environment) -> Linearization> { + pub fn evaluate_constants< + 'a, + Challenge: Index, + Environment: ColumnEnvironment<'a, F, ChallengeTerm, Challenge, Column = Column>, + >( + &self, + env: &Environment, + ) -> Linearization, Column> { self.map(|e| e.evaluate_constants(env)) } } -impl Linearization>> { +impl + Linearization>, Column> +{ /// Given a linearization and an environment, compute the polynomial corresponding to the /// linearization, in evaluation form. - pub fn to_polynomial( + pub fn to_polynomial< + 'a, + Challenge: Index, + ColEvaluations: ColumnEvaluations, + Environment: ColumnEnvironment<'a, F, ChallengeTerm, Challenge, Column = Column>, + >( &self, - env: &Environment, + env: &Environment, pt: F, - evals: &ProofEvaluations>, + evals: &ColEvaluations, ) -> (F, Evaluations>) { - let cs = &env.constants; - let n = env.domain.d1.size(); + let cs = env.get_constants(); + let chals = env.get_challenges(); + let d1 = env.get_domain(Domain::D1); + let n = d1.size(); let mut res = vec![F::zero(); n]; self.index_terms.iter().for_each(|(idx, c)| { - let c = PolishToken::evaluate(c, env.domain.d1, pt, evals, cs).unwrap(); + let c = PolishToken::evaluate(c, d1, pt, evals, cs, chals).unwrap(); let e = env .get_column(idx) .unwrap_or_else(|| panic!("Index polynomial {idx:?} not found")); @@ -1928,28 +2239,37 @@ impl Linearization>> { .enumerate() .for_each(|(i, r)| *r += c * e.evals[scale * i]); }); - let p = Evaluations::>::from_vec_and_domain(res, env.domain.d1); + let p = Evaluations::>::from_vec_and_domain(res, d1); ( - PolishToken::evaluate(&self.constant_term, env.domain.d1, pt, evals, cs).unwrap(), + PolishToken::evaluate(&self.constant_term, d1, pt, evals, cs, chals).unwrap(), p, ) } } -impl Linearization>> { +impl + Linearization, Column>, Column> +{ /// Given a linearization and an environment, compute the polynomial corresponding to the /// linearization, in evaluation form. - pub fn to_polynomial( + pub fn to_polynomial< + 'a, + Challenge: Index, + ColEvaluations: ColumnEvaluations, + Environment: ColumnEnvironment<'a, F, ChallengeTerm, Challenge, Column = Column>, + >( &self, - env: &Environment, + env: &Environment, pt: F, - evals: &ProofEvaluations>, + evals: &ColEvaluations, ) -> (F, DensePolynomial) { - let cs = &env.constants; - let n = env.domain.d1.size(); + let cs = env.get_constants(); + let chals = env.get_challenges(); + let d1 = env.get_domain(Domain::D1); + let n = d1.size(); let mut res = vec![F::zero(); n]; self.index_terms.iter().for_each(|(idx, c)| { - let c = c.evaluate_(env.domain.d1, pt, evals, cs).unwrap(); + let c = c.evaluate_(d1, pt, evals, cs, chals).unwrap(); let e = env .get_column(idx) .unwrap_or_else(|| panic!("Index polynomial {idx:?} not found")); @@ -1958,77 +2278,81 @@ impl Linearization>> { .enumerate() .for_each(|(i, r)| *r += c * e.evals[scale * i]) }); - let p = Evaluations::>::from_vec_and_domain(res, env.domain.d1).interpolate(); + let p = Evaluations::>::from_vec_and_domain(res, d1).interpolate(); ( self.constant_term - .evaluate_(env.domain.d1, pt, evals, cs) + .evaluate_(d1, pt, evals, cs, chals) .unwrap(), p, ) } } -impl Expr { - /// Exponentiate an expression - #[must_use] - pub fn pow(self, p: u64) -> Self { - use Expr::*; - if p == 0 { - return Constant(F::one()); - } - Pow(Box::new(self), p) - } -} - -type Monomials = HashMap, Expr>; +type Monomials = HashMap>, Expr>; -fn mul_monomials + Clone + One + Zero + PartialEq>( - e1: &Monomials, - e2: &Monomials, -) -> Monomials { - let mut res: HashMap<_, Expr> = HashMap::new(); +fn mul_monomials< + F: Neg + Clone + One + Zero + PartialEq, + Column: Ord + Copy + std::hash::Hash, +>( + e1: &Monomials, + e2: &Monomials, +) -> Monomials +where + ExprInner: Literal, + as Literal>::F: Field, +{ + let mut res: HashMap<_, Expr> = HashMap::new(); for (m1, c1) in e1.iter() { for (m2, c2) in e2.iter() { let mut m = m1.clone(); m.extend(m2); m.sort(); let c1c2 = c1.clone() * c2.clone(); - let v = res.entry(m).or_insert_with(Expr::::zero); + let v = res.entry(m).or_insert_with(Expr::::zero); *v = v.clone() + c1c2; } } res } -impl + Clone + One + Zero + PartialEq> Expr { +impl + Clone + One + Zero + PartialEq, Column: Ord + Copy + std::hash::Hash> + Expr +where + ExprInner: Literal, + as Literal>::F: Field, +{ // TODO: This function (which takes linear time) // is called repeatedly in monomials, yielding quadratic behavior for // that function. It's ok for now as we only call that function once on // a small input when producing the verification key. fn is_constant(&self, evaluated: &HashSet) -> bool { - use Expr::*; + use ExprInner::*; + use Operations::*; match self { Pow(x, _) => x.is_constant(evaluated), Square(x) => x.is_constant(evaluated), - Constant(_) => true, - Cell(v) => evaluated.contains(&v.col), + Atom(Constant(_)) => true, + Atom(Cell(v)) => evaluated.contains(&v.col), Double(x) => x.is_constant(evaluated), - BinOp(_, x, y) => x.is_constant(evaluated) && y.is_constant(evaluated), - VanishesOnZeroKnowledgeAndPreviousRows => true, - UnnormalizedLagrangeBasis(_) => true, + Add(x, y) | Sub(x, y) | Mul(x, y) => { + x.is_constant(evaluated) && y.is_constant(evaluated) + } + Atom(VanishesOnZeroKnowledgeAndPreviousRows) => true, + Atom(UnnormalizedLagrangeBasis(_)) => true, Cache(_, x) => x.is_constant(evaluated), IfFeature(_, e1, e2) => e1.is_constant(evaluated) && e2.is_constant(evaluated), } } - fn monomials(&self, ev: &HashSet) -> HashMap, Expr> { - let sing = |v: Vec, c: Expr| { + fn monomials(&self, ev: &HashSet) -> HashMap>, Expr> { + let sing = |v: Vec>, c: Expr| { let mut h = HashMap::new(); h.insert(v, c); h }; - let constant = |e: Expr| sing(vec![], e); - use Expr::*; + let constant = |e: Expr| sing(vec![], e); + use ExprInner::*; + use Operations::*; if self.is_constant(ev) { return constant(self.clone()); @@ -2037,7 +2361,7 @@ impl + Clone + One + Zero + PartialEq> Expr { match self { Pow(x, d) => { // Run the multiplication logic with square and multiply - let mut acc = sing(vec![], Expr::::one()); + let mut acc = sing(vec![], Expr::::one()); let mut acc_is_one = true; let x = x.monomials(ev); @@ -2059,13 +2383,13 @@ impl + Clone + One + Zero + PartialEq> Expr { HashMap::from_iter(e.monomials(ev).into_iter().map(|(m, c)| (m, c.double()))) } Cache(_, e) => e.monomials(ev), - UnnormalizedLagrangeBasis(i) => constant(UnnormalizedLagrangeBasis(*i)), - VanishesOnZeroKnowledgeAndPreviousRows => { - constant(VanishesOnZeroKnowledgeAndPreviousRows) + Atom(UnnormalizedLagrangeBasis(i)) => constant(Atom(UnnormalizedLagrangeBasis(*i))), + Atom(VanishesOnZeroKnowledgeAndPreviousRows) => { + constant(Atom(VanishesOnZeroKnowledgeAndPreviousRows)) } - Constant(c) => constant(Constant(c.clone())), - Cell(var) => sing(vec![*var], Constant(F::one())), - BinOp(Op2::Add, e1, e2) => { + Atom(Constant(c)) => constant(Atom(Constant(c.clone()))), + Atom(Cell(var)) => sing(vec![*var], Atom(Constant(F::one()))), + Add(e1, e2) => { let mut res = e1.monomials(ev); for (m, c) in e2.monomials(ev) { let v = match res.remove(&m) { @@ -2076,7 +2400,7 @@ impl + Clone + One + Zero + PartialEq> Expr { } res } - BinOp(Op2::Sub, e1, e2) => { + Sub(e1, e2) => { let mut res = e1.monomials(ev); for (m, c) in e2.monomials(ev) { let v = match res.remove(&m) { @@ -2087,7 +2411,7 @@ impl + Clone + One + Zero + PartialEq> Expr { } res } - BinOp(Op2::Mul, e1, e2) => { + Mul(e1, e2) => { let e1 = e1.monomials(ev); let e2 = e2.monomials(ev); mul_monomials(&e1, &e2) @@ -2132,7 +2456,7 @@ impl + Clone + One + Zero + PartialEq> Expr { /// are the same thing as polynomials with `F[V_1]` coefficients in variables `V_0`. /// /// There is also a function - /// `lin_or_err : (F[V_1])[V_0] -> Result, &str>` + /// `lin_or_err : (F[V_1])[V_0] -> Result, &str>` /// /// which checks if the given input is in fact a degree 1 polynomial in the variables `V_0` /// (i.e., a linear combination of `V_0` elements with `F[V_1]` coefficients) @@ -2146,15 +2470,17 @@ impl + Clone + One + Zero + PartialEq> Expr { pub fn linearize( &self, evaluated: HashSet, - ) -> Result>, ExprError> { - let mut res: HashMap> = HashMap::new(); - let mut constant_term: Expr = Self::zero(); + ) -> Result, Column>, ExprError> { + let mut res: HashMap> = HashMap::new(); + let mut constant_term: Expr = Self::zero(); let monomials = self.monomials(&evaluated); for (m, c) in monomials { let (evaluated, mut unevaluated): (Vec<_>, _) = m.into_iter().partition(|v| evaluated.contains(&v.col)); - let c = evaluated.into_iter().fold(c, |acc, v| acc * Expr::Cell(v)); + let c = evaluated + .into_iter() + .fold(c, |acc, v| acc * Expr::Atom(ExprInner::Cell(v))); if unevaluated.is_empty() { constant_term += c; } else if unevaluated.len() == 1 { @@ -2195,134 +2521,59 @@ impl + Clone + One + Zero + PartialEq> Expr { // Trait implementations -impl Zero for ConstantExpr { - fn zero() -> Self { - ConstantExpr::Literal(F::zero()) - } - - fn is_zero(&self) -> bool { - match self { - ConstantExpr::Literal(x) => x.is_zero(), - _ => false, - } - } -} - -impl One for ConstantExpr { - fn one() -> Self { - ConstantExpr::Literal(F::one()) - } - - fn is_one(&self) -> bool { - match self { - ConstantExpr::Literal(x) => x.is_one(), - _ => false, - } - } -} - -impl> Neg for ConstantExpr { - type Output = ConstantExpr; - - fn neg(self) -> ConstantExpr { - match self { - ConstantExpr::Literal(x) => ConstantExpr::Literal(x.neg()), - e => ConstantExpr::Mul(Box::new(ConstantExpr::Literal(F::one().neg())), Box::new(e)), - } - } -} - -impl Add> for ConstantExpr { - type Output = ConstantExpr; - fn add(self, other: Self) -> Self { - use ConstantExpr::{Add, Literal}; - if self.is_zero() { - return other; - } - if other.is_zero() { - return self; - } - match (self, other) { - (Literal(x), Literal(y)) => Literal(x + y), - (x, y) => Add(Box::new(x), Box::new(y)), - } - } -} - -impl Sub> for ConstantExpr { - type Output = ConstantExpr; - fn sub(self, other: Self) -> Self { - use ConstantExpr::{Literal, Sub}; - if other.is_zero() { - return self; - } - match (self, other) { - (Literal(x), Literal(y)) => Literal(x - y), - (x, y) => Sub(Box::new(x), Box::new(y)), - } - } -} - -impl Mul> for ConstantExpr { - type Output = ConstantExpr; - fn mul(self, other: Self) -> Self { - use ConstantExpr::{Literal, Mul}; - if self.is_one() { - return other; - } - if other.is_one() { - return self; - } - match (self, other) { - (Literal(x), Literal(y)) => Literal(x * y), - (x, y) => Mul(Box::new(x), Box::new(y)), - } - } -} - -impl Zero for Expr { +impl Zero for Operations +where + T::F: Field, +{ fn zero() -> Self { - Expr::Constant(F::zero()) + Self::literal(T::F::zero()) } fn is_zero(&self) -> bool { - match self { - Expr::Constant(x) => x.is_zero(), - _ => false, + if let Some(x) = self.to_literal_ref() { + x.is_zero() + } else { + false } } } -impl One for Expr { +impl One for Operations +where + T::F: Field, +{ fn one() -> Self { - Expr::Constant(F::one()) + Self::literal(T::F::one()) } fn is_one(&self) -> bool { - match self { - Expr::Constant(x) => x.is_one(), - _ => false, + if let Some(x) = self.to_literal_ref() { + x.is_one() + } else { + false } } } -impl> Neg for Expr { - type Output = Expr; +impl Neg for Operations +where + T::F: One + Neg + Copy, +{ + type Output = Self; - fn neg(self) -> Expr { - match self { - Expr::Constant(x) => Expr::Constant(x.neg()), - e => Expr::BinOp( - Op2::Mul, - Box::new(Expr::Constant(F::one().neg())), - Box::new(e), - ), + fn neg(self) -> Self { + match self.to_literal() { + Ok(x) => Self::literal(x.neg()), + Err(x) => Operations::Mul(Box::new(Self::literal(T::F::one().neg())), Box::new(x)), } } } -impl Add> for Expr { - type Output = Expr; +impl Add for Operations +where + T::F: Field, +{ + type Output = Self; fn add(self, other: Self) -> Self { if self.is_zero() { return other; @@ -2330,22 +2581,44 @@ impl Add> for Expr { if other.is_zero() { return self; } - Expr::BinOp(Op2::Add, Box::new(self), Box::new(other)) + let (x, y) = { + match (self.to_literal(), other.to_literal()) { + (Ok(x), Ok(y)) => return Self::literal(x + y), + (Ok(x), Err(y)) => (Self::literal(x), y), + (Err(x), Ok(y)) => (x, Self::literal(y)), + (Err(x), Err(y)) => (x, y), + } + }; + Operations::Add(Box::new(x), Box::new(y)) } -} - -impl AddAssign> for Expr { - fn add_assign(&mut self, other: Self) { - if self.is_zero() { - *self = other; - } else if !other.is_zero() { - *self = Expr::BinOp(Op2::Add, Box::new(self.clone()), Box::new(other)); +} + +impl Sub for Operations +where + T::F: Field, +{ + type Output = Self; + fn sub(self, other: Self) -> Self { + if other.is_zero() { + return self; } + let (x, y) = { + match (self.to_literal(), other.to_literal()) { + (Ok(x), Ok(y)) => return Self::literal(x - y), + (Ok(x), Err(y)) => (Self::literal(x), y), + (Err(x), Ok(y)) => (x, Self::literal(y)), + (Err(x), Err(y)) => (x, y), + } + }; + Operations::Sub(Box::new(x), Box::new(y)) } } -impl Mul> for Expr { - type Output = Expr; +impl Mul for Operations +where + T::F: Field, +{ + type Output = Self; fn mul(self, other: Self) -> Self { if self.is_zero() || other.is_zero() { return Self::zero(); @@ -2357,13 +2630,38 @@ impl Mul> for Expr { if other.is_one() { return self; } - Expr::BinOp(Op2::Mul, Box::new(self), Box::new(other)) + let (x, y) = { + match (self.to_literal(), other.to_literal()) { + (Ok(x), Ok(y)) => return Self::literal(x * y), + (Ok(x), Err(y)) => (Self::literal(x), y), + (Err(x), Ok(y)) => (x, Self::literal(y)), + (Err(x), Err(y)) => (x, y), + } + }; + Operations::Mul(Box::new(x), Box::new(y)) + } +} + +impl AddAssign> for Expr +where + ExprInner: Literal, + as Literal>::F: Field, +{ + fn add_assign(&mut self, other: Self) { + if self.is_zero() { + *self = other; + } else if !other.is_zero() { + *self = Expr::Add(Box::new(self.clone()), Box::new(other)); + } } } -impl MulAssign> for Expr +impl MulAssign> for Expr where F: Zero + One + PartialEq + Clone, + Column: PartialEq + Clone, + ExprInner: Literal, + as Literal>::F: Field, { fn mul_assign(&mut self, other: Self) { if self.is_zero() || other.is_zero() { @@ -2371,44 +2669,38 @@ where } else if self.is_one() { *self = other; } else if !other.is_one() { - *self = Expr::BinOp(Op2::Mul, Box::new(self.clone()), Box::new(other)); - } - } -} - -impl Sub> for Expr { - type Output = Expr; - fn sub(self, other: Self) -> Self { - if other.is_zero() { - return self; + *self = Expr::Mul(Box::new(self.clone()), Box::new(other)); } - Expr::BinOp(Op2::Sub, Box::new(self), Box::new(other)) } } -impl From for Expr { +impl From for Expr { fn from(x: u64) -> Self { - Expr::Constant(F::from(x)) + Expr::Atom(ExprInner::Constant(F::from(x))) } } -impl From for Expr> { +impl<'a, F: Field, Column, ChallengeTerm: AlphaChallengeTerm<'a>> From + for Expr, Column> +{ fn from(x: u64) -> Self { - Expr::Constant(ConstantExpr::Literal(F::from(x))) + ConstantTerm::Literal(F::from(x)).into() } } -impl From for ConstantExpr { +impl From for ConstantExpr { fn from(x: u64) -> Self { - ConstantExpr::Literal(F::from(x)) + ConstantTerm::Literal(F::from(x)).into() } } -impl Mul for Expr> { - type Output = Expr>; +impl<'a, F: Field, Column: PartialEq + Copy, ChallengeTerm: AlphaChallengeTerm<'a>> Mul + for Expr, Column> +{ + type Output = Expr, Column>; fn mul(self, y: F) -> Self::Output { - Expr::Constant(ConstantExpr::Literal(y)) * self + Expr::from(ConstantTerm::Literal(y)) * self } } @@ -2416,116 +2708,309 @@ impl Mul for Expr> { // Display // -impl ConstantExpr +pub trait FormattedOutput: Sized { + fn is_alpha(&self) -> bool; + fn ocaml(&self, cache: &mut HashMap) -> String; + fn latex(&self, cache: &mut HashMap) -> String; + fn text(&self, cache: &mut HashMap) -> String; +} + +impl<'a, ChallengeTerm> FormattedOutput for ChallengeTerm where - F: PrimeField, + ChallengeTerm: AlphaChallengeTerm<'a>, { - fn ocaml(&self) -> String { - use ConstantExpr::*; + fn is_alpha(&self) -> bool { + self.eq(&ChallengeTerm::ALPHA) + } + fn ocaml(&self, _cache: &mut HashMap) -> String { + self.to_string() + } + + fn latex(&self, _cache: &mut HashMap) -> String { + "\\".to_string() + &self.to_string() + } + + fn text(&self, _cache: &mut HashMap) -> String { + self.to_string() + } +} + +impl FormattedOutput for ConstantTerm { + fn is_alpha(&self) -> bool { + false + } + fn ocaml(&self, _cache: &mut HashMap) -> String { + use ConstantTerm::*; match self { - Alpha => "alpha".to_string(), - Beta => "beta".to_string(), - Gamma => "gamma".to_string(), - JointCombiner => "joint_combiner".to_string(), EndoCoefficient => "endo_coefficient".to_string(), Mds { row, col } => format!("mds({row}, {col})"), Literal(x) => format!( "field(\"{:#066X}\")", Into::::into(x.into_bigint()) ), - Pow(x, n) => match x.as_ref() { - Alpha => format!("alpha_pow({n})"), - x => format!("pow({}, {n})", x.ocaml()), - }, - Add(x, y) => format!("({} + {})", x.ocaml(), y.ocaml()), - Mul(x, y) => format!("({} * {})", x.ocaml(), y.ocaml()), - Sub(x, y) => format!("({} - {})", x.ocaml(), y.ocaml()), } } - fn latex(&self) -> String { - use ConstantExpr::*; + fn latex(&self, _cache: &mut HashMap) -> String { + use ConstantTerm::*; match self { - Alpha => "\\alpha".to_string(), - Beta => "\\beta".to_string(), - Gamma => "\\gamma".to_string(), - JointCombiner => "joint\\_combiner".to_string(), EndoCoefficient => "endo\\_coefficient".to_string(), Mds { row, col } => format!("mds({row}, {col})"), Literal(x) => format!("\\mathbb{{F}}({})", x.into_bigint().into()), - Pow(x, n) => match x.as_ref() { - Alpha => format!("\\alpha^{{{n}}}"), - x => format!("{}^{n}", x.ocaml()), - }, - Add(x, y) => format!("({} + {})", x.ocaml(), y.ocaml()), - Mul(x, y) => format!("({} \\cdot {})", x.ocaml(), y.ocaml()), - Sub(x, y) => format!("({} - {})", x.ocaml(), y.ocaml()), } } - fn text(&self) -> String { - use ConstantExpr::*; + fn text(&self, _cache: &mut HashMap) -> String { + use ConstantTerm::*; match self { - Alpha => "alpha".to_string(), - Beta => "beta".to_string(), - Gamma => "gamma".to_string(), - JointCombiner => "joint_combiner".to_string(), EndoCoefficient => "endo_coefficient".to_string(), Mds { row, col } => format!("mds({row}, {col})"), Literal(x) => format!("0x{}", x.to_hex()), - Pow(x, n) => match x.as_ref() { - Alpha => format!("alpha^{n}"), - x => format!("{}^{n}", x.text()), - }, - Add(x, y) => format!("({} + {})", x.text(), y.text()), - Mul(x, y) => format!("({} * {})", x.text(), y.text()), - Sub(x, y) => format!("({} - {})", x.text(), y.text()), } } } -impl Expr> +impl<'a, F: PrimeField, ChallengeTerm> FormattedOutput for ConstantExprInner where - F: PrimeField, + ChallengeTerm: AlphaChallengeTerm<'a>, { - /// Converts the expression in OCaml code - pub fn ocaml_str(&self) -> String { - let mut env = HashMap::new(); - let e = self.ocaml(&mut env); + fn is_alpha(&self) -> bool { + use ConstantExprInner::*; + match self { + Challenge(x) => x.is_alpha(), + Constant(x) => x.is_alpha(), + } + } + fn ocaml(&self, cache: &mut HashMap) -> String { + use ConstantExprInner::*; + match self { + Challenge(x) => { + let mut inner_cache = HashMap::new(); + let res = x.ocaml(&mut inner_cache); + inner_cache.into_iter().for_each(|(k, v)| { + let _ = cache.insert(k, Challenge(v)); + }); + res + } + Constant(x) => { + let mut inner_cache = HashMap::new(); + let res = x.ocaml(&mut inner_cache); + inner_cache.into_iter().for_each(|(k, v)| { + let _ = cache.insert(k, Constant(v)); + }); + res + } + } + } + fn latex(&self, cache: &mut HashMap) -> String { + use ConstantExprInner::*; + match self { + Challenge(x) => { + let mut inner_cache = HashMap::new(); + let res = x.latex(&mut inner_cache); + inner_cache.into_iter().for_each(|(k, v)| { + let _ = cache.insert(k, Challenge(v)); + }); + res + } + Constant(x) => { + let mut inner_cache = HashMap::new(); + let res = x.latex(&mut inner_cache); + inner_cache.into_iter().for_each(|(k, v)| { + let _ = cache.insert(k, Constant(v)); + }); + res + } + } + } + fn text(&self, cache: &mut HashMap) -> String { + use ConstantExprInner::*; + match self { + Challenge(x) => { + let mut inner_cache = HashMap::new(); + let res = x.text(&mut inner_cache); + inner_cache.into_iter().for_each(|(k, v)| { + let _ = cache.insert(k, Challenge(v)); + }); + res + } + Constant(x) => { + let mut inner_cache = HashMap::new(); + let res = x.text(&mut inner_cache); + inner_cache.into_iter().for_each(|(k, v)| { + let _ = cache.insert(k, Constant(v)); + }); + res + } + } + } +} - let mut env: Vec<_> = env.into_iter().collect(); - // HashMap deliberately uses an unstable order; here we sort to ensure that the output is - // consistent when printing. - env.sort_by(|(x, _), (y, _)| x.cmp(y)); +impl FormattedOutput for Variable { + fn is_alpha(&self) -> bool { + false + } - let mut res = String::new(); - for (k, v) in env { - let rhs = v.ocaml_str(); - let cached = format!("let {} = {rhs} in ", k.var_name()); - res.push_str(&cached); + fn ocaml(&self, _cache: &mut HashMap) -> String { + format!("var({:?}, {:?})", self.col, self.row) + } + + fn latex(&self, _cache: &mut HashMap) -> String { + let col = self.col.latex(&mut HashMap::new()); + match self.row { + Curr => col, + Next => format!("\\tilde{{{col}}}"), } + } - res.push_str(&e); - res + fn text(&self, _cache: &mut HashMap) -> String { + let col = self.col.text(&mut HashMap::new()); + match self.row { + Curr => format!("Curr({col})"), + Next => format!("Next({col})"), + } + } +} + +impl FormattedOutput for Operations { + fn is_alpha(&self) -> bool { + match self { + Operations::Atom(x) => x.is_alpha(), + _ => false, + } + } + fn ocaml(&self, cache: &mut HashMap) -> String { + use Operations::*; + match self { + Atom(x) => { + let mut inner_cache = HashMap::new(); + let res = x.ocaml(&mut inner_cache); + inner_cache.into_iter().for_each(|(k, v)| { + let _ = cache.insert(k, Atom(v)); + }); + res + } + Pow(x, n) => { + if x.is_alpha() { + format!("alpha_pow({n})") + } else { + format!("pow({}, {n})", x.ocaml(cache)) + } + } + Add(x, y) => format!("({} + {})", x.ocaml(cache), y.ocaml(cache)), + Mul(x, y) => format!("({} * {})", x.ocaml(cache), y.ocaml(cache)), + Sub(x, y) => format!("({} - {})", x.ocaml(cache), y.ocaml(cache)), + Double(x) => format!("double({})", x.ocaml(cache)), + Square(x) => format!("square({})", x.ocaml(cache)), + Cache(id, e) => { + cache.insert(*id, e.as_ref().clone()); + id.var_name() + } + IfFeature(feature, e1, e2) => { + format!( + "if_feature({:?}, (fun () -> {}), (fun () -> {}))", + feature, + e1.ocaml(cache), + e2.ocaml(cache) + ) + } + } + } + + fn latex(&self, cache: &mut HashMap) -> String { + use Operations::*; + match self { + Atom(x) => { + let mut inner_cache = HashMap::new(); + let res = x.latex(&mut inner_cache); + inner_cache.into_iter().for_each(|(k, v)| { + let _ = cache.insert(k, Atom(v)); + }); + res + } + Pow(x, n) => format!("{}^{{{n}}}", x.latex(cache)), + Add(x, y) => format!("({} + {})", x.latex(cache), y.latex(cache)), + Mul(x, y) => format!("({} \\cdot {})", x.latex(cache), y.latex(cache)), + Sub(x, y) => format!("({} - {})", x.latex(cache), y.latex(cache)), + Double(x) => format!("2 ({})", x.latex(cache)), + Square(x) => format!("({})^2", x.latex(cache)), + Cache(id, e) => { + cache.insert(*id, e.as_ref().clone()); + id.var_name() + } + IfFeature(feature, _, _) => format!("{feature:?}"), + } + } + + fn text(&self, cache: &mut HashMap) -> String { + use Operations::*; + match self { + Atom(x) => { + let mut inner_cache = HashMap::new(); + let res = x.text(&mut inner_cache); + inner_cache.into_iter().for_each(|(k, v)| { + let _ = cache.insert(k, Atom(v)); + }); + res + } + Pow(x, n) => format!("{}^{n}", x.text(cache)), + Add(x, y) => format!("({} + {})", x.text(cache), y.text(cache)), + Mul(x, y) => format!("({} * {})", x.text(cache), y.text(cache)), + Sub(x, y) => format!("({} - {})", x.text(cache), y.text(cache)), + Double(x) => format!("double({})", x.text(cache)), + Square(x) => format!("square({})", x.text(cache)), + Cache(id, e) => { + cache.insert(*id, e.as_ref().clone()); + id.var_name() + } + IfFeature(feature, _, _) => format!("{feature:?}"), + } } +} +impl<'a, F, Column: FormattedOutput + Debug + Clone, ChallengeTerm> FormattedOutput + for Expr, Column> +where + F: PrimeField, + ChallengeTerm: AlphaChallengeTerm<'a>, +{ + fn is_alpha(&self) -> bool { + use ExprInner::*; + use Operations::*; + match self { + Atom(Constant(x)) => x.is_alpha(), + _ => false, + } + } + /// Converts the expression in OCaml code /// Recursively print the expression, /// except for the cached expression that are stored in the `cache`. - fn ocaml(&self, cache: &mut HashMap>>) -> String { - use Expr::*; + fn ocaml( + &self, + cache: &mut HashMap, Column>>, + ) -> String { + use ExprInner::*; + use Operations::*; match self { Double(x) => format!("double({})", x.ocaml(cache)), - Constant(x) => x.ocaml(), - Cell(v) => format!("cell({})", v.ocaml()), - UnnormalizedLagrangeBasis(i) => { + Atom(Constant(x)) => { + let mut inner_cache = HashMap::new(); + let res = x.ocaml(&mut inner_cache); + inner_cache.into_iter().for_each(|(k, v)| { + let _ = cache.insert(k, Atom(Constant(v))); + }); + res + } + Atom(Cell(v)) => format!("cell({})", v.ocaml(&mut HashMap::new())), + Atom(UnnormalizedLagrangeBasis(i)) => { format!("unnormalized_lagrange_basis({}, {})", i.zk_rows, i.offset) } - VanishesOnZeroKnowledgeAndPreviousRows => { + Atom(VanishesOnZeroKnowledgeAndPreviousRows) => { "vanishes_on_zero_knowledge_and_previous_rows".to_string() } - BinOp(Op2::Add, x, y) => format!("({} + {})", x.ocaml(cache), y.ocaml(cache)), - BinOp(Op2::Mul, x, y) => format!("({} * {})", x.ocaml(cache), y.ocaml(cache)), - BinOp(Op2::Sub, x, y) => format!("({} - {})", x.ocaml(cache), y.ocaml(cache)), + Add(x, y) => format!("({} + {})", x.ocaml(cache), y.ocaml(cache)), + Mul(x, y) => format!("({} * {})", x.ocaml(cache), y.ocaml(cache)), + Sub(x, y) => format!("({} - {})", x.ocaml(cache), y.ocaml(cache)), Pow(x, d) => format!("pow({}, {d})", x.ocaml(cache)), Square(x) => format!("square({})", x.ocaml(cache)), Cache(id, e) => { @@ -2543,51 +3028,41 @@ where } } - /// Converts the expression in LaTeX - pub fn latex_str(&self) -> Vec { - let mut env = HashMap::new(); - let e = self.latex(&mut env); - - let mut env: Vec<_> = env.into_iter().collect(); - // HashMap deliberately uses an unstable order; here we sort to ensure that the output is - // consistent when printing. - env.sort_by(|(x, _), (y, _)| x.cmp(y)); - - let mut res = vec![]; - for (k, v) in env { - let mut rhs = v.latex_str(); - let last = rhs.pop().expect("returned an empty expression"); - res.push(format!("{} = {last}", k.latex_name())); - res.extend(rhs); - } - res.push(e); - res - } - - fn latex(&self, cache: &mut HashMap>>) -> String { - use Expr::*; + fn latex( + &self, + cache: &mut HashMap, Column>>, + ) -> String { + use ExprInner::*; + use Operations::*; match self { Double(x) => format!("2 ({})", x.latex(cache)), - Constant(x) => x.latex(), - Cell(v) => v.latex(), - UnnormalizedLagrangeBasis(RowOffset { + Atom(Constant(x)) => { + let mut inner_cache = HashMap::new(); + let res = x.latex(&mut inner_cache); + inner_cache.into_iter().for_each(|(k, v)| { + let _ = cache.insert(k, Atom(Constant(v))); + }); + res + } + Atom(Cell(v)) => v.latex(&mut HashMap::new()), + Atom(UnnormalizedLagrangeBasis(RowOffset { zk_rows: true, offset: i, - }) => { + })) => { format!("unnormalized\\_lagrange\\_basis(zk\\_rows + {})", *i) } - UnnormalizedLagrangeBasis(RowOffset { + Atom(UnnormalizedLagrangeBasis(RowOffset { zk_rows: false, offset: i, - }) => { + })) => { format!("unnormalized\\_lagrange\\_basis({})", *i) } - VanishesOnZeroKnowledgeAndPreviousRows => { + Atom(VanishesOnZeroKnowledgeAndPreviousRows) => { "vanishes\\_on\\_zero\\_knowledge\\_and\\_previous\\_row".to_string() } - BinOp(Op2::Add, x, y) => format!("({} + {})", x.latex(cache), y.latex(cache)), - BinOp(Op2::Mul, x, y) => format!("({} \\cdot {})", x.latex(cache), y.latex(cache)), - BinOp(Op2::Sub, x, y) => format!("({} - {})", x.latex(cache), y.latex(cache)), + Add(x, y) => format!("({} + {})", x.latex(cache), y.latex(cache)), + Mul(x, y) => format!("({} \\cdot {})", x.latex(cache), y.latex(cache)), + Sub(x, y) => format!("({} - {})", x.latex(cache), y.latex(cache)), Pow(x, d) => format!("{}^{{{d}}}", x.latex(cache)), Square(x) => format!("({})^2", x.latex(cache)), Cache(id, e) => { @@ -2600,32 +3075,43 @@ where /// Recursively print the expression, /// except for the cached expression that are stored in the `cache`. - fn text(&self, cache: &mut HashMap>>) -> String { - use Expr::*; + fn text( + &self, + cache: &mut HashMap, Column>>, + ) -> String { + use ExprInner::*; + use Operations::*; match self { Double(x) => format!("double({})", x.text(cache)), - Constant(x) => x.text(), - Cell(v) => v.text(), - UnnormalizedLagrangeBasis(RowOffset { + Atom(Constant(x)) => { + let mut inner_cache = HashMap::new(); + let res = x.text(&mut inner_cache); + inner_cache.into_iter().for_each(|(k, v)| { + let _ = cache.insert(k, Atom(Constant(v))); + }); + res + } + Atom(Cell(v)) => v.text(&mut HashMap::new()), + Atom(UnnormalizedLagrangeBasis(RowOffset { zk_rows: true, offset: i, - }) => match i.cmp(&0) { + })) => match i.cmp(&0) { Ordering::Greater => format!("unnormalized_lagrange_basis(zk_rows + {})", *i), Ordering::Equal => "unnormalized_lagrange_basis(zk_rows)".to_string(), Ordering::Less => format!("unnormalized_lagrange_basis(zk_rows - {})", (-*i)), }, - UnnormalizedLagrangeBasis(RowOffset { + Atom(UnnormalizedLagrangeBasis(RowOffset { zk_rows: false, offset: i, - }) => { + })) => { format!("unnormalized_lagrange_basis({})", *i) } - VanishesOnZeroKnowledgeAndPreviousRows => { + Atom(VanishesOnZeroKnowledgeAndPreviousRows) => { "vanishes_on_zero_knowledge_and_previous_rows".to_string() } - BinOp(Op2::Add, x, y) => format!("({} + {})", x.text(cache), y.text(cache)), - BinOp(Op2::Mul, x, y) => format!("({} * {})", x.text(cache), y.text(cache)), - BinOp(Op2::Sub, x, y) => format!("({} - {})", x.text(cache), y.text(cache)), + Add(x, y) => format!("({} + {})", x.text(cache), y.text(cache)), + Mul(x, y) => format!("({} * {})", x.text(cache), y.text(cache)), + Sub(x, y) => format!("({} - {})", x.text(cache), y.text(cache)), Pow(x, d) => format!("pow({}, {d})", x.text(cache)), Square(x) => format!("square({})", x.text(cache)), Cache(id, e) => { @@ -2635,24 +3121,33 @@ where IfFeature(feature, _, _) => format!("{feature:?}"), } } +} - /// Converts the expression to a text string - pub fn text_str(&self) -> String { +impl<'a, F, Column: FormattedOutput + Debug + Clone, ChallengeTerm> + Expr, Column> +where + F: PrimeField, + ChallengeTerm: AlphaChallengeTerm<'a>, +{ + /// Converts the expression in LaTeX + // It is only used by visual tooling like kimchi-visu + pub fn latex_str(&self) -> Vec { let mut env = HashMap::new(); - let e = self.text(&mut env); + let e = self.latex(&mut env); let mut env: Vec<_> = env.into_iter().collect(); - // HashMap deliberately uses an unstable order; here we sort to ensure that the output is - // consistent when printing. + // HashMap deliberately uses an unstable order; here we sort to ensure + // that the output is consistent when printing. env.sort_by(|(x, _), (y, _)| x.cmp(y)); - let mut res = String::new(); + let mut res = vec![]; for (k, v) in env { - let str = format!("{} = {}", k.text_name(), v.text_str()); - res.push_str(&str); + let mut rhs = v.latex_str(); + let last = rhs.pop().expect("returned an empty expression"); + res.push(format!("{} = {last}", k.latex_name())); + res.extend(rhs); } - - res.push_str(&e); + res.push(e); res } } @@ -2669,17 +3164,18 @@ pub mod constraints { use std::fmt; use super::*; + use crate::circuits::berkeley_columns::{coeff, witness}; /// This trait defines a common arithmetic operations interface /// that can be used by constraints. It allows us to reuse /// constraint code for witness computation. - pub trait ExprOps: - std::ops::Add - + std::ops::Sub - + std::ops::Neg - + std::ops::Mul - + std::ops::AddAssign - + std::ops::MulAssign + pub trait ExprOps: + Add + + Sub + + Neg + + Mul + + AddAssign + + MulAssign + Clone + Zero + One @@ -2727,30 +3223,42 @@ pub mod constraints { fn coeff(col: usize, env: Option<&ArgumentData>) -> Self; /// Create a constant - fn constant(expr: ConstantExpr, env: Option<&ArgumentData>) -> Self; + fn constant(expr: ConstantExpr, env: Option<&ArgumentData>) -> Self; /// Cache item fn cache(&self, cache: &mut Cache) -> Self; } - - impl ExprOps for Expr> + // TODO generalize with generic Column/challengeterm + // We need to create a trait for berkeley_columns::Environment + impl ExprOps + for Expr, berkeley_columns::Column> where F: PrimeField, + // TODO remove + Expr, berkeley_columns::Column>: std::fmt::Display, { fn two_pow(pow: u64) -> Self { - Expr::>::literal(>::two_pow(pow)) + Expr::, berkeley_columns::Column>::literal( + >::two_pow(pow), + ) } fn two_to_limb() -> Self { - Expr::>::literal(>::two_to_limb()) + Expr::, berkeley_columns::Column>::literal( + KimchiForeignElement::::two_to_limb(), + ) } fn two_to_2limb() -> Self { - Expr::>::literal(>::two_to_2limb()) + Expr::, berkeley_columns::Column>::literal( + KimchiForeignElement::::two_to_2limb(), + ) } fn two_to_3limb() -> Self { - Expr::>::literal(>::two_to_3limb()) + Expr::, berkeley_columns::Column>::literal( + KimchiForeignElement::::two_to_3limb(), + ) } fn double(&self) -> Self { @@ -2774,7 +3282,7 @@ pub mod constraints { } fn literal(x: F) -> Self { - Expr::Constant(ConstantExpr::Literal(x)) + ConstantTerm::Literal(x).into() } fn witness(row: CurrOrNext, col: usize, _: Option<&ArgumentData>) -> Self { @@ -2785,30 +3293,34 @@ pub mod constraints { coeff(col) } - fn constant(expr: ConstantExpr, _: Option<&ArgumentData>) -> Self { - Expr::Constant(expr) + fn constant( + expr: ConstantExpr, + _: Option<&ArgumentData>, + ) -> Self { + Expr::from(expr) } fn cache(&self, cache: &mut Cache) -> Self { Expr::Cache(cache.next_id(), Box::new(self.clone())) } } - - impl ExprOps for F { + // TODO generalize with generic Column/challengeterm + // We need to generalize argument.rs + impl ExprOps for F { fn two_pow(pow: u64) -> Self { >::two_pow(pow) } fn two_to_limb() -> Self { - >::two_to_limb() + KimchiForeignElement::::two_to_limb() } fn two_to_2limb() -> Self { - >::two_to_2limb() + KimchiForeignElement::::two_to_2limb() } fn two_to_3limb() -> Self { - >::two_to_3limb() + KimchiForeignElement::::two_to_3limb() } fn double(&self) -> Self { @@ -2849,9 +3361,12 @@ pub mod constraints { } } - fn constant(expr: ConstantExpr, env: Option<&ArgumentData>) -> Self { + fn constant( + expr: ConstantExpr, + env: Option<&ArgumentData>, + ) -> Self { match env { - Some(data) => expr.value(&data.constants), + Some(data) => expr.value(&data.constants, &data.challenges), None => panic!("Missing constants"), } } @@ -2862,12 +3377,12 @@ pub mod constraints { } /// Creates a constraint to enforce that b is either 0 or 1. - pub fn boolean>(b: &T) -> T { + pub fn boolean>(b: &T) -> T { b.square() - b.clone() } /// Crumb constraint for 2-bit value x - pub fn crumb>(x: &T) -> T { + pub fn crumb>(x: &T) -> T { // Assert x \in [0,3] i.e. assert x*(x - 1)*(x - 2)*(x - 3) == 0 x.clone() * (x.clone() - 1u64.into()) @@ -2876,47 +3391,14 @@ pub mod constraints { } /// lo + mi * 2^{LIMB_BITS} - pub fn compact_limb>(lo: &T, mi: &T) -> T { + pub fn compact_limb>( + lo: &T, + mi: &T, + ) -> T { lo.clone() + mi.clone() * T::two_to_limb() } } -// -// Helpers -// - -/// An alias for the intended usage of the expression type in constructing constraints. -pub type E = Expr>; - -/// Convenience function to create a constant as [Expr]. -pub fn constant(x: F) -> E { - Expr::Constant(ConstantExpr::Literal(x)) -} - -/// Helper function to quickly create an expression for a witness. -pub fn witness(i: usize, row: CurrOrNext) -> E { - E::::cell(Column::Witness(i), row) -} - -/// Same as [witness] but for the current row. -pub fn witness_curr(i: usize) -> E { - witness(i, CurrOrNext::Curr) -} - -/// Same as [witness] but for the next row. -pub fn witness_next(i: usize) -> E { - witness(i, CurrOrNext::Next) -} - -/// Handy function to quickly create an expression for a gate. -pub fn index(g: GateType) -> E { - E::::cell(Column::Index(g), CurrOrNext::Curr) -} - -pub fn coeff(i: usize) -> E { - E::::cell(Column::Coefficient(i), CurrOrNext::Curr) -} - /// Auto clone macro - Helps make constraints more readable /// by eliminating requirement to .clone() all the time #[macro_export] @@ -2945,155 +3427,8 @@ pub use auto_clone_array; /// You can import this module like `use kimchi::circuits::expr::prologue::*` to obtain a number of handy aliases and helpers pub mod prologue { - pub use super::{coeff, constant, index, witness, witness_curr, witness_next, FeatureFlag, E}; -} - -#[cfg(test)] -pub mod test { - use super::*; - use crate::{ - circuits::{ - constraints::ConstraintSystem, expr::constraints::ExprOps, gate::CircuitGate, - polynomials::generic::GenericGateSpec, wires::Wire, - }, - curve::KimchiCurve, - prover_index::ProverIndex, + pub use super::{ + berkeley_columns::{coeff, constant, index, witness, witness_curr, witness_next, E}, + FeatureFlag, }; - use ark_ff::UniformRand; - use mina_curves::pasta::{Fp, Pallas, Vesta}; - use poly_commitment::{ - evaluation_proof::OpeningProof, - srs::{endos, SRS}, - }; - use std::{array, sync::Arc}; - - #[test] - #[should_panic] - fn test_failed_linearize() { - // w0 * w1 - let mut expr: E = E::zero(); - expr += witness_curr(0); - expr *= witness_curr(1); - - // since none of w0 or w1 is evaluated this should panic - let evaluated = HashSet::new(); - expr.linearize(evaluated).unwrap(); - } - - #[test] - #[should_panic] - fn test_degree_tracking() { - // The selector CompleteAdd has degree n-1 (so can be tracked with n evaluations in the domain d1 of size n). - // Raising a polynomial of degree n-1 to the power 8 makes it degree 8*(n-1) (and so it needs `8(n-1) + 1` evaluations). - // Since `d8` is of size `8n`, we are still good with that many evaluations to track the new polynomial. - // Raising it to the power 9 pushes us out of the domain d8, which will panic. - let mut expr: E = E::zero(); - expr += index(GateType::CompleteAdd); - let expr = expr.pow(9); - - // create a dummy env - let one = Fp::from(1u32); - let gates = vec![ - CircuitGate::create_generic_gadget( - Wire::for_row(0), - GenericGateSpec::Const(1u32.into()), - None, - ), - CircuitGate::create_generic_gadget( - Wire::for_row(1), - GenericGateSpec::Const(1u32.into()), - None, - ), - ]; - let index = { - let constraint_system = ConstraintSystem::fp_for_testing(gates); - let srs = SRS::::create(constraint_system.domain.d1.size()); - srs.get_lagrange_basis(constraint_system.domain.d1); - let srs = Arc::new(srs); - - let (endo_q, _endo_r) = endos::(); - ProverIndex::>::create(constraint_system, endo_q, srs) - }; - - let witness_cols: [_; COLUMNS] = array::from_fn(|_| DensePolynomial::zero()); - let permutation = DensePolynomial::zero(); - let domain_evals = index.cs.evaluate(&witness_cols, &permutation); - - let env = Environment { - constants: Constants { - alpha: one, - beta: one, - gamma: one, - joint_combiner: None, - endo_coefficient: one, - mds: &Vesta::sponge_params().mds, - zk_rows: 3, - }, - witness: &domain_evals.d8.this.w, - coefficient: &index.column_evaluations.coefficients8, - vanishes_on_zero_knowledge_and_previous_rows: &index - .cs - .precomputations() - .vanishes_on_zero_knowledge_and_previous_rows, - z: &domain_evals.d8.this.z, - l0_1: l0_1(index.cs.domain.d1), - domain: index.cs.domain, - index: HashMap::new(), - lookup: None, - }; - - // this should panic as we don't have a domain large enough - expr.evaluations(&env); - } - - #[test] - fn test_unnormalized_lagrange_basis() { - let zk_rows = 3; - let domain = EvaluationDomains::::create(2usize.pow(10) + zk_rows) - .expect("failed to create evaluation domain"); - let rng = &mut o1_utils::tests::make_test_rng(None); - - // Check that both ways of computing lagrange basis give the same result - let d1_size: i32 = domain.d1.size().try_into().expect("domain size too big"); - for i in 1..d1_size { - let pt = Fp::rand(rng); - assert_eq!( - unnormalized_lagrange_basis(&domain.d1, d1_size - i, &pt), - unnormalized_lagrange_basis(&domain.d1, -i, &pt) - ); - } - } - - #[test] - fn test_arithmetic_ops() { - fn test_1>() -> T { - T::zero() + T::one() - } - assert_eq!(test_1::>(), E::zero() + E::one()); - assert_eq!(test_1::(), Fp::one()); - - fn test_2>() -> T { - T::one() + T::one() - } - assert_eq!(test_2::>(), E::one() + E::one()); - assert_eq!(test_2::(), Fp::from(2u64)); - - fn test_3>(x: T) -> T { - T::from(2u64) * x - } - assert_eq!( - test_3::>(E::from(3u64)), - E::from(2u64) * E::from(3u64) - ); - assert_eq!(test_3(Fp::from(3u64)), Fp::from(6u64)); - - fn test_4>(x: T) -> T { - x.clone() * (x.square() + T::from(7u64)) - } - assert_eq!( - test_4::>(E::from(5u64)), - E::from(5u64) * (Expr::square(E::from(5u64)) + E::from(7u64)) - ); - assert_eq!(test_4::(Fp::from(5u64)), Fp::from(160u64)); - } } diff --git a/kimchi/src/circuits/gate.rs b/kimchi/src/circuits/gate.rs index a03481b04c..69eb4c4894 100644 --- a/kimchi/src/circuits/gate.rs +++ b/kimchi/src/circuits/gate.rs @@ -3,10 +3,11 @@ use crate::{ circuits::{ argument::{Argument, ArgumentEnv}, + berkeley_columns::BerkeleyChallenges, constraints::ConstraintSystem, polynomials::{ - complete_add, endomul_scalar, endosclmul, foreign_field_add, foreign_field_mul, - poseidon, range_check, turshi, varbasemul, + complete_add, endomul_scalar, endosclmul, foreign_field_add, foreign_field_mul, keccak, + poseidon, range_check, rot, turshi, varbasemul, xor, }, wires::*, }, @@ -20,11 +21,7 @@ use serde::{Deserialize, Serialize}; use serde_with::serde_as; use thiserror::Error; -use super::{ - argument::ArgumentWitness, - expr, - polynomials::{rot, xor}, -}; +use super::{argument::ArgumentWitness, expr}; /// A row accessible from a given row, corresponds to the fact that we open all polynomials /// at `zeta` **and** `omega * zeta`. @@ -111,6 +108,8 @@ pub enum GateType { // Gates for Keccak Xor16, Rot64, + KeccakRound, + KeccakSponge, } /// Gate error @@ -212,6 +211,12 @@ impl CircuitGate { Rot64 => self .verify_witness::(row, witness, &index.cs, public) .map_err(|e| e.to_string()), + KeccakRound => self + .verify_witness::(row, witness, &index.cs, public) + .map_err(|e| e.to_string()), + KeccakSponge => self + .verify_witness::(row, witness, &index.cs, public) + .map_err(|e| e.to_string()), } } @@ -228,16 +233,24 @@ impl CircuitGate { // Set up the constants. Note that alpha, beta, gamma and joint_combiner // are one because this function is not running the prover. let constants = expr::Constants { - alpha: F::one(), - beta: F::one(), - gamma: F::one(), - joint_combiner: Some(F::one()), endo_coefficient: cs.endo, mds: &G::sponge_params().mds, zk_rows: cs.zk_rows, }; + //TODO : use generic challenges, since we do not need those here + let challenges = BerkeleyChallenges { + alpha: F::one(), + beta: F::one(), + gamma: F::one(), + joint_combiner: F::one(), + }; // Create the argument environment for the constraints over field elements - let env = ArgumentEnv::::create(argument_witness, self.coeffs.clone(), constants); + let env = ArgumentEnv::::create( + argument_witness, + self.coeffs.clone(), + constants, + challenges, + ); // Check the wiring (i.e. copy constraints) for this gate // Note: Gates can operated on row Curr or Curr and Next. @@ -306,6 +319,12 @@ impl CircuitGate { } GateType::Xor16 => xor::Xor16::constraint_checks(&env, &mut cache), GateType::Rot64 => rot::Rot64::constraint_checks(&env, &mut cache), + GateType::KeccakRound => { + keccak::circuitgates::KeccakRound::constraint_checks(&env, &mut cache) + } + GateType::KeccakSponge => { + keccak::circuitgates::KeccakSponge::constraint_checks(&env, &mut cache) + } }; // Check for failed constraints @@ -568,14 +587,6 @@ mod tests { use proptest::prelude::*; use rand::SeedableRng as _; - // TODO: move to mina-curves - prop_compose! { - pub fn arb_fp()(seed: [u8; 32]) -> Fp { - let rng = &mut rand::rngs::StdRng::from_seed(seed); - Fp::rand(rng) - } - } - prop_compose! { fn arb_fp_vec(max: usize)(seed: [u8; 32], num in 0..max) -> Vec { let rng = &mut rand::rngs::StdRng::from_seed(seed); diff --git a/kimchi/src/circuits/lookup/constraints.rs b/kimchi/src/circuits/lookup/constraints.rs index 3b5c38b4d4..5aca05eee5 100644 --- a/kimchi/src/circuits/lookup/constraints.rs +++ b/kimchi/src/circuits/lookup/constraints.rs @@ -1,6 +1,7 @@ use crate::{ circuits::{ - expr::{prologue::*, Column, ConstantExpr, RowOffset}, + berkeley_columns::{BerkeleyChallengeTerm, Column}, + expr::{prologue::*, ConstantExpr, ConstantTerm, ExprInner, RowOffset}, gate::{CircuitGate, CurrOrNext}, lookup::lookups::{ JointLookup, JointLookupSpec, JointLookupValue, LocalPosition, LookupInfo, @@ -392,8 +393,10 @@ pub fn constraints( let column = |col: Column| E::cell(col, Curr); // gamma * (beta + 1) - let gammabeta1 = - E::::Constant(ConstantExpr::Gamma * (ConstantExpr::Beta + ConstantExpr::one())); + let gammabeta1 = E::::from( + ConstantExpr::from(BerkeleyChallengeTerm::Gamma) + * (ConstantExpr::from(BerkeleyChallengeTerm::Beta) + ConstantExpr::one()), + ); // the numerator part in the multiset check of plookup let numerator = { @@ -420,7 +423,7 @@ pub fn constraints( E::one() - lookup_indicator }; - let joint_combiner = E::Constant(ConstantExpr::JointCombiner); + let joint_combiner = E::from(BerkeleyChallengeTerm::JointCombiner); let table_id_combiner = // Compute `joint_combiner.pow(lookup_info.max_joint_size)`, injecting feature flags if // needed. @@ -443,16 +446,16 @@ pub fn constraints( .dummy_lookup .entry .iter() - .map(|x| E::Constant(ConstantExpr::Literal(*x))) + .map(|x| ConstantTerm::Literal(*x).into()) .collect(), - table_id: E::Constant(ConstantExpr::Literal(configuration.dummy_lookup.table_id)), + table_id: ConstantTerm::Literal(configuration.dummy_lookup.table_id).into(), }; expr_dummy.evaluate(&joint_combiner, &table_id_combiner) }; // (1 + beta)^max_per_row let beta1_per_row: E = { - let beta1 = E::Constant(ConstantExpr::one() + ConstantExpr::Beta); + let beta1 = E::from(ConstantExpr::one() + BerkeleyChallengeTerm::Beta.into()); // Compute beta1.pow(lookup_info.max_per_row) let mut res = beta1.clone(); for i in 1..lookup_info.max_per_row { @@ -470,11 +473,11 @@ pub fn constraints( }; // pre-compute the padding dummies we can use depending on the number of lookups to the `max_per_row` lookups - // each value is also multipled with (1 + beta)^max_per_row + // each value is also multiplied with (1 + beta)^max_per_row // as we need to multiply the denominator with this eventually let dummy_padding = |spec_len| { let mut res = E::one(); - let dummy = E::Constant(ConstantExpr::Gamma) + dummy_lookup.clone(); + let dummy: E<_> = E::from(BerkeleyChallengeTerm::Gamma) + dummy_lookup.clone(); for i in spec_len..lookup_info.max_per_row { let mut dummy_used = dummy.clone(); if generate_feature_flags { @@ -507,10 +510,10 @@ pub fn constraints( let eval = |pos: LocalPosition| witness(pos.column, pos.row); spec.iter() .map(|j| { - E::Constant(ConstantExpr::Gamma) + E::from(BerkeleyChallengeTerm::Gamma) + j.evaluate(&joint_combiner, &table_id_combiner, &eval) }) - .fold(padding, |acc: E, x| acc * x) + .fold(padding, |acc: E, x: E| acc * x) }; // f part of the numerator @@ -539,7 +542,7 @@ pub fn constraints( // t part of the numerator let t_chunk = gammabeta1.clone() + E::cell(Column::LookupTable, Curr) - + E::beta() * E::cell(Column::LookupTable, Next); + + E::from(BerkeleyChallengeTerm::Beta) * E::cell(Column::LookupTable, Next); // return the numerator f_chunk * t_chunk @@ -586,7 +589,7 @@ pub fn constraints( // gamma * (beta + 1) + sorted[i](x w) + beta * sorted[i](x) let mut expr = gammabeta1.clone() + E::cell(Column::LookupSorted(i), s1) - + E::beta() * E::cell(Column::LookupSorted(i), s2); + + E::from(BerkeleyChallengeTerm::Beta) * E::cell(Column::LookupSorted(i), s2); if generate_feature_flags { expr = E::IfFeature( FeatureFlag::LookupsPerRow(i as isize), @@ -610,14 +613,14 @@ pub fn constraints( let mut res = vec![ // the accumulator except for the last zk_rows+1 rows // (contains the zk-rows and the last value of the accumulator) - E::VanishesOnZeroKnowledgeAndPreviousRows * aggreg_equation, + E::Atom(ExprInner::VanishesOnZeroKnowledgeAndPreviousRows) * aggreg_equation, // the initial value of the accumulator - E::UnnormalizedLagrangeBasis(RowOffset { + E::Atom(ExprInner::UnnormalizedLagrangeBasis(RowOffset { zk_rows: false, offset: 0, - }) * (E::cell(Column::LookupAggreg, Curr) - E::one()), + })) * (E::cell(Column::LookupAggreg, Curr) - E::one()), // Check that the final value of the accumulator is 1 - E::UnnormalizedLagrangeBasis(final_lookup_row) + E::Atom(ExprInner::UnnormalizedLagrangeBasis(final_lookup_row)) * (E::cell(Column::LookupAggreg, Curr) - E::one()), ]; @@ -634,7 +637,7 @@ pub fn constraints( offset: 0, } }; - let mut expr = E::UnnormalizedLagrangeBasis(first_or_last) + let mut expr = E::Atom(ExprInner::UnnormalizedLagrangeBasis(first_or_last)) * (column(Column::LookupSorted(i)) - column(Column::LookupSorted(i + 1))); if generate_feature_flags { expr = E::IfFeature( diff --git a/kimchi/src/circuits/lookup/runtime_tables.rs b/kimchi/src/circuits/lookup/runtime_tables.rs index 98b018ed8a..ef2d9f7d15 100644 --- a/kimchi/src/circuits/lookup/runtime_tables.rs +++ b/kimchi/src/circuits/lookup/runtime_tables.rs @@ -2,10 +2,10 @@ //! The setup has to prepare for their presence using [`RuntimeTableCfg`]. //! At proving time, the prover can use [`RuntimeTable`] to specify the actual tables. -use crate::circuits::{ - expr::{prologue::*, Column}, - gate::CurrOrNext, -}; +// TODO: write cargo specifications + +use crate::circuits::{berkeley_columns::Column, expr::prologue::*, gate::CurrOrNext}; + use ark_ff::Field; use serde::{Deserialize, Serialize}; diff --git a/kimchi/src/circuits/lookup/tables/mod.rs b/kimchi/src/circuits/lookup/tables/mod.rs index cd183cb714..2c5b3f3f92 100644 --- a/kimchi/src/circuits/lookup/tables/mod.rs +++ b/kimchi/src/circuits/lookup/tables/mod.rs @@ -5,6 +5,9 @@ use serde::{Deserialize, Serialize}; pub mod range_check; pub mod xor; +// If you add new tables, update ../../../../../book/src/kimchi/lookup.md +// accordingly + //~ spec:startcode /// The table ID associated with the XOR lookup table. pub const XOR_TABLE_ID: i32 = 0; diff --git a/kimchi/src/circuits/mod.rs b/kimchi/src/circuits/mod.rs index 83dbc91c63..7bad1cfd9f 100644 --- a/kimchi/src/circuits/mod.rs +++ b/kimchi/src/circuits/mod.rs @@ -2,6 +2,7 @@ pub mod macros; pub mod argument; +pub mod berkeley_columns; pub mod constraints; pub mod domain_constant_evaluation; pub mod domains; diff --git a/kimchi/src/circuits/polynomials/complete_add.rs b/kimchi/src/circuits/polynomials/complete_add.rs index 015bbe2b22..90afcc9847 100644 --- a/kimchi/src/circuits/polynomials/complete_add.rs +++ b/kimchi/src/circuits/polynomials/complete_add.rs @@ -16,6 +16,7 @@ //~ use crate::circuits::{ argument::{Argument, ArgumentEnv, ArgumentType}, + berkeley_columns::BerkeleyChallengeTerm, expr::{constraints::ExprOps, Cache}, gate::{CircuitGate, GateType}, wires::COLUMNS, @@ -30,7 +31,7 @@ use std::marker::PhantomData; /// Additionally, if r == 0, then `z_inv` = 1 / z. /// /// If r == 1 however (i.e., if z == 0), then z_inv is unconstrained. -fn zero_check>(z: T, z_inv: T, r: T) -> Vec { +fn zero_check>(z: T, z_inv: T, r: T) -> Vec { vec![z_inv * z.clone() - (T::one() - r.clone()), r * z] } @@ -98,7 +99,10 @@ where const ARGUMENT_TYPE: ArgumentType = ArgumentType::Gate(GateType::CompleteAdd); const CONSTRAINTS: u32 = 7; - fn constraint_checks>(env: &ArgumentEnv, cache: &mut Cache) -> Vec { + fn constraint_checks>( + env: &ArgumentEnv, + cache: &mut Cache, + ) -> Vec { // This function makes 2 + 1 + 1 + 1 + 2 = 7 constraints let x1 = env.witness_curr(0); let y1 = env.witness_curr(1); diff --git a/kimchi/src/circuits/polynomials/endomul_scalar.rs b/kimchi/src/circuits/polynomials/endomul_scalar.rs index 64ab7a8e6b..156f39270e 100644 --- a/kimchi/src/circuits/polynomials/endomul_scalar.rs +++ b/kimchi/src/circuits/polynomials/endomul_scalar.rs @@ -4,6 +4,7 @@ use crate::{ circuits::{ argument::{Argument, ArgumentEnv, ArgumentType}, + berkeley_columns::BerkeleyChallengeTerm, constraints::ConstraintSystem, expr::{constraints::ExprOps, Cache}, gate::{CircuitGate, GateType}, @@ -49,7 +50,7 @@ impl CircuitGate { } } -fn polynomial>(coeffs: &[F], x: &T) -> T { +fn polynomial>(coeffs: &[F], x: &T) -> T { coeffs .iter() .rev() @@ -166,7 +167,10 @@ where const ARGUMENT_TYPE: ArgumentType = ArgumentType::Gate(GateType::EndoMulScalar); const CONSTRAINTS: u32 = 11; - fn constraint_checks>(env: &ArgumentEnv, cache: &mut Cache) -> Vec { + fn constraint_checks>( + env: &ArgumentEnv, + cache: &mut Cache, + ) -> Vec { let n0 = env.witness_curr(0); let n8 = env.witness_curr(1); let a0 = env.witness_curr(2); diff --git a/kimchi/src/circuits/polynomials/endosclmul.rs b/kimchi/src/circuits/polynomials/endosclmul.rs index 5a95be20b1..d9b2db9943 100644 --- a/kimchi/src/circuits/polynomials/endosclmul.rs +++ b/kimchi/src/circuits/polynomials/endosclmul.rs @@ -5,6 +5,7 @@ use crate::{ circuits::{ argument::{Argument, ArgumentEnv, ArgumentType}, + berkeley_columns::{BerkeleyChallengeTerm, BerkeleyChallenges}, constraints::ConstraintSystem, expr::{ self, @@ -139,21 +140,23 @@ impl CircuitGate { let pt = F::from(123456u64); let constants = expr::Constants { - alpha: F::zero(), - beta: F::zero(), - gamma: F::zero(), - joint_combiner: None, mds: &G::sponge_params().mds, endo_coefficient: cs.endo, zk_rows: cs.zk_rows, }; + let challenges = BerkeleyChallenges { + alpha: F::zero(), + beta: F::zero(), + gamma: F::zero(), + joint_combiner: F::zero(), + }; let evals: ProofEvaluations> = ProofEvaluations::dummy_with_witness_evaluations(this, next); let constraints = EndosclMul::constraints(&mut Cache::default()); for (i, c) in constraints.iter().enumerate() { - match c.evaluate_(cs.domain.d1, pt, &evals, &constants) { + match c.evaluate_(cs.domain.d1, pt, &evals, &constants, &challenges) { Ok(x) => { if x != F::zero() { return Err(format!("Bad endo equation {i}")); @@ -186,7 +189,10 @@ where const ARGUMENT_TYPE: ArgumentType = ArgumentType::Gate(GateType::EndoMul); const CONSTRAINTS: u32 = 11; - fn constraint_checks>(env: &ArgumentEnv, cache: &mut Cache) -> Vec { + fn constraint_checks>( + env: &ArgumentEnv, + cache: &mut Cache, + ) -> Vec { let b1 = env.witness_curr(11); let b2 = env.witness_curr(12); let b3 = env.witness_curr(13); diff --git a/kimchi/src/circuits/polynomials/foreign_field_add/circuitgates.rs b/kimchi/src/circuits/polynomials/foreign_field_add/circuitgates.rs index 8e41a1f724..504fa46822 100644 --- a/kimchi/src/circuits/polynomials/foreign_field_add/circuitgates.rs +++ b/kimchi/src/circuits/polynomials/foreign_field_add/circuitgates.rs @@ -2,14 +2,15 @@ use crate::circuits::{ argument::{Argument, ArgumentEnv, ArgumentType}, + berkeley_columns::BerkeleyChallengeTerm, expr::{ constraints::{compact_limb, ExprOps}, Cache, }, gate::GateType, + polynomials::foreign_field_common::LIMB_COUNT, }; use ark_ff::PrimeField; -use o1_utils::LIMB_COUNT; use std::{array, marker::PhantomData}; //~ These circuit gates are used to constrain that @@ -142,7 +143,10 @@ where const ARGUMENT_TYPE: ArgumentType = ArgumentType::Gate(GateType::ForeignFieldAdd); const CONSTRAINTS: u32 = 4; - fn constraint_checks>(env: &ArgumentEnv, _cache: &mut Cache) -> Vec { + fn constraint_checks>( + env: &ArgumentEnv, + _cache: &mut Cache, + ) -> Vec { let foreign_modulus: [T; LIMB_COUNT] = array::from_fn(|i| env.coeff(i)); // stored as coefficient for better correspondance with the relation being proved @@ -204,7 +208,7 @@ where } // Auxiliary function to obtain the constraints to check a carry flag -fn is_carry>(flag: &T) -> T { +fn is_carry>(flag: &T) -> T { // Carry bits are -1, 0, or 1. flag.clone() * (flag.clone() - T::one()) * (flag.clone() + T::one()) } diff --git a/kimchi/src/circuits/polynomials/foreign_field_add/gadget.rs b/kimchi/src/circuits/polynomials/foreign_field_add/gadget.rs index 96af81e165..4b6c847b83 100644 --- a/kimchi/src/circuits/polynomials/foreign_field_add/gadget.rs +++ b/kimchi/src/circuits/polynomials/foreign_field_add/gadget.rs @@ -2,10 +2,10 @@ use ark_ff::PrimeField; use num_bigint::BigUint; -use o1_utils::foreign_field::BigUintForeignFieldHelpers; use crate::circuits::{ gate::{CircuitGate, Connect, GateType}, + polynomials::foreign_field_common::BigUintForeignFieldHelpers, wires::Wire, }; diff --git a/kimchi/src/circuits/polynomials/foreign_field_add/witness.rs b/kimchi/src/circuits/polynomials/foreign_field_add/witness.rs index e948633370..c10db4d294 100644 --- a/kimchi/src/circuits/polynomials/foreign_field_add/witness.rs +++ b/kimchi/src/circuits/polynomials/foreign_field_add/witness.rs @@ -4,15 +4,16 @@ use crate::{ circuits::{ expr::constraints::compact_limb, polynomial::COLUMNS, + polynomials::foreign_field_common::{ + BigUintForeignFieldHelpers, KimchiForeignElement, HI, LIMB_BITS, LO, MI, + }, witness::{self, ConstantCell, VariableCell, Variables, WitnessCell}, }, variable_map, }; use ark_ff::PrimeField; use num_bigint::BigUint; -use o1_utils::foreign_field::{ - BigUintForeignFieldHelpers, ForeignElement, ForeignFieldHelpers, HI, LO, MI, -}; +use o1_utils::foreign_field::{ForeignElement, ForeignFieldHelpers}; use std::array; /// All foreign field operations allowed @@ -42,17 +43,17 @@ impl FFOps { // - the overflow flag // - the carry value fn compute_ffadd_values( - left_input: &ForeignElement, - right_input: &ForeignElement, + left_input: &ForeignElement, + right_input: &ForeignElement, opcode: FFOps, - foreign_modulus: &ForeignElement, -) -> (ForeignElement, F, F, F) { + foreign_modulus: &ForeignElement, +) -> (ForeignElement, F, F, F) { // Compute bigint version of the inputs let left = left_input.to_biguint(); let right = right_input.to_biguint(); // Clarification: - let right_hi = right_input[3] * F::two_to_limb() + right_input[HI]; // This allows to store 2^88 in the high limb + let right_hi = right_input[3] * KimchiForeignElement::::two_to_limb() + right_input[HI]; // This allows to store 2^88 in the high limb let modulus = foreign_modulus.to_biguint(); @@ -110,7 +111,7 @@ fn compute_ffadd_values( + compact_limb(&right_input[LO], &right_input[MI]) * sign - compact_limb(&foreign_modulus[LO], &foreign_modulus[MI]) * field_overflow - compact_limb(&result[LO], &result[MI])) - / F::two_to_2limb(); + / KimchiForeignElement::::two_to_2limb(); let carry_top: F = result[HI] - left_input[HI] - sign * right_hi + field_overflow * foreign_modulus[HI]; @@ -184,9 +185,9 @@ fn init_ffadd_row( overflow: F, carry: F, ) { - let witness_shape: Vec<[Box>; COLUMNS]> = vec![ + let layout: [Vec>>; 1] = [ // ForeignFieldAdd row - [ + vec![ VariableCell::create("left_lo"), VariableCell::create("left_mi"), VariableCell::create("left_hi"), @@ -208,7 +209,7 @@ fn init_ffadd_row( witness::init( witness, offset, - &witness_shape, + &layout, &variable_map!["left_lo" => left[LO], "left_mi" => left[MI], "left_hi" => left[HI], "right_lo" => right[LO], "right_mi" => right[MI], "right_hi" => right[HI], "overflow" => overflow, "carry" => carry], ); } @@ -220,16 +221,16 @@ fn init_bound_rows( bound: &[F; 3], carry: &F, ) { - let witness_shape: Vec<[Box>; COLUMNS]> = vec![ - [ + let layout: [Vec>>; 2] = [ + vec![ // ForeignFieldAdd row VariableCell::create("result_lo"), VariableCell::create("result_mi"), VariableCell::create("result_hi"), - ConstantCell::create(F::zero()), // 0 - ConstantCell::create(F::zero()), // 0 - ConstantCell::create(F::two_to_limb()), // 2^88 - ConstantCell::create(F::one()), // field_overflow + ConstantCell::create(F::zero()), // 0 + ConstantCell::create(F::zero()), // 0 + ConstantCell::create(KimchiForeignElement::::two_to_limb()), // 2^88 + ConstantCell::create(F::one()), // field_overflow VariableCell::create("carry"), ConstantCell::create(F::zero()), ConstantCell::create(F::zero()), @@ -239,7 +240,7 @@ fn init_bound_rows( ConstantCell::create(F::zero()), ConstantCell::create(F::zero()), ], - [ + vec![ // Zero Row VariableCell::create("bound_lo"), VariableCell::create("bound_mi"), @@ -262,7 +263,7 @@ fn init_bound_rows( witness::init( witness, offset, - &witness_shape, + &layout, &variable_map!["carry" => *carry, "result_lo" => result[LO], "result_mi" => result[MI], "result_hi" => result[HI], "bound_lo" => bound[LO], "bound_mi" => bound[MI], "bound_hi" => bound[HI]], ); } @@ -274,8 +275,8 @@ pub fn extend_witness_bound_addition( foreign_field_modulus: &[F; 3], ) { // Convert to types used by this module - let fe = ForeignElement::::new(*limbs); - let foreign_field_modulus = ForeignElement::::new(*foreign_field_modulus); + let fe = ForeignElement::::new(*limbs); + let foreign_field_modulus = ForeignElement::::new(*foreign_field_modulus); if foreign_field_modulus.to_biguint() > BigUint::max_foreign_field_modulus::() { panic!( "foreign_field_modulus exceeds maximum: {} > {}", @@ -285,7 +286,7 @@ pub fn extend_witness_bound_addition( } // Compute values for final bound check, needs a 4 limb right input - let right_input = ForeignElement::::from_biguint(BigUint::binary_modulus()); + let right_input = ForeignElement::::from_biguint(BigUint::binary_modulus()); // Compute the bound and related witness data let (bound_output, bound_sign, bound_ovf, bound_carry) = diff --git a/kimchi/src/circuits/polynomials/foreign_field_common.rs b/kimchi/src/circuits/polynomials/foreign_field_common.rs new file mode 100644 index 0000000000..0d0cfb103f --- /dev/null +++ b/kimchi/src/circuits/polynomials/foreign_field_common.rs @@ -0,0 +1,366 @@ +//! Common parameters and functions for kimchi's foreign field circuits. + +use o1_utils::{ + field_helpers::FieldHelpers, + foreign_field::{ForeignElement, ForeignFieldHelpers}, +}; + +use ark_ff::{Field, PrimeField}; +use num_bigint::BigUint; +use num_traits::{One, Zero}; +use std::array; + +/// Index of low limb (in 3-limb foreign elements) +pub const LO: usize = 0; +/// Index of middle limb (in 3-limb foreign elements) +pub const MI: usize = 1; +/// Index of high limb (in 3-limb foreign elements) +pub const HI: usize = 2; + +/// Limb length for foreign field elements +pub const LIMB_BITS: usize = 88; + +/// Two to the power of the limb length +pub const TWO_TO_LIMB: u128 = 2u128.pow(LIMB_BITS as u32); + +/// Number of desired limbs for foreign field elements +pub const LIMB_COUNT: usize = 3; + +/// Exponent of binary modulus (i.e. t) +pub const BINARY_MODULUS_EXP: usize = LIMB_BITS * LIMB_COUNT; + +pub type KimchiForeignElement = ForeignElement; + +/// Foreign field helpers +pub trait BigUintForeignFieldHelpers { + /// 2 + fn two() -> Self; + + /// 2^{LIMB_BITS} + fn two_to_limb() -> Self; + + /// 2^{2 * LIMB_BITS} + fn two_to_2limb() -> Self; + + /// 2^t + fn binary_modulus() -> Self; + + /// 2^259 (see foreign field multiplication RFC) + fn max_foreign_field_modulus() -> Self; + + /// Convert to 3 limbs of LIMB_BITS each + fn to_limbs(&self) -> [BigUint; 3]; + + /// Convert to 2 limbs of 2 * LIMB_BITS each. The compressed term is the bottom part + fn to_compact_limbs(&self) -> [BigUint; 2]; + + /// Convert to 3 PrimeField limbs of LIMB_BITS each + fn to_field_limbs(&self) -> [F; 3]; + + /// Convert to 2 PrimeField limbs of 2 * LIMB_BITS each. The compressed term is the bottom part. + fn to_compact_field_limbs(&self) -> [F; 2]; + + /// Negate: 2^T - self + fn negate(&self) -> BigUint; +} + +// @volhovm can we remove this? +// /// Two to the power of the limb length +// pub fn two_to_limb() -> BigUint { +// BigUint::from(TWO_TO_LIMB) +// } + +impl BigUintForeignFieldHelpers for BigUint { + fn two() -> Self { + Self::from(2u32) + } + + fn two_to_limb() -> Self { + BigUint::two().pow(LIMB_BITS as u32) + } + + fn two_to_2limb() -> Self { + BigUint::two().pow(2 * LIMB_BITS as u32) + } + + fn binary_modulus() -> Self { + BigUint::two().pow(3 * LIMB_BITS as u32) + } + + fn max_foreign_field_modulus() -> Self { + // For simplicity and efficiency we use the approximation m = 2^259 - 1 + BigUint::two().pow(259) - BigUint::one() + } + + fn to_limbs(&self) -> [Self; 3] { + let mut limbs = biguint_to_limbs(self, LIMB_BITS); + assert!(limbs.len() <= 3); + limbs.resize(3, BigUint::zero()); + + array::from_fn(|i| limbs[i].clone()) + } + + fn to_compact_limbs(&self) -> [Self; 2] { + let mut limbs = biguint_to_limbs(self, 2 * LIMB_BITS); + assert!(limbs.len() <= 2); + limbs.resize(2, BigUint::zero()); + + array::from_fn(|i| limbs[i].clone()) + } + + fn to_field_limbs(&self) -> [F; 3] { + self.to_limbs().to_field_limbs() + } + + fn to_compact_field_limbs(&self) -> [F; 2] { + self.to_compact_limbs().to_field_limbs() + } + + fn negate(&self) -> BigUint { + assert!(*self < BigUint::binary_modulus()); + let neg_self = BigUint::binary_modulus() - self; + assert_eq!(neg_self.bits(), BINARY_MODULUS_EXP as u64); + neg_self + } +} + +/// PrimeField array BigUint helpers +pub trait FieldArrayBigUintHelpers { + /// Convert limbs from field elements to BigUint + fn to_limbs(&self) -> [BigUint; N]; + + /// Alias for to_limbs + fn to_biguints(&self) -> [BigUint; N] { + self.to_limbs() + } +} + +impl FieldArrayBigUintHelpers for [F; N] { + fn to_limbs(&self) -> [BigUint; N] { + array::from_fn(|i| self[i].to_biguint()) + } +} + +/// PrimeField array compose BigUint +pub trait FieldArrayCompose { + /// Compose field limbs into BigUint + fn compose(&self) -> BigUint; +} + +impl FieldArrayCompose for [F; 2] { + fn compose(&self) -> BigUint { + fields_compose(self, &BigUint::two_to_2limb()) + } +} + +impl FieldArrayCompose for [F; 3] { + fn compose(&self) -> BigUint { + fields_compose(self, &BigUint::two_to_limb()) + } +} + +/// PrimeField array compact limbs +pub trait FieldArrayCompact { + /// Compose field limbs into BigUint + fn to_compact_limbs(&self) -> [F; 2]; +} + +impl FieldArrayCompact for [F; 3] { + fn to_compact_limbs(&self) -> [F; 2] { + [ + self[0] + KimchiForeignElement::::two_to_limb() * self[1], + self[2], + ] + } +} + +/// BigUint array PrimeField helpers +pub trait BigUintArrayFieldHelpers { + /// Convert limbs from BigUint to field element + fn to_field_limbs(&self) -> [F; N]; + + /// Alias for to_field_limbs + fn to_fields(&self) -> [F; N] { + self.to_field_limbs() + } +} + +impl BigUintArrayFieldHelpers for [BigUint; N] { + fn to_field_limbs(&self) -> [F; N] { + biguints_to_fields(self) + } +} + +/// BigUint array compose helper +pub trait BigUintArrayCompose { + /// Compose limbs into BigUint + fn compose(&self) -> BigUint; +} + +impl BigUintArrayCompose<2> for [BigUint; 2] { + fn compose(&self) -> BigUint { + bigunits_compose(self, &BigUint::two_to_2limb()) + } +} + +impl BigUintArrayCompose<3> for [BigUint; 3] { + fn compose(&self) -> BigUint { + bigunits_compose(self, &BigUint::two_to_limb()) + } +} + +// Compose field limbs into BigUint value +fn fields_compose(limbs: &[F; N], base: &BigUint) -> BigUint { + limbs + .iter() + .cloned() + .enumerate() + .fold(BigUint::zero(), |x, (i, limb)| { + x + base.pow(i as u32) * limb.to_biguint() + }) +} + +// Convert array of BigUint to an array of PrimeField +fn biguints_to_fields(limbs: &[BigUint; N]) -> [F; N] { + array::from_fn(|i| { + F::from_random_bytes(&limbs[i].to_bytes_le()) + .expect("failed to convert BigUint to field element") + }) +} + +// Compose limbs into BigUint value +fn bigunits_compose(limbs: &[BigUint; N], base: &BigUint) -> BigUint { + limbs + .iter() + .cloned() + .enumerate() + .fold(BigUint::zero(), |x, (i, limb)| { + x + base.pow(i as u32) * limb + }) +} + +// Split a BigUint up into limbs of size limb_size (in little-endian order) +fn biguint_to_limbs(x: &BigUint, limb_bits: usize) -> Vec { + let bytes = x.to_bytes_le(); + let chunks: Vec<&[u8]> = bytes.chunks(limb_bits / 8).collect(); + chunks + .iter() + .map(|chunk| BigUint::from_bytes_le(chunk)) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + + use ark_ec::AffineRepr; + use ark_ff::One; + use mina_curves::pasta::Pallas as CurvePoint; + use num_bigint::RandBigInt; + + /// Base field element type + pub type BaseField = ::BaseField; + + fn secp256k1_modulus() -> BigUint { + BigUint::from_bytes_be(&secp256k1::constants::FIELD_SIZE) + } + + #[test] + fn test_negate_modulus_safe1() { + secp256k1_modulus().negate(); + } + + #[test] + fn test_negate_modulus_safe2() { + BigUint::binary_modulus().sqrt().negate(); + } + + #[test] + fn test_negate_modulus_safe3() { + (BigUint::binary_modulus() / BigUint::from(2u32)).negate(); + } + + #[test] + #[should_panic] + fn test_negate_modulus_unsafe1() { + (BigUint::binary_modulus() - BigUint::one()).negate(); + } + + #[test] + #[should_panic] + fn test_negate_modulus_unsafe2() { + (BigUint::binary_modulus() + BigUint::one()).negate(); + } + + #[test] + #[should_panic] + fn test_negate_modulus_unsafe3() { + BigUint::binary_modulus().negate(); + } + + #[test] + fn check_negation() { + let mut rng = o1_utils::tests::make_test_rng(None); + for _ in 0..10 { + rng.gen_biguint(256).negate(); + } + } + + #[test] + fn check_good_limbs() { + let mut rng = o1_utils::tests::make_test_rng(None); + for _ in 0..100 { + let x = rng.gen_biguint(264); + assert_eq!(x.to_limbs().len(), 3); + assert_eq!(x.to_limbs().compose(), x); + assert_eq!(x.to_compact_limbs().len(), 2); + assert_eq!(x.to_compact_limbs().compose(), x); + assert_eq!(x.to_compact_limbs().compose(), x.to_limbs().compose()); + + assert_eq!(x.to_field_limbs::().len(), 3); + assert_eq!(x.to_field_limbs::().compose(), x); + assert_eq!(x.to_compact_field_limbs::().len(), 2); + assert_eq!(x.to_compact_field_limbs::().compose(), x); + assert_eq!( + x.to_compact_field_limbs::().compose(), + x.to_field_limbs::().compose() + ); + + assert_eq!(x.to_limbs().to_fields::(), x.to_field_limbs()); + assert_eq!(x.to_field_limbs::().to_biguints(), x.to_limbs()); + } + } + + #[test] + #[should_panic] + fn check_bad_limbs_1() { + let mut rng = o1_utils::tests::make_test_rng(None); + assert_ne!(rng.gen_biguint(265).to_limbs().len(), 3); + } + + #[test] + #[should_panic] + fn check_bad_limbs_2() { + let mut rng = o1_utils::tests::make_test_rng(None); + assert_ne!(rng.gen_biguint(265).to_compact_limbs().len(), 2); + } + + #[test] + #[should_panic] + fn check_bad_limbs_3() { + let mut rng = o1_utils::tests::make_test_rng(None); + assert_ne!(rng.gen_biguint(265).to_field_limbs::().len(), 3); + } + + #[test] + #[should_panic] + fn check_bad_limbs_4() { + let mut rng = o1_utils::tests::make_test_rng(None); + assert_ne!( + rng.gen_biguint(265) + .to_compact_field_limbs::() + .len(), + 2 + ); + } +} diff --git a/kimchi/src/circuits/polynomials/foreign_field_mul/circuitgates.rs b/kimchi/src/circuits/polynomials/foreign_field_mul/circuitgates.rs index 444872e2da..e0e8b75a8e 100644 --- a/kimchi/src/circuits/polynomials/foreign_field_mul/circuitgates.rs +++ b/kimchi/src/circuits/polynomials/foreign_field_mul/circuitgates.rs @@ -93,6 +93,7 @@ use crate::{ auto_clone_array, circuits::{ argument::{Argument, ArgumentEnv, ArgumentType}, + berkeley_columns::BerkeleyChallengeTerm, expr::{constraints::ExprOps, Cache}, gate::GateType, }, @@ -103,9 +104,9 @@ use std::{array, marker::PhantomData}; /// Compute non-zero intermediate products /// /// For more details see the "Intermediate products" Section of -/// the [Foreign Field Multiplication RFC](../rfcs/foreign_field_mul.md) +/// the [Foreign Field Multiplication RFC](https://github.com/o1-labs/rfcs/blob/main/0006-ffmul-revised.md) /// -pub fn compute_intermediate_products>( +pub fn compute_intermediate_products>( left_input: &[T; 3], right_input: &[T; 3], quotient: &[T; 3], @@ -135,7 +136,7 @@ pub fn compute_intermediate_products>( } // Compute native modulus values -pub fn compute_native_modulus_values>( +pub fn compute_native_modulus_values>( left_input: &[T; 3], right_input: &[T; 3], quotient: &[T; 3], @@ -165,7 +166,7 @@ pub fn compute_native_modulus_values>( } /// Composes the 91-bit carry1 value from its parts -pub fn compose_carry>(carry: &[T; 11]) -> T { +pub fn compose_carry>(carry: &[T; 11]) -> T { auto_clone_array!(carry); carry(0) + T::two_pow(12) * carry(1) @@ -194,7 +195,10 @@ where const CONSTRAINTS: u32 = 11; // DEGREE is 4 - fn constraint_checks>(env: &ArgumentEnv, _cache: &mut Cache) -> Vec { + fn constraint_checks>( + env: &ArgumentEnv, + _cache: &mut Cache, + ) -> Vec { let mut constraints = vec![]; // diff --git a/kimchi/src/circuits/polynomials/foreign_field_mul/gadget.rs b/kimchi/src/circuits/polynomials/foreign_field_mul/gadget.rs index 41afcc39e0..872fb3fe1b 100644 --- a/kimchi/src/circuits/polynomials/foreign_field_mul/gadget.rs +++ b/kimchi/src/circuits/polynomials/foreign_field_mul/gadget.rs @@ -2,19 +2,23 @@ use ark_ff::PrimeField; use num_bigint::BigUint; -use o1_utils::foreign_field::{BigUintForeignFieldHelpers, ForeignFieldHelpers}; +use o1_utils::foreign_field::ForeignFieldHelpers; use crate::{ alphas::Alphas, circuits::{ argument::Argument, - expr::{Cache, E}, + berkeley_columns::E, + expr::Cache, gate::{CircuitGate, GateType}, lookup::{ self, tables::{GateLookupTable, LookupTable}, }, - polynomials::generic::GenericGateSpec, + polynomials::{ + foreign_field_common::{BigUintForeignFieldHelpers, KimchiForeignElement}, + generic::GenericGateSpec, + }, wires::Wire, }, }; @@ -76,7 +80,7 @@ impl CircuitGate { ) { let r = gates.len(); let hi_fmod = foreign_field_modulus.to_field_limbs::()[2]; - let hi_limb: F = F::two_to_limb() - hi_fmod - F::one(); + let hi_limb: F = KimchiForeignElement::::two_to_limb() - hi_fmod - F::one(); let g = GenericGateSpec::Plus(hi_limb); CircuitGate::extend_generic(gates, curr_row, Wire::for_row(r), g.clone(), Some(g)); } diff --git a/kimchi/src/circuits/polynomials/foreign_field_mul/witness.rs b/kimchi/src/circuits/polynomials/foreign_field_mul/witness.rs index e202f3f7f7..16cb810d05 100644 --- a/kimchi/src/circuits/polynomials/foreign_field_mul/witness.rs +++ b/kimchi/src/circuits/polynomials/foreign_field_mul/witness.rs @@ -4,7 +4,14 @@ use crate::{ auto_clone_array, circuits::{ polynomial::COLUMNS, - polynomials::{foreign_field_add, range_check}, + polynomials::{ + foreign_field_add, + foreign_field_common::{ + BigUintArrayFieldHelpers, BigUintForeignFieldHelpers, FieldArrayBigUintHelpers, + KimchiForeignElement, + }, + range_check, + }, witness::{self, ConstantCell, VariableBitsCell, VariableCell, Variables, WitnessCell}, }, variable_map, @@ -14,10 +21,7 @@ use num_bigint::BigUint; use num_integer::Integer; use num_traits::One; -use o1_utils::foreign_field::{ - BigUintArrayFieldHelpers, BigUintForeignFieldHelpers, FieldArrayBigUintHelpers, - ForeignFieldHelpers, -}; +use o1_utils::foreign_field::ForeignFieldHelpers; use std::{array, ops::Div}; use super::circuitgates; @@ -41,10 +45,10 @@ use super::circuitgates; // // so that most significant limb, q2, is in W[2][0]. // -fn create_layout() -> [[Box>; COLUMNS]; 2] { +fn create_layout() -> [Vec>>; 2] { [ // ForeignFieldMul row - [ + vec![ // Copied for multi-range-check VariableCell::create("left_input0"), VariableCell::create("left_input1"), @@ -64,7 +68,7 @@ fn create_layout() -> [[Box>; COLUMNS]; 2] { VariableBitsCell::create("carry1", 90, None), ], // Zero row - [ + vec![ // Copied for multi-range-check VariableCell::create("remainder01"), VariableCell::create("remainder2"), @@ -321,7 +325,9 @@ impl ExternalChecks { witness: &mut [Vec; COLUMNS], foreign_field_modulus: &BigUint, ) { - let hi_limb = F::two_to_limb() - foreign_field_modulus.to_field_limbs::()[2] - F::one(); + let hi_limb = KimchiForeignElement::::two_to_limb() + - foreign_field_modulus.to_field_limbs::()[2] + - F::one(); for chunk in self.high_bounds.clone().chunks(2) { // Extend the witness for the generic gate for col in witness.iter_mut().take(COLUMNS) { diff --git a/kimchi/src/circuits/polynomials/generic.rs b/kimchi/src/circuits/polynomials/generic.rs index ad07a866cf..855adf8942 100644 --- a/kimchi/src/circuits/polynomials/generic.rs +++ b/kimchi/src/circuits/polynomials/generic.rs @@ -36,6 +36,7 @@ use crate::{ circuits::{ argument::{Argument, ArgumentEnv, ArgumentType}, + berkeley_columns::BerkeleyChallengeTerm, expr::{constraints::ExprOps, Cache}, gate::{CircuitGate, GateType}, polynomial::COLUMNS, @@ -77,7 +78,10 @@ where const ARGUMENT_TYPE: ArgumentType = ArgumentType::Gate(GateType::Generic); const CONSTRAINTS: u32 = 2; - fn constraint_checks>(env: &ArgumentEnv, _cache: &mut Cache) -> Vec { + fn constraint_checks>( + env: &ArgumentEnv, + _cache: &mut Cache, + ) -> Vec { // First generic gate let left_coeff1 = env.coeff(0); let right_coeff1 = env.coeff(1); diff --git a/kimchi/src/circuits/polynomials/keccak.rs b/kimchi/src/circuits/polynomials/keccak.rs deleted file mode 100644 index 637509344a..0000000000 --- a/kimchi/src/circuits/polynomials/keccak.rs +++ /dev/null @@ -1,85 +0,0 @@ -//! Keccak gadget -use std::array; - -use ark_ff::{PrimeField, SquareRootField}; - -use crate::circuits::{ - gate::{CircuitGate, Connect}, - polynomial::COLUMNS, - polynomials::{ - generic::GenericGateSpec, - rot::{self, RotMode}, - }, - wires::Wire, -}; - -/// Creates the 5x5 table of rotation bits for Keccak modulo 64 -/// | x \ y | 0 | 1 | 2 | 3 | 4 | -/// | ----- | -- | -- | -- | -- | -- | -/// | 0 | 0 | 36 | 3 | 41 | 18 | -/// | 1 | 1 | 44 | 10 | 45 | 2 | -/// | 2 | 62 | 6 | 43 | 15 | 61 | -/// | 3 | 28 | 55 | 25 | 21 | 56 | -/// | 4 | 27 | 20 | 39 | 8 | 14 | -pub const ROT_TAB: [[u32; 5]; 5] = [ - [0, 36, 3, 41, 18], - [1, 44, 10, 45, 2], - [62, 6, 43, 15, 61], - [28, 55, 25, 21, 56], - [27, 20, 39, 8, 14], -]; - -impl CircuitGate { - /// Creates Keccak gadget. - /// Right now it only creates an initial generic gate with all zeros starting on `new_row` and then - /// calls the Keccak rotation gadget - pub fn create_keccak(new_row: usize) -> (usize, Vec) { - // Initial Generic gate to constrain the prefix of the output to be zero - let mut gates = vec![CircuitGate::::create_generic_gadget( - Wire::for_row(new_row), - GenericGateSpec::Pub, - None, - )]; - Self::create_keccak_rot(&mut gates, new_row + 1, new_row) - } - - /// Creates Keccak rotation gates for the whole table (skipping the rotation by 0) - pub fn create_keccak_rot( - gates: &mut Vec, - new_row: usize, - zero_row: usize, - ) -> (usize, Vec) { - let mut rot_row = new_row; - for row in ROT_TAB { - for rot in row { - // if rotation by 0 bits, no need to create a gate for it - if rot == 0 { - continue; - } - let mut rot64_gates = Self::create_rot64(rot_row, rot); - rot_row += rot64_gates.len(); - // Append them to the full gates vector - gates.append(&mut rot64_gates); - // Check that 2 most significant limbs of shifted are zero - gates.connect_64bit(zero_row, rot_row - 1); - } - } - (rot_row, gates.to_vec()) - } -} - -/// Create a Keccak rotation (whole table) -/// Input: state (5x5) array of words to be rotated -pub fn create_witness_keccak_rot(state: [[u64; 5]; 5]) -> [Vec; COLUMNS] { - // First generic gate with all zeros to constrain that the two most significant limbs of shifted output are zeros - let mut witness: [Vec; COLUMNS] = array::from_fn(|_| vec![F::zero()]); - for (x, row) in ROT_TAB.iter().enumerate() { - for (y, &rot) in row.iter().enumerate() { - if rot == 0 { - continue; - } - rot::extend_rot(&mut witness, state[x][y], rot, RotMode::Left); - } - } - witness -} diff --git a/kimchi/src/circuits/polynomials/keccak/circuitgates.rs b/kimchi/src/circuits/polynomials/keccak/circuitgates.rs new file mode 100644 index 0000000000..b35b14333c --- /dev/null +++ b/kimchi/src/circuits/polynomials/keccak/circuitgates.rs @@ -0,0 +1,324 @@ +//! Keccak gadget +use crate::{ + auto_clone, auto_clone_array, + circuits::{ + argument::{Argument, ArgumentEnv, ArgumentType}, + berkeley_columns::BerkeleyChallengeTerm, + expr::{ + constraints::{boolean, ExprOps}, + Cache, + }, + gate::GateType, + polynomials::keccak::{constants::*, OFF}, + }, + grid, +}; +use ark_ff::PrimeField; +use std::marker::PhantomData; + +#[macro_export] +macro_rules! from_quarters { + ($quarters:ident, $x:ident) => { + $quarters($x, 0) + + T::two_pow(16) * $quarters($x, 1) + + T::two_pow(32) * $quarters($x, 2) + + T::two_pow(48) * $quarters($x, 3) + }; + ($quarters:ident, $y:ident, $x:ident) => { + $quarters($y, $x, 0) + + T::two_pow(16) * $quarters($y, $x, 1) + + T::two_pow(32) * $quarters($y, $x, 2) + + T::two_pow(48) * $quarters($y, $x, 3) + }; +} + +#[macro_export] +macro_rules! from_shifts { + ($shifts:ident, $i:ident) => { + $shifts($i) + + T::two_pow(1) * $shifts(100 + $i) + + T::two_pow(2) * $shifts(200 + $i) + + T::two_pow(3) * $shifts(300 + $i) + }; + ($shifts:ident, $x:ident, $q:ident) => { + $shifts(0, $x, $q) + + T::two_pow(1) * $shifts(1, $x, $q) + + T::two_pow(2) * $shifts(2, $x, $q) + + T::two_pow(3) * $shifts(3, $x, $q) + }; + ($shifts:ident, $y:ident, $x:ident, $q:ident) => { + $shifts(0, $y, $x, $q) + + T::two_pow(1) * $shifts(1, $y, $x, $q) + + T::two_pow(2) * $shifts(2, $y, $x, $q) + + T::two_pow(3) * $shifts(3, $y, $x, $q) + }; +} + +//~ +//~ | `KeccakRound` | [0...265) | [265...1165) | [1165...1965) | +//~ | ------------- | --------- | ------------ | ------------- | +//~ | Curr | theta | pirho | chi | +//~ +//~ | `KeccakRound` | [0...100) | +//~ | ------------- | --------- | +//~ | Next | iota | +//~ +//~ ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- +//~ +//~ | Columns | [0...100) | [100...180) | [180...200) | [200...205) | [205...225) | [225...245) | [245...265) | +//~ | -------- | --------- | ----------- | ----------- | ----------- | ------------ | ------------ | ------------ | +//~ | theta | state_a | shifts_c | dense_c | quotient_c | remainder_c | dense_rot_c | expand_rot_c | +//~ +//~ | Columns | [265...665) | [665...765) | [765...865) | [865...965) | [965...1065) | [1065...1165) | +//~ | -------- | ----------- | ----------- | ------------ | ----------- | ------------ | ------------- | +//~ | pirho | shifts_e | dense_e | quotient_e | remainder_e | dense_rot_e | expand_rot_e | +//~ +//~ | Columns | [1165...1565) | [1565...1965) | +//~ | -------- | ------------- | ------------- | +//~ | chi | shifts_b | shifts_sum | +//~ +//~ | Columns | [0...4) | [4...100) | +//~ | -------- | ------- | --------- | +//~ | iota | g00 | rest_g | +//~ +#[derive(Default)] +pub struct KeccakRound(PhantomData); + +impl Argument for KeccakRound +where + F: PrimeField, +{ + const ARGUMENT_TYPE: ArgumentType = ArgumentType::Gate(GateType::KeccakRound); + const CONSTRAINTS: u32 = 389; + + // Constraints for one round of the Keccak permutation function + fn constraint_checks>( + env: &ArgumentEnv, + _cache: &mut Cache, + ) -> Vec { + let mut constraints = vec![]; + + // DEFINE ROUND CONSTANT + let rc = [env.coeff(0), env.coeff(1), env.coeff(2), env.coeff(3)]; + + // LOAD STATES FROM WITNESS LAYOUT + // THETA + let state_a = grid!( + 100, + env.witness_curr_chunk(THETA_STATE_A_OFF, THETA_SHIFTS_C_OFF) + ); + let shifts_c = grid!( + 80, + env.witness_curr_chunk(THETA_SHIFTS_C_OFF, THETA_DENSE_C_OFF) + ); + let dense_c = grid!( + 20, + env.witness_curr_chunk(THETA_DENSE_C_OFF, THETA_QUOTIENT_C_OFF) + ); + let quotient_c = grid!( + 5, + env.witness_curr_chunk(THETA_QUOTIENT_C_OFF, THETA_REMAINDER_C_OFF) + ); + let remainder_c = grid!( + 20, + env.witness_curr_chunk(THETA_REMAINDER_C_OFF, THETA_DENSE_ROT_C_OFF) + ); + let dense_rot_c = grid!( + 20, + env.witness_curr_chunk(THETA_DENSE_ROT_C_OFF, THETA_EXPAND_ROT_C_OFF) + ); + let expand_rot_c = grid!( + 20, + env.witness_curr_chunk(THETA_EXPAND_ROT_C_OFF, PIRHO_DENSE_E_OFF) + ); + // PI-RHO + let shifts_e = grid!( + 400, + env.witness_curr_chunk(PIRHO_SHIFTS_E_OFF, PIRHO_DENSE_E_OFF) + ); + let dense_e = grid!( + 100, + env.witness_curr_chunk(PIRHO_DENSE_E_OFF, PIRHO_QUOTIENT_E_OFF) + ); + let quotient_e = grid!( + 100, + env.witness_curr_chunk(PIRHO_QUOTIENT_E_OFF, PIRHO_REMAINDER_E_OFF) + ); + let remainder_e = grid!( + 100, + env.witness_curr_chunk(PIRHO_REMAINDER_E_OFF, PIRHO_DENSE_ROT_E_OFF) + ); + let dense_rot_e = grid!( + 100, + env.witness_curr_chunk(PIRHO_DENSE_ROT_E_OFF, PIRHO_EXPAND_ROT_E_OFF) + ); + let expand_rot_e = grid!( + 100, + env.witness_curr_chunk(PIRHO_EXPAND_ROT_E_OFF, CHI_SHIFTS_B_OFF) + ); + // CHI + let shifts_b = grid!( + 400, + env.witness_curr_chunk(CHI_SHIFTS_B_OFF, CHI_SHIFTS_SUM_OFF) + ); + let shifts_sum = grid!( + 400, + env.witness_curr_chunk(CHI_SHIFTS_SUM_OFF, IOTA_STATE_G_OFF) + ); + // IOTA + let state_g = grid!(100, env.witness_next_chunk(0, IOTA_STATE_G_LEN)); + + // Define vectors containing witness expressions which are not in the layout for efficiency + let mut state_c: Vec> = vec![vec![T::zero(); QUARTERS]; DIM]; + let mut state_d: Vec> = vec![vec![T::zero(); QUARTERS]; DIM]; + let mut state_e: Vec>> = vec![vec![vec![T::zero(); QUARTERS]; DIM]; DIM]; + let mut state_b: Vec>> = vec![vec![vec![T::zero(); QUARTERS]; DIM]; DIM]; + let mut state_f: Vec>> = vec![vec![vec![T::zero(); QUARTERS]; DIM]; DIM]; + + // STEP theta: 5 * ( 3 + 4 * 1 ) = 35 constraints + for x in 0..DIM { + let word_c = from_quarters!(dense_c, x); + let rem_c = from_quarters!(remainder_c, x); + let rot_c = from_quarters!(dense_rot_c, x); + + constraints + .push(word_c * T::two_pow(1) - (quotient_c(x) * T::two_pow(64) + rem_c.clone())); + constraints.push(rot_c - (quotient_c(x) + rem_c)); + constraints.push(boolean("ient_c(x))); + + for q in 0..QUARTERS { + state_c[x][q] = state_a(0, x, q) + + state_a(1, x, q) + + state_a(2, x, q) + + state_a(3, x, q) + + state_a(4, x, q); + constraints.push(state_c[x][q].clone() - from_shifts!(shifts_c, x, q)); + + state_d[x][q] = + shifts_c(0, (x + DIM - 1) % DIM, q) + expand_rot_c((x + 1) % DIM, q); + + for (y, column_e) in state_e.iter_mut().enumerate() { + column_e[x][q] = state_a(y, x, q) + state_d[x][q].clone(); + } + } + } // END theta + + // STEP pirho: 5 * 5 * (2 + 4 * 1) = 150 constraints + for (y, col) in OFF.iter().enumerate() { + for (x, off) in col.iter().enumerate() { + let word_e = from_quarters!(dense_e, y, x); + let quo_e = from_quarters!(quotient_e, y, x); + let rem_e = from_quarters!(remainder_e, y, x); + let rot_e = from_quarters!(dense_rot_e, y, x); + + constraints.push( + word_e * T::two_pow(*off) - (quo_e.clone() * T::two_pow(64) + rem_e.clone()), + ); + constraints.push(rot_e - (quo_e.clone() + rem_e)); + + for q in 0..QUARTERS { + constraints.push(state_e[y][x][q].clone() - from_shifts!(shifts_e, y, x, q)); + state_b[(2 * x + 3 * y) % DIM][y][q] = expand_rot_e(y, x, q); + } + } + } // END pirho + + // STEP chi: 4 * 5 * 5 * 2 = 200 constraints + for q in 0..QUARTERS { + for x in 0..DIM { + for y in 0..DIM { + let not = T::literal(F::from(0x1111111111111111u64)) + - shifts_b(0, y, (x + 1) % DIM, q); + let sum = not + shifts_b(0, y, (x + 2) % DIM, q); + let and = shifts_sum(1, y, x, q); + + constraints.push(state_b[y][x][q].clone() - from_shifts!(shifts_b, y, x, q)); + constraints.push(sum - from_shifts!(shifts_sum, y, x, q)); + state_f[y][x][q] = shifts_b(0, y, x, q) + and; + } + } + } // END chi + + // STEP iota: 4 constraints + for (q, c) in rc.iter().enumerate() { + constraints.push(state_g(0, 0, q) - (state_f[0][0][q].clone() + c.clone())); + } // END iota + + constraints + } +} + +//~ +//~ | `KeccakSponge` | [0...100) | [100...168) | [168...200) | [200...400] | [400...800) | +//~ | -------------- | --------- | ----------- | ----------- | ----------- | ----------- | +//~ | Curr | old_state | new_block | zeros | bytes | shifts | +//~ | Next | xor_state | +//~ +#[derive(Default)] +pub struct KeccakSponge(PhantomData); + +impl Argument for KeccakSponge +where + F: PrimeField, +{ + const ARGUMENT_TYPE: ArgumentType = ArgumentType::Gate(GateType::KeccakSponge); + const CONSTRAINTS: u32 = 532; + + // Constraints for the Keccak sponge + fn constraint_checks>( + env: &ArgumentEnv, + _cache: &mut Cache, + ) -> Vec { + let mut constraints = vec![]; + + // LOAD WITNESS + let old_state = env.witness_curr_chunk(SPONGE_OLD_STATE_OFF, SPONGE_NEW_STATE_OFF); + let new_state = env.witness_curr_chunk(SPONGE_NEW_STATE_OFF, SPONGE_BYTES_OFF); + let zeros = env.witness_curr_chunk(SPONGE_ZEROS_OFF, SPONGE_BYTES_OFF); + let xor_state = env.witness_next_chunk(0, SPONGE_XOR_STATE_LEN); + let bytes = env.witness_curr_chunk(SPONGE_BYTES_OFF, SPONGE_SHIFTS_OFF); + let shifts = + env.witness_curr_chunk(SPONGE_SHIFTS_OFF, SPONGE_SHIFTS_OFF + SPONGE_SHIFTS_LEN); + auto_clone_array!(old_state); + auto_clone_array!(new_state); + auto_clone_array!(xor_state); + auto_clone_array!(bytes); + auto_clone_array!(shifts); + + // LOAD COEFFICIENTS + let absorb = env.coeff(0); + let squeeze = env.coeff(1); + let root = env.coeff(2); + let flags = env.coeff_chunk(4, 140); + let pad = env.coeff_chunk(200, 336); + auto_clone!(root); + auto_clone!(absorb); + auto_clone!(squeeze); + auto_clone_array!(flags); + auto_clone_array!(pad); + + // 32 + 100 * 3 + 64 + 136 = 532 + for z in zeros { + // Absorb phase pads with zeros the new state + constraints.push(absorb() * z); + } + for i in 0..STATE_LEN { + // In first absorb, root state is all zeros + constraints.push(root() * old_state(i)); + // Absorbs the new block by performing XOR with the old state + constraints.push(absorb() * (xor_state(i) - (old_state(i) + new_state(i)))); + // In absorb, Check shifts correspond to the decomposition of the new state + constraints.push(absorb() * (new_state(i) - from_shifts!(shifts, i))); + } + for i in 0..64 { + // In squeeze, Check shifts correspond to the 256-bit prefix digest of the old state (current) + constraints.push(squeeze() * (old_state(i) - from_shifts!(shifts, i))); + } + for i in 0..RATE_IN_BYTES { + // Check padding + constraints.push(flags(i) * (pad(i) - bytes(i))); + } + + constraints + } +} diff --git a/kimchi/src/circuits/polynomials/keccak/constants.rs b/kimchi/src/circuits/polynomials/keccak/constants.rs new file mode 100644 index 0000000000..e8416cf2db --- /dev/null +++ b/kimchi/src/circuits/polynomials/keccak/constants.rs @@ -0,0 +1,83 @@ +/// Constants for each witness' index offsets and lengths + +// KECCAK PARAMETERS +/// The dimension of the Keccak state +pub const DIM: usize = 5; + +/// An element of the Keccak state is 64 bits. However, we split the state into +/// quarters of 16 bits to represent XOR and AND using the finite field +/// addition. See the Keccak RFC for more information. +pub const QUARTERS: usize = 4; + +pub const SHIFTS: usize = 4; + +/// The number of rounds in the Keccak permutation +pub const ROUNDS: usize = 24; + +/// The number of bytes that can be processed by the Keccak permutation. It is +/// the rate of the sponge configuration. +pub const RATE_IN_BYTES: usize = 1088 / 8; + +/// The number of bytes used as a capacity in the sponge. +pub const CAPACITY_IN_BYTES: usize = 512 / 8; + +/// The number of columns the Keccak circuit uses. +pub const KECCAK_COLS: usize = 1965; + +/// The number of field elements used to represent the whole Keccak state. +pub const STATE_LEN: usize = QUARTERS * DIM * DIM; + +pub const SHIFTS_LEN: usize = SHIFTS * STATE_LEN; + +// ROUND INDICES + +pub const THETA_STATE_A_OFF: usize = 0; +pub const THETA_STATE_A_LEN: usize = STATE_LEN; +pub const THETA_SHIFTS_C_OFF: usize = THETA_STATE_A_LEN; +pub const THETA_SHIFTS_C_LEN: usize = SHIFTS * DIM * QUARTERS; +pub const THETA_DENSE_C_OFF: usize = THETA_SHIFTS_C_OFF + THETA_SHIFTS_C_LEN; +pub const THETA_DENSE_C_LEN: usize = QUARTERS * DIM; +pub const THETA_QUOTIENT_C_OFF: usize = THETA_DENSE_C_OFF + THETA_DENSE_C_LEN; +pub const THETA_QUOTIENT_C_LEN: usize = DIM; +pub const THETA_REMAINDER_C_OFF: usize = THETA_QUOTIENT_C_OFF + THETA_QUOTIENT_C_LEN; +pub const THETA_REMAINDER_C_LEN: usize = QUARTERS * DIM; +pub const THETA_DENSE_ROT_C_OFF: usize = THETA_REMAINDER_C_OFF + THETA_REMAINDER_C_LEN; +pub const THETA_DENSE_ROT_C_LEN: usize = QUARTERS * DIM; +pub const THETA_EXPAND_ROT_C_OFF: usize = THETA_DENSE_ROT_C_OFF + THETA_DENSE_ROT_C_LEN; +pub const THETA_EXPAND_ROT_C_LEN: usize = QUARTERS * DIM; +pub const PIRHO_SHIFTS_E_OFF: usize = THETA_EXPAND_ROT_C_OFF + THETA_EXPAND_ROT_C_LEN; +pub const PIRHO_SHIFTS_E_LEN: usize = SHIFTS_LEN; +pub const PIRHO_DENSE_E_OFF: usize = PIRHO_SHIFTS_E_OFF + PIRHO_SHIFTS_E_LEN; +pub const PIRHO_DENSE_E_LEN: usize = STATE_LEN; +pub const PIRHO_QUOTIENT_E_OFF: usize = PIRHO_DENSE_E_OFF + PIRHO_DENSE_E_LEN; +pub const PIRHO_QUOTIENT_E_LEN: usize = STATE_LEN; +pub const PIRHO_REMAINDER_E_OFF: usize = PIRHO_QUOTIENT_E_OFF + PIRHO_QUOTIENT_E_LEN; +pub const PIRHO_REMAINDER_E_LEN: usize = STATE_LEN; +pub const PIRHO_DENSE_ROT_E_OFF: usize = PIRHO_REMAINDER_E_OFF + PIRHO_REMAINDER_E_LEN; +pub const PIRHO_DENSE_ROT_E_LEN: usize = STATE_LEN; +pub const PIRHO_EXPAND_ROT_E_OFF: usize = PIRHO_DENSE_ROT_E_OFF + PIRHO_DENSE_ROT_E_LEN; +pub const PIRHO_EXPAND_ROT_E_LEN: usize = STATE_LEN; +pub const CHI_SHIFTS_B_OFF: usize = PIRHO_EXPAND_ROT_E_OFF + PIRHO_EXPAND_ROT_E_LEN; +pub const CHI_SHIFTS_B_LEN: usize = SHIFTS_LEN; +pub const CHI_SHIFTS_SUM_OFF: usize = CHI_SHIFTS_B_OFF + CHI_SHIFTS_B_LEN; +pub const CHI_SHIFTS_SUM_LEN: usize = SHIFTS_LEN; +pub const IOTA_STATE_G_OFF: usize = 0; +pub const IOTA_STATE_G_LEN: usize = STATE_LEN; +// SPONGE INDICES +pub const SPONGE_OLD_STATE_OFF: usize = 0; +pub const SPONGE_OLD_STATE_LEN: usize = STATE_LEN; +pub const SPONGE_NEW_STATE_OFF: usize = SPONGE_OLD_STATE_OFF + SPONGE_OLD_STATE_LEN; +pub const SPONGE_NEW_STATE_LEN: usize = STATE_LEN; +pub const SPONGE_NEW_BLOCK_OFF: usize = SPONGE_NEW_STATE_OFF; +pub const SPONGE_NEW_BLOCK_LEN: usize = 68; +pub const SPONGE_ZEROS_OFF: usize = SPONGE_NEW_BLOCK_OFF + SPONGE_NEW_BLOCK_LEN; +pub const SPONGE_ZEROS_LEN: usize = STATE_LEN - SPONGE_NEW_BLOCK_LEN; +pub const SPONGE_BYTES_OFF: usize = SPONGE_NEW_STATE_OFF + SPONGE_NEW_STATE_LEN; +pub const SPONGE_BYTES_LEN: usize = 2 * STATE_LEN; +pub const SPONGE_SHIFTS_OFF: usize = SPONGE_BYTES_OFF + SPONGE_BYTES_LEN; +pub const SPONGE_SHIFTS_LEN: usize = SHIFTS_LEN; +pub const SPONGE_XOR_STATE_OFF: usize = 0; +pub const SPONGE_XOR_STATE_LEN: usize = STATE_LEN; + +/// The number of columns the Sponge circuit uses. +pub const SPONGE_COLS: usize = SPONGE_SHIFTS_OFF + SPONGE_SHIFTS_LEN; diff --git a/kimchi/src/circuits/polynomials/keccak/gadget.rs b/kimchi/src/circuits/polynomials/keccak/gadget.rs new file mode 100644 index 0000000000..f204771fe9 --- /dev/null +++ b/kimchi/src/circuits/polynomials/keccak/gadget.rs @@ -0,0 +1,93 @@ +//! Keccak gadget +use crate::circuits::{ + gate::{CircuitGate, GateType}, + wires::Wire, +}; +use ark_ff::PrimeField; + +use super::{ + constants::{RATE_IN_BYTES, ROUNDS}, + Keccak, RC, +}; + +const SPONGE_COEFFS: usize = 336; + +impl CircuitGate { + /// Extends a Keccak circuit to hash one message + /// Note: + /// Requires at least one more row after the Keccak gadget so that + /// constraints can access the next row in the squeeze + pub fn extend_keccak(circuit: &mut Vec, bytelength: usize) -> usize { + let mut gates = Self::create_keccak(circuit.len(), bytelength); + circuit.append(&mut gates); + circuit.len() + } + + /// Creates a Keccak256 circuit, capacity 512 bits, rate 1088 bits, message of a given bytelength + fn create_keccak(new_row: usize, bytelength: usize) -> Vec { + let padded_len = Keccak::padded_length(bytelength); + let extra_bytes = padded_len - bytelength; + let num_blocks = padded_len / RATE_IN_BYTES; + let mut gates = vec![]; + for block in 0..num_blocks { + let root = block == 0; + let pad = block == num_blocks - 1; + gates.push(Self::create_keccak_absorb( + new_row + gates.len(), + root, + pad, + extra_bytes, + )); + for round in 0..ROUNDS { + gates.push(Self::create_keccak_round(new_row + gates.len(), round)); + } + } + gates.push(Self::create_keccak_squeeze(new_row + gates.len())); + gates + } + + fn create_keccak_squeeze(new_row: usize) -> Self { + CircuitGate { + typ: GateType::KeccakSponge, + wires: Wire::for_row(new_row), + coeffs: { + let mut c = vec![F::zero(); SPONGE_COEFFS]; + c[1] = F::one(); + c + }, + } + } + + fn create_keccak_absorb(new_row: usize, root: bool, pad: bool, pad_bytes: usize) -> Self { + let mut coeffs = vec![F::zero(); SPONGE_COEFFS]; + coeffs[0] = F::one(); // absorb + if root { + coeffs[2] = F::one(); // root + } + if pad { + // Check pad 0x01 (0x00 ... 0x00)* 0x80 or 0x81 if only one byte for padding + for i in 0..pad_bytes { + coeffs[140 - i] = F::one(); // flag for padding + if i == 0 { + coeffs[SPONGE_COEFFS - 1 - i] += F::from(0x80u8); // pad + } + if i == pad_bytes - 1 { + coeffs[SPONGE_COEFFS - 1 - i] += F::one(); // pad + } + } + } + CircuitGate { + typ: GateType::KeccakSponge, + wires: Wire::for_row(new_row), + coeffs, + } + } + + fn create_keccak_round(new_row: usize, round: usize) -> Self { + CircuitGate { + typ: GateType::KeccakRound, + wires: Wire::for_row(new_row), + coeffs: Keccak::expand_word(RC[round]), + } + } +} diff --git a/kimchi/src/circuits/polynomials/keccak/mod.rs b/kimchi/src/circuits/polynomials/keccak/mod.rs new file mode 100644 index 0000000000..ec7a094c25 --- /dev/null +++ b/kimchi/src/circuits/polynomials/keccak/mod.rs @@ -0,0 +1,428 @@ +//! Keccak hash module +pub mod circuitgates; +pub mod constants; +pub mod gadget; +pub mod witness; + +use crate::circuits::expr::constraints::ExprOps; +use ark_ff::PrimeField; + +use self::constants::{DIM, QUARTERS, RATE_IN_BYTES, ROUNDS}; +use super::super::berkeley_columns::BerkeleyChallengeTerm; + +#[macro_export] +macro_rules! grid { + (5, $v:expr) => {{ + |x: usize| $v[x].clone() + }}; + (20, $v:expr) => {{ + |x: usize, q: usize| $v[q + QUARTERS * x].clone() + }}; + (80, $v:expr) => {{ + |i: usize, x: usize, q: usize| $v[q + QUARTERS * (x + DIM * i)].clone() + }}; + (100, $v:expr) => {{ + |y: usize, x: usize, q: usize| $v[q + QUARTERS * (x + DIM * y)].clone() + }}; + (400, $v:expr) => {{ + |i: usize, y: usize, x: usize, q: usize| { + $v[q + QUARTERS * (x + DIM * (y + DIM * i))].clone() + } + }}; +} + +/// Creates the 5x5 table of rotation bits for Keccak modulo 64 +/// | x \ y | 0 | 1 | 2 | 3 | 4 | +/// | ----- | -- | -- | -- | -- | -- | +/// | 0 | 0 | 36 | 3 | 41 | 18 | +/// | 1 | 1 | 44 | 10 | 45 | 2 | +/// | 2 | 62 | 6 | 43 | 15 | 61 | +/// | 3 | 28 | 55 | 25 | 21 | 56 | +/// | 4 | 27 | 20 | 39 | 8 | 14 | +/// Note that the order of the indexing is `[y][x]` to match the encoding of the witness algorithm +pub const OFF: [[u64; DIM]; DIM] = [ + [0, 1, 62, 28, 27], + [36, 44, 6, 55, 20], + [3, 10, 43, 25, 39], + [41, 45, 15, 21, 8], + [18, 2, 61, 56, 14], +]; + +/// Contains the 24 round constants for Keccak +pub const RC: [u64; ROUNDS] = [ + 0x0000000000000001, + 0x0000000000008082, + 0x800000000000808a, + 0x8000000080008000, + 0x000000000000808b, + 0x0000000080000001, + 0x8000000080008081, + 0x8000000000008009, + 0x000000000000008a, + 0x0000000000000088, + 0x0000000080008009, + 0x000000008000000a, + 0x000000008000808b, + 0x800000000000008b, + 0x8000000000008089, + 0x8000000000008003, + 0x8000000000008002, + 0x8000000000000080, + 0x000000000000800a, + 0x800000008000000a, + 0x8000000080008081, + 0x8000000000008080, + 0x0000000080000001, + 0x8000000080008008, +]; + +/// Naive Keccak structure +pub struct Keccak {} + +/// Trait containing common operations for optimized Keccak +impl Keccak { + /// Composes a vector of 4 dense quarters into the dense full u64 word + pub fn compose(quarters: &[u64]) -> u64 { + quarters[0] + (1 << 16) * quarters[1] + (1 << 32) * quarters[2] + (1 << 48) * quarters[3] + } + /// Takes a dense u64 word and decomposes it into a vector of 4 dense quarters. + /// The first element of the vector corresponds to the 16 least significant bits. + pub fn decompose(word: u64) -> Vec { + vec![ + word % (1 << 16), + (word / (1 << 16)) % (1 << 16), + (word / (1 << 32)) % (1 << 16), + (word / (1 << 48)) % (1 << 16), + ] + } + + /// Expands a quarter of a word into the sparse representation as a u64 + pub fn expand(quarter: u64) -> u64 { + u64::from_str_radix(&format!("{:b}", quarter), 16).unwrap() + } + + /// Expands a u64 word into a vector of 4 sparse u64 quarters + pub fn expand_word>(word: u64) -> Vec { + Self::decompose(word) + .iter() + .map(|q| T::literal(F::from(Self::expand(*q)))) + .collect::>() + } + + /// Returns the expansion of the 4 dense decomposed quarters of a word where + /// the first expanded element corresponds to the 16 least significant bits of the word. + pub fn sparse(word: u64) -> Vec { + Self::decompose(word) + .iter() + .map(|q| Self::expand(*q)) + .collect::>() + } + /// From each quarter in sparse representation, it computes its 4 resets. + /// The resulting vector contains 4 times as many elements as the input. + /// The output is placed in the vector as [shift0, shift1, shift2, shift3] + pub fn shift(state: &[u64]) -> Vec { + let n = state.len(); + let mut shifts = vec![0; QUARTERS * n]; + let aux = Self::expand(0xFFFF); + for (i, term) in state.iter().enumerate() { + shifts[i] = aux & term; // shift0 = reset0 + shifts[n + i] = ((aux << 1) & term) / 2; // shift1 = reset1/2 + shifts[2 * n + i] = ((aux << 2) & term) / 4; // shift2 = reset2/4 + shifts[3 * n + i] = ((aux << 3) & term) / 8; // shift3 = reset3/8 + } + shifts + } + + /// From a vector of shifts, resets the underlying value returning only shift0 + /// Note that shifts is always a vector whose length is a multiple of 4. + pub fn reset(shifts: &[u64]) -> Vec { + shifts[0..shifts.len() / QUARTERS].to_vec() + } + + /// From a canonical expanded state, obtain the corresponding 16-bit dense terms + pub fn collapse(state: &[u64]) -> Vec { + state + .iter() + .map(|&reset| u64::from_str_radix(&format!("{:x}", reset), 2).unwrap()) + .collect::>() + } + + /// Outputs the state into dense quarters of 16-bits each in little endian order + pub fn quarters(state: &[u8]) -> Vec { + let mut quarters = vec![]; + for pair in state.chunks(2) { + quarters.push(u16::from_le_bytes([pair[0], pair[1]]) as u64); + } + quarters + } + + /// On input a vector of 16-bit dense quarters, outputs a vector of 8-bit bytes in the right order for Keccak + pub fn bytestring(dense: &[u64]) -> Vec { + dense + .iter() + .map(|x| vec![x % 256, x / 256]) + .collect::>>() + .iter() + .flatten() + .copied() + .collect() + } + + /// On input a 200-byte vector, generates a vector of 100 expanded quarters representing the 1600-bit state + pub fn expand_state(state: &[u8]) -> Vec { + let mut expanded = vec![]; + for pair in state.chunks(2) { + let quarter = u16::from_le_bytes([pair[0], pair[1]]); + expanded.push(Self::expand(quarter as u64)); + } + expanded + } + + /// On input a length, returns the smallest multiple of RATE_IN_BYTES that is greater than the bytelength. + /// That means that if the input has a length that is a multiple of the RATE_IN_BYTES, then + /// it needs to add one whole block of RATE_IN_BYTES bytes just for padding purposes. + pub fn padded_length(bytelength: usize) -> usize { + Self::num_blocks(bytelength) * RATE_IN_BYTES + } + + /// Pads the message with the 10*1 rule until reaching a length that is a multiple of the rate + pub fn pad(message: &[u8]) -> Vec { + let msg_len = message.len(); + let pad_len = Self::padded_length(msg_len); + let mut padded = vec![0; pad_len]; + for (i, byte) in message.iter().enumerate() { + padded[i] = *byte; + } + padded[msg_len] = 0x01; + padded[pad_len - 1] += 0x80; + + padded + } + + /// Number of blocks to be absorbed on input a given preimage bytelength + pub fn num_blocks(bytelength: usize) -> usize { + bytelength / RATE_IN_BYTES + 1 + } +} + +#[cfg(test)] +mod tests { + + use rand::{rngs::StdRng, thread_rng, Rng}; + use rand_core::SeedableRng; + + use super::*; + + #[test] + // Shows that the expansion of the 16-bit dense quarters into 64-bit sparse quarters + // corresponds to the binary representation of the 16-bit dense quarter. + fn test_bitwise_sparse_representation() { + assert_eq!(Keccak::expand(0xFFFF), 0x1111111111111111); + assert_eq!(Keccak::expand(0x0000), 0x0000000000000000); + assert_eq!(Keccak::expand(0x1234), 0x0001001000110100) + } + + #[test] + // Tests that composing and decomposition are the inverse of each other, + // and the order of the quarters is the desired one. + fn test_compose_decompose() { + let word: u64 = 0x70d324ac9215fd8e; + let dense = Keccak::decompose(word); + let expected_dense = [0xfd8e, 0x9215, 0x24ac, 0x70d3]; + for i in 0..QUARTERS { + assert_eq!(dense[i], expected_dense[i]); + } + assert_eq!(word, Keccak::compose(&dense)); + } + + #[test] + // Tests that expansion works as expected with one quarter word + fn test_quarter_expansion() { + let quarter: u16 = 0b01011010111011011; // 0xB5DB + let expected_expansion = 0b0001000000010001000000010000000100010001000000010001000000010001; // 0x01011010111011011 + assert_eq!(expected_expansion, Keccak::expand(quarter as u64)); + } + + #[test] + // Tests that expansion of decomposed quarters works as expected + fn test_sparse() { + let word: u64 = 0x1234567890abcdef; + let sparse = Keccak::sparse(word); + let expected_sparse: Vec = vec![ + 0x1100110111101111, // 0xcdef + 0x1001000010101011, // 0x90ab + 0x0101011001111000, // 0x5678 + 0x0001001000110100, // 0x1234 + ]; + for i in 0..QUARTERS { + assert_eq!(sparse[i], expected_sparse[i]); + } + } + + #[test] + // Tests that the shifts are computed correctly + fn test_shifts() { + let seed: [u8; 32] = thread_rng().gen(); + eprintln!("Seed: {:?}", seed); + let mut rng = StdRng::from_seed(seed); + let word: u64 = rng.gen_range(0..2u128.pow(64)) as u64; + let sparse = Keccak::sparse(word); + let shifts = Keccak::shift(&sparse); + for i in 0..QUARTERS { + assert_eq!( + sparse[i], + shifts[i] + shifts[4 + i] * 2 + shifts[8 + i] * 4 + shifts[12 + i] * 8 + ) + } + } + + #[test] + // Checks that reset function returns shift0, as the first positions of the shifts vector + fn test_reset() { + let seed: [u8; 32] = thread_rng().gen(); + eprintln!("Seed: {:?}", seed); + let mut rng = StdRng::from_seed(seed); + let word: u64 = rng.gen_range(0..2u128.pow(64)) as u64; + let shifts = Keccak::shift(&Keccak::sparse(word)); + let reset = Keccak::reset(&shifts); + assert_eq!(reset.len(), 4); + assert_eq!(shifts.len(), 16); + for i in 0..QUARTERS { + assert_eq!(reset[i], shifts[i]) + } + } + + #[test] + // Checks that one can obtain the original word from the resets of the expanded word + fn test_collapse() { + let seed: [u8; 32] = thread_rng().gen(); + eprintln!("Seed: {:?}", seed); + let mut rng = StdRng::from_seed(seed); + let word: u64 = rng.gen_range(0..2u128.pow(64)) as u64; + let dense = Keccak::compose(&Keccak::collapse(&Keccak::reset(&Keccak::shift( + &Keccak::sparse(word), + )))); + assert_eq!(word, dense); + } + + #[test] + // Checks that concatenating the maximum number of carries (15 per bit) result + // in the same original dense word, and just one more carry results in a different word + fn test_max_carries() { + let seed: [u8; 32] = thread_rng().gen(); + eprintln!("Seed: {:?}", seed); + let mut rng = StdRng::from_seed(seed); + let word: u64 = rng.gen_range(0..2u128.pow(64)) as u64; + let carries = 0xEEEE; + // add a few carry bits to the canonical representation + let mut sparse = Keccak::sparse(word) + .iter() + .map(|x| *x + carries) + .collect::>(); + let dense = Keccak::compose(&Keccak::collapse(&Keccak::reset(&Keccak::shift(&sparse)))); + assert_eq!(word, dense); + + sparse[0] += 1; + let wrong_dense = + Keccak::compose(&Keccak::collapse(&Keccak::reset(&Keccak::shift(&sparse)))); + assert_ne!(word, wrong_dense); + } + + #[test] + // Tests that the XOR can be represented in the 4i-th + // positions of the addition of sparse representations + fn test_sparse_xor() { + let seed: [u8; 32] = thread_rng().gen(); + eprintln!("Seed: {:?}", seed); + let mut rng = StdRng::from_seed(seed); + let a: u64 = rng.gen_range(0..2u128.pow(64)) as u64; + let b: u64 = rng.gen_range(0..2u128.pow(64)) as u64; + let xor = a ^ b; + + let sparse_a = Keccak::sparse(a); + let sparse_b = Keccak::sparse(b); + + // compute xor as sum of a and b + let sparse_sum = sparse_a + .iter() + .zip(sparse_b.iter()) + .map(|(a, b)| a + b) + .collect::>(); + let reset_sum = Keccak::reset(&Keccak::shift(&sparse_sum)); + + assert_eq!(Keccak::sparse(xor), reset_sum); + } + + #[test] + // Tests that the AND can be represented in the (4i+1)-th positions of the + // addition of canonical sparse representations + fn test_sparse_and() { + let seed: [u8; 32] = thread_rng().gen(); + eprintln!("Seed: {:?}", seed); + let mut rng = StdRng::from_seed(seed); + let a: u64 = rng.gen_range(0..2u128.pow(64)) as u64; + let b: u64 = rng.gen_range(0..2u128.pow(64)) as u64; + let and = a & b; + + let sparse_a = Keccak::sparse(a); + let sparse_b = Keccak::sparse(b); + + // compute and as carries of sum of a and b + let sparse_sum = sparse_a + .iter() + .zip(sparse_b.iter()) + .map(|(a, b)| a + b) + .collect::>(); + let carries_sum = &Keccak::shift(&sparse_sum)[4..8]; + + assert_eq!(Keccak::sparse(and), carries_sum); + } + + #[test] + // Tests that the NOT can be represented as subtraction with the expansion of + // the 16-bit dense quarter. + fn test_sparse_not() { + let seed: [u8; 32] = thread_rng().gen(); + eprintln!("Seed: {:?}", seed); + let mut rng = StdRng::from_seed(seed); + let word = rng.gen_range(0..2u64.pow(16)); + let expanded = Keccak::expand(word); + + // compute not as subtraction with expand all ones + let all_ones = 0xFFFF; + let not = all_ones - word; + let sparse_not = Keccak::expand(all_ones) - expanded; + + assert_eq!(not, Keccak::collapse(&[sparse_not])[0]); + } + + #[test] + // Checks that the padding length is correctly computed + fn test_pad_length() { + assert_eq!(Keccak::padded_length(0), RATE_IN_BYTES); + assert_eq!(Keccak::padded_length(1), RATE_IN_BYTES); + assert_eq!(Keccak::padded_length(RATE_IN_BYTES - 1), RATE_IN_BYTES); + // If input is already a multiple of RATE bytes, it needs to add a whole new block just for padding + assert_eq!(Keccak::padded_length(RATE_IN_BYTES), 2 * RATE_IN_BYTES); + assert_eq!( + Keccak::padded_length(RATE_IN_BYTES * 2 - 1), + 2 * RATE_IN_BYTES + ); + assert_eq!(Keccak::padded_length(RATE_IN_BYTES * 2), 3 * RATE_IN_BYTES); + } + + #[test] + // Checks that the padding is correctly computed + fn test_pad() { + let message = vec![0xFF; RATE_IN_BYTES - 1]; + let padded = Keccak::pad(&message); + assert_eq!(padded.len(), RATE_IN_BYTES); + assert_eq!(padded[padded.len() - 1], 0x81); + + let message = vec![0x01; RATE_IN_BYTES]; + let padded = Keccak::pad(&message); + assert_eq!(padded.len(), 2 * RATE_IN_BYTES); + assert_eq!(padded[message.len()], 0x01); + assert_eq!(padded[padded.len() - 1], 0x80); + } +} diff --git a/kimchi/src/circuits/polynomials/keccak/witness.rs b/kimchi/src/circuits/polynomials/keccak/witness.rs new file mode 100644 index 0000000000..632b7b4c76 --- /dev/null +++ b/kimchi/src/circuits/polynomials/keccak/witness.rs @@ -0,0 +1,657 @@ +//! Keccak witness computation + +use crate::{ + auto_clone, + circuits::{ + polynomials::keccak::{ + constants::{ + CAPACITY_IN_BYTES, DIM, KECCAK_COLS, QUARTERS, RATE_IN_BYTES, ROUNDS, STATE_LEN, + }, + Keccak, OFF, + }, + witness::{self, IndexCell, Variables, WitnessCell}, + }, + grid, variable_map, +}; +use ark_ff::PrimeField; +use num_bigint::BigUint; +use std::array; + +pub(crate) const SPARSE_RC: [[u64; QUARTERS]; ROUNDS] = [ + [ + 0x0000000000000001, + 0x0000000000000000, + 0x0000000000000000, + 0x0000000000000000, + ], + [ + 0x1000000010000010, + 0x0000000000000000, + 0x0000000000000000, + 0x0000000000000000, + ], + [ + 0x1000000010001010, + 0x0000000000000000, + 0x0000000000000000, + 0x1000000000000000, + ], + [ + 0x1000000000000000, + 0x1000000000000000, + 0x0000000000000000, + 0x1000000000000000, + ], + [ + 0x1000000010001011, + 0x0000000000000000, + 0x0000000000000000, + 0x0000000000000000, + ], + [ + 0x0000000000000001, + 0x1000000000000000, + 0x0000000000000000, + 0x0000000000000000, + ], + [ + 0x1000000010000001, + 0x1000000000000000, + 0x0000000000000000, + 0x1000000000000000, + ], + [ + 0x1000000000001001, + 0x0000000000000000, + 0x0000000000000000, + 0x1000000000000000, + ], + [ + 0x0000000010001010, + 0x0000000000000000, + 0x0000000000000000, + 0x0000000000000000, + ], + [ + 0x0000000010001000, + 0x0000000000000000, + 0x0000000000000000, + 0x0000000000000000, + ], + [ + 0x1000000000001001, + 0x1000000000000000, + 0x0000000000000000, + 0x0000000000000000, + ], + [ + 0x0000000000001010, + 0x1000000000000000, + 0x0000000000000000, + 0x0000000000000000, + ], + [ + 0x1000000010001011, + 0x1000000000000000, + 0x0000000000000000, + 0x0000000000000000, + ], + [ + 0x0000000010001011, + 0x0000000000000000, + 0x0000000000000000, + 0x1000000000000000, + ], + [ + 0x1000000010001001, + 0x0000000000000000, + 0x0000000000000000, + 0x1000000000000000, + ], + [ + 0x1000000000000011, + 0x0000000000000000, + 0x0000000000000000, + 0x1000000000000000, + ], + [ + 0x1000000000000010, + 0x0000000000000000, + 0x0000000000000000, + 0x1000000000000000, + ], + [ + 0x0000000010000000, + 0x0000000000000000, + 0x0000000000000000, + 0x1000000000000000, + ], + [ + 0x1000000000001010, + 0x0000000000000000, + 0x0000000000000000, + 0x0000000000000000, + ], + [ + 0x0000000000001010, + 0x1000000000000000, + 0x0000000000000000, + 0x1000000000000000, + ], + [ + 0x1000000010000001, + 0x1000000000000000, + 0x0000000000000000, + 0x1000000000000000, + ], + [ + 0x1000000010000000, + 0x0000000000000000, + 0x0000000000000000, + 0x1000000000000000, + ], + [ + 0x0000000000000001, + 0x1000000000000000, + 0x0000000000000000, + 0x0000000000000000, + ], + [ + 0x1000000000001000, + 0x1000000000000000, + 0x0000000000000000, + 0x1000000000000000, + ], +]; + +type Layout = Vec, COLUMNS>>>; + +fn layout_round() -> [Layout; 1] { + [vec![ + IndexCell::create("state_a", 0, 100), + IndexCell::create("shifts_c", 100, 180), + IndexCell::create("dense_c", 180, 200), + IndexCell::create("quotient_c", 200, 205), + IndexCell::create("remainder_c", 205, 225), + IndexCell::create("dense_rot_c", 225, 245), + IndexCell::create("expand_rot_c", 245, 265), + IndexCell::create("shifts_e", 265, 665), + IndexCell::create("dense_e", 665, 765), + IndexCell::create("quotient_e", 765, 865), + IndexCell::create("remainder_e", 865, 965), + IndexCell::create("dense_rot_e", 965, 1065), + IndexCell::create("expand_rot_e", 1065, 1165), + IndexCell::create("shifts_b", 1165, 1565), + IndexCell::create("shifts_sum", 1565, 1965), + ]] +} + +fn layout_sponge() -> [Layout; 1] { + [vec![ + IndexCell::create("old_state", 0, 100), + IndexCell::create("new_state", 100, 200), + IndexCell::create("bytes", 200, 400), + IndexCell::create("shifts", 400, 800), + ]] +} + +// Transforms a vector of u64 into a vector of field elements +fn field(input: &[u64]) -> Vec { + input.iter().map(|x| F::from(*x)).collect::>() +} + +// Contains the quotient, remainder, bound, dense rotated as quarters of at most 16 bits each +// Contains the expansion of the rotated word +pub struct Rotation { + quotient: Vec, + remainder: Vec, + dense_rot: Vec, + expand_rot: Vec, +} + +impl Rotation { + // On input the dense quarters of a word, rotate the word offset bits to the left + fn new(dense: &[u64], offset: u32) -> Self { + let word = Keccak::compose(dense); + let rem = word as u128 * 2u128.pow(offset) % 2u128.pow(64); + let quo = (word as u128) / 2u128.pow(64 - offset); + let rot = rem + quo; + assert!(rot as u64 == word.rotate_left(offset)); + + Self { + quotient: Keccak::decompose(quo as u64), + remainder: Keccak::decompose(rem as u64), + dense_rot: Keccak::decompose(rot as u64), + expand_rot: Keccak::decompose(rot as u64) + .iter() + .map(|x| Keccak::expand(*x)) + .collect(), + } + } + + // On input the dense quarters of many words, rotate the word offset bits to the left + fn many(words: &[u64], offsets: &[u32]) -> Self { + assert!(words.len() == QUARTERS * offsets.len()); + let mut quotient = vec![]; + let mut remainder = vec![]; + let mut dense_rot = vec![]; + let mut expand_rot = vec![]; + for (word, offset) in words.chunks(QUARTERS).zip(offsets.iter()) { + let mut rot = Self::new(word, *offset); + quotient.append(&mut rot.quotient); + remainder.append(&mut rot.remainder); + dense_rot.append(&mut rot.dense_rot); + expand_rot.append(&mut rot.expand_rot); + } + Self { + quotient, + remainder, + dense_rot, + expand_rot, + } + } +} + +/// Values involved in Theta permutation step +pub struct Theta { + shifts_c: Vec, + dense_c: Vec, + quotient_c: Vec, + remainder_c: Vec, + dense_rot_c: Vec, + expand_rot_c: Vec, + state_e: Vec, +} + +impl Theta { + pub fn create(state_a: &[u64]) -> Self { + let state_c = Self::compute_state_c(state_a); + let shifts_c = Keccak::shift(&state_c); + let dense_c = Keccak::collapse(&Keccak::reset(&shifts_c)); + let rotation_c = Rotation::many(&dense_c, &[1; DIM]); + let state_d = Self::compute_state_d(&shifts_c, &rotation_c.expand_rot); + let state_e = Self::compute_state_e(state_a, &state_d); + let quotient_c = vec![ + rotation_c.quotient[0], + rotation_c.quotient[4], + rotation_c.quotient[8], + rotation_c.quotient[12], + rotation_c.quotient[16], + ]; + Self { + shifts_c, + dense_c, + quotient_c, + remainder_c: rotation_c.remainder, + dense_rot_c: rotation_c.dense_rot, + expand_rot_c: rotation_c.expand_rot, + state_e, + } + } + + pub fn shifts_c(&self, i: usize, x: usize, q: usize) -> u64 { + let shifts_c = grid!(80, &self.shifts_c); + shifts_c(i, x, q) + } + + pub fn dense_c(&self, x: usize, q: usize) -> u64 { + let dense_c = grid!(20, &self.dense_c); + dense_c(x, q) + } + + pub fn quotient_c(&self, x: usize) -> u64 { + self.quotient_c[x] + } + + pub fn remainder_c(&self, x: usize, q: usize) -> u64 { + let remainder_c = grid!(20, &self.remainder_c); + remainder_c(x, q) + } + + pub fn dense_rot_c(&self, x: usize, q: usize) -> u64 { + let dense_rot_c = grid!(20, &self.dense_rot_c); + dense_rot_c(x, q) + } + + pub fn expand_rot_c(&self, x: usize, q: usize) -> u64 { + let expand_rot_c = grid!(20, &self.expand_rot_c); + expand_rot_c(x, q) + } + + pub fn state_e(&self) -> Vec { + self.state_e.clone() + } + + fn compute_state_c(state_a: &[u64]) -> Vec { + let state_a = grid!(100, state_a); + let mut state_c = vec![]; + for x in 0..DIM { + for q in 0..QUARTERS { + state_c.push( + state_a(0, x, q) + + state_a(1, x, q) + + state_a(2, x, q) + + state_a(3, x, q) + + state_a(4, x, q), + ); + } + } + state_c + } + + fn compute_state_d(shifts_c: &[u64], expand_rot_c: &[u64]) -> Vec { + let shifts_c = grid!(20, shifts_c); + let expand_rot_c = grid!(20, expand_rot_c); + let mut state_d = vec![]; + for x in 0..DIM { + for q in 0..QUARTERS { + state_d.push(shifts_c((x + DIM - 1) % DIM, q) + expand_rot_c((x + 1) % DIM, q)); + } + } + state_d + } + + fn compute_state_e(state_a: &[u64], state_d: &[u64]) -> Vec { + let state_a = grid!(100, state_a); + let state_d = grid!(20, state_d); + let mut state_e = vec![]; + for y in 0..DIM { + for x in 0..DIM { + for q in 0..QUARTERS { + state_e.push(state_a(y, x, q) + state_d(x, q)); + } + } + } + state_e + } +} + +/// Values involved in PiRho permutation step +pub struct PiRho { + shifts_e: Vec, + dense_e: Vec, + quotient_e: Vec, + remainder_e: Vec, + dense_rot_e: Vec, + expand_rot_e: Vec, + state_b: Vec, +} + +impl PiRho { + pub fn create(state_e: &[u64]) -> Self { + let shifts_e = Keccak::shift(state_e); + let dense_e = Keccak::collapse(&Keccak::reset(&shifts_e)); + let rotation_e = Rotation::many( + &dense_e, + &OFF.iter() + .flatten() + .map(|x| *x as u32) + .collect::>(), + ); + + let mut state_b = vec![vec![vec![0; QUARTERS]; DIM]; DIM]; + let aux = grid!(100, rotation_e.expand_rot); + for y in 0..DIM { + for x in 0..DIM { + for q in 0..QUARTERS { + state_b[(2 * x + 3 * y) % DIM][y][q] = aux(y, x, q); + } + } + } + let state_b = state_b.iter().flatten().flatten().copied().collect(); + + Self { + shifts_e, + dense_e, + quotient_e: rotation_e.quotient, + remainder_e: rotation_e.remainder, + dense_rot_e: rotation_e.dense_rot, + expand_rot_e: rotation_e.expand_rot, + state_b, + } + } + + pub fn shifts_e(&self, i: usize, y: usize, x: usize, q: usize) -> u64 { + let shifts_e = grid!(400, &self.shifts_e); + shifts_e(i, y, x, q) + } + + pub fn dense_e(&self, y: usize, x: usize, q: usize) -> u64 { + let dense_e = grid!(100, &self.dense_e); + dense_e(y, x, q) + } + + pub fn quotient_e(&self, y: usize, x: usize, q: usize) -> u64 { + let quotient_e = grid!(100, &self.quotient_e); + quotient_e(y, x, q) + } + + pub fn remainder_e(&self, y: usize, x: usize, q: usize) -> u64 { + let remainder_e = grid!(100, &self.remainder_e); + remainder_e(y, x, q) + } + + pub fn dense_rot_e(&self, y: usize, x: usize, q: usize) -> u64 { + let dense_rot_e = grid!(100, &self.dense_rot_e); + dense_rot_e(y, x, q) + } + + pub fn expand_rot_e(&self, y: usize, x: usize, q: usize) -> u64 { + let expand_rot_e = grid!(100, &self.expand_rot_e); + expand_rot_e(y, x, q) + } + + pub fn state_b(&self) -> Vec { + self.state_b.clone() + } +} + +/// Values involved in Chi permutation step +pub struct Chi { + shifts_b: Vec, + shifts_sum: Vec, + state_f: Vec, +} + +impl Chi { + pub fn create(state_b: &[u64]) -> Self { + let shifts_b = Keccak::shift(state_b); + let shiftsb = grid!(400, shifts_b); + let mut sum = vec![]; + for y in 0..DIM { + for x in 0..DIM { + for q in 0..QUARTERS { + let not = 0x1111111111111111u64 - shiftsb(0, y, (x + 1) % DIM, q); + sum.push(not + shiftsb(0, y, (x + 2) % DIM, q)); + } + } + } + let shifts_sum = Keccak::shift(&sum); + let shiftsum = grid!(400, shifts_sum); + let mut state_f = vec![]; + for y in 0..DIM { + for x in 0..DIM { + for q in 0..QUARTERS { + let and = shiftsum(1, y, x, q); + state_f.push(shiftsb(0, y, x, q) + and); + } + } + } + + Self { + shifts_b, + shifts_sum, + state_f, + } + } + + pub fn shifts_b(&self, i: usize, y: usize, x: usize, q: usize) -> u64 { + let shifts_b = grid!(400, &self.shifts_b); + shifts_b(i, y, x, q) + } + + pub fn shifts_sum(&self, i: usize, y: usize, x: usize, q: usize) -> u64 { + let shifts_sum = grid!(400, &self.shifts_sum); + shifts_sum(i, y, x, q) + } + + pub fn state_f(&self) -> Vec { + self.state_f.clone() + } +} + +/// Values involved in Iota permutation step +pub struct Iota { + state_g: Vec, + round_constants: [u64; QUARTERS], +} + +impl Iota { + pub fn create(state_f: &[u64], round: usize) -> Self { + let round_constants = SPARSE_RC[round]; + let mut state_g = state_f.to_vec(); + for (i, c) in round_constants.iter().enumerate() { + state_g[i] = state_f[i] + *c; + } + Self { + state_g, + round_constants, + } + } + + pub fn state_g(&self) -> Vec { + self.state_g.clone() + } + + pub fn round_constants(&self, i: usize) -> u64 { + self.round_constants[i] + } +} + +/// Creates a witness for the Keccak hash function +/// Input: +/// - message: the message to be hashed +/// Note: +/// Requires at least one more row after the keccak gadget so that +/// constraints can access the next row in the squeeze +pub fn extend_keccak_witness(witness: &mut [Vec; KECCAK_COLS], message: BigUint) { + let padded = Keccak::pad(&message.to_bytes_be()); + let chunks = padded.chunks(RATE_IN_BYTES); + + // The number of rows that need to be added to the witness correspond to + // - Absorb phase: + // - 1 per block for the sponge row + // - 24 for the rounds + // - Squeeze phase: + // - 1 for the final sponge row + let rows: usize = chunks.len() * (ROUNDS + 1) + 1; + + let mut keccak_witness = array::from_fn(|_| vec![F::zero(); rows]); + + // Absorb phase + let mut row = 0; + let mut state = vec![0; QUARTERS * DIM * DIM]; + for chunk in chunks { + let mut block = chunk.to_vec(); + // Pad the block until reaching 200 bytes + block.append(&mut vec![0; CAPACITY_IN_BYTES]); + let new_state = Keccak::expand_state(&block); + auto_clone!(new_state); + let shifts = Keccak::shift(&new_state()); + let bytes = block.iter().map(|b| *b as u64).collect::>(); + + // Initialize the absorb sponge row + witness::init( + &mut keccak_witness, + row, + &layout_sponge(), + &variable_map!["old_state" => field(&state), "new_state" => field(&new_state()), "bytes" => field(&bytes), "shifts" => field(&shifts)], + ); + row += 1; + + let xor_state = state + .iter() + .zip(new_state()) + .map(|(x, y)| x + y) + .collect::>(); + + let mut ini_state = xor_state.clone(); + + for round in 0..ROUNDS { + // Theta + let theta = Theta::create(&ini_state); + + // PiRho + let pirho = PiRho::create(&theta.state_e); + + // Chi + let chi = Chi::create(&pirho.state_b); + + // Iota + let iota = Iota::create(&chi.state_f, round); + + // Initialize the round row + witness::init( + &mut keccak_witness, + row, + &layout_round(), + &variable_map![ + "state_a" => field(&ini_state), + "shifts_c" => field(&theta.shifts_c), + "dense_c" => field(&theta.dense_c), + "quotient_c" => field(&theta.quotient_c), + "remainder_c" => field(&theta.remainder_c), + "dense_rot_c" => field(&theta.dense_rot_c), + "expand_rot_c" => field(&theta.expand_rot_c), + "shifts_e" => field(&pirho.shifts_e), + "dense_e" => field(&pirho.dense_e), + "quotient_e" => field(&pirho.quotient_e), + "remainder_e" => field(&pirho.remainder_e), + "dense_rot_e" => field(&pirho.dense_rot_e), + "expand_rot_e" => field(&pirho.expand_rot_e), + "shifts_b" => field(&chi.shifts_b), + "shifts_sum" => field(&chi.shifts_sum) + ], + ); + row += 1; + ini_state = iota.state_g; + } + // update state after rounds + state = ini_state; + } + + // Squeeze phase + + let new_state = vec![0; STATE_LEN]; + let shifts = Keccak::shift(&state); + let dense = Keccak::collapse(&Keccak::reset(&shifts)); + let bytes = Keccak::bytestring(&dense); + + // Initialize the squeeze sponge row + witness::init( + &mut keccak_witness, + row, + &layout_sponge(), + &variable_map!["old_state" => field(&state), "new_state" => field(&new_state), "bytes" => field(&bytes), "shifts" => field(&shifts)], + ); + + for col in 0..KECCAK_COLS { + witness[col].extend(keccak_witness[col].iter()); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::circuits::polynomials::keccak::RC; + + #[test] + fn test_sparse_round_constants() { + for round in 0..ROUNDS { + let round_constants = Keccak::sparse(RC[round]); + for (i, rc) in round_constants.iter().enumerate().take(QUARTERS) { + assert_eq!(*rc, SPARSE_RC[round][i]); + } + } + } +} diff --git a/kimchi/src/circuits/polynomials/mod.rs b/kimchi/src/circuits/polynomials/mod.rs index bbbe57939d..fbfc712ad9 100644 --- a/kimchi/src/circuits/polynomials/mod.rs +++ b/kimchi/src/circuits/polynomials/mod.rs @@ -3,8 +3,10 @@ pub mod complete_add; pub mod endomul_scalar; pub mod endosclmul; pub mod foreign_field_add; +pub mod foreign_field_common; pub mod foreign_field_mul; pub mod generic; +pub mod keccak; pub mod not; pub mod permutation; pub mod poseidon; diff --git a/kimchi/src/circuits/polynomials/not.rs b/kimchi/src/circuits/polynomials/not.rs index 4db55e4d6d..9773afd24f 100644 --- a/kimchi/src/circuits/polynomials/not.rs +++ b/kimchi/src/circuits/polynomials/not.rs @@ -90,7 +90,7 @@ impl CircuitGate { // Extends a double generic gate for negation for one or two words // Input: - // - gates : full ciricuit to which the not will be added + // - gates : full circuit to which the not will be added // - new_row : row to start the double NOT generic gate // - pub_row : row where the public inputs is stored (the 2^bits - 1) value (in the column 0 of that row) // - double_generic : whether to perform two NOTs or only one inside the generic gate diff --git a/kimchi/src/circuits/polynomials/poseidon.rs b/kimchi/src/circuits/polynomials/poseidon.rs index 587346d953..7f4290a458 100644 --- a/kimchi/src/circuits/polynomials/poseidon.rs +++ b/kimchi/src/circuits/polynomials/poseidon.rs @@ -28,6 +28,7 @@ use crate::{ circuits::{ argument::{Argument, ArgumentEnv, ArgumentType}, + berkeley_columns::BerkeleyChallengeTerm, expr::{constraints::ExprOps, Cache}, gate::{CircuitGate, CurrOrNext, GateType}, polynomial::COLUMNS, @@ -339,7 +340,10 @@ where const ARGUMENT_TYPE: ArgumentType = ArgumentType::Gate(GateType::Poseidon); const CONSTRAINTS: u32 = 15; - fn constraint_checks>(env: &ArgumentEnv, cache: &mut Cache) -> Vec { + fn constraint_checks>( + env: &ArgumentEnv, + cache: &mut Cache, + ) -> Vec { let mut res = vec![]; let mut idx = 0; diff --git a/kimchi/src/circuits/polynomials/range_check/circuitgates.rs b/kimchi/src/circuits/polynomials/range_check/circuitgates.rs index 5abf525e4c..2a9b3104b1 100644 --- a/kimchi/src/circuits/polynomials/range_check/circuitgates.rs +++ b/kimchi/src/circuits/polynomials/range_check/circuitgates.rs @@ -118,6 +118,7 @@ use std::marker::PhantomData; use crate::circuits::{ argument::{Argument, ArgumentEnv, ArgumentType}, + berkeley_columns::BerkeleyChallengeTerm, expr::{ constraints::{crumb, ExprOps}, Cache, @@ -178,7 +179,10 @@ where // * Operates on Curr row // * Range constrain all limbs except vp0 and vp1 (barring plookup constraints, which are done elsewhere) // * Constrain that combining all limbs equals the limb stored in column 0 - fn constraint_checks>(env: &ArgumentEnv, _cache: &mut Cache) -> Vec { + fn constraint_checks>( + env: &ArgumentEnv, + _cache: &mut Cache, + ) -> Vec { // 1) Apply range constraints on the limbs // * Columns 1-2 are 12-bit copy constraints // * They are copied 3 rows ahead (to the final row) and are constrained by lookups @@ -279,7 +283,10 @@ where // * Operates on Curr and Next row // * Range constrain all limbs (barring plookup constraints, which are done elsewhere) // * Constrain that combining all limbs equals the value v2 stored in row Curr, column 0 - fn constraint_checks>(env: &ArgumentEnv, _cache: &mut Cache) -> Vec { + fn constraint_checks>( + env: &ArgumentEnv, + _cache: &mut Cache, + ) -> Vec { // 1) Apply range constraints on limbs for Curr row // * Column 2 is a 2-bit crumb let mut constraints = vec![crumb(&env.witness_curr(2))]; @@ -330,7 +337,7 @@ where power_of_2 *= 4u64.into(); // 2 bits } - // Next row: Sum remaining 2-bit limbs vc10 and vc11 + // Next row: Sum remaining 2-bit limbs v2c9, v2c10, and v2c11 (reverse order) for i in (0..=2).rev() { sum_of_limbs += power_of_2.clone() * env.witness_next(i); power_of_2 *= 4u64.into(); // 2 bits diff --git a/kimchi/src/circuits/polynomials/range_check/gadget.rs b/kimchi/src/circuits/polynomials/range_check/gadget.rs index 5815488441..309f4de50e 100644 --- a/kimchi/src/circuits/polynomials/range_check/gadget.rs +++ b/kimchi/src/circuits/polynomials/range_check/gadget.rs @@ -6,7 +6,8 @@ use crate::{ alphas::Alphas, circuits::{ argument::Argument, - expr::{Cache, E}, + berkeley_columns::E, + expr::Cache, gate::{CircuitGate, Connect, GateType}, lookup::{ self, diff --git a/kimchi/src/circuits/polynomials/range_check/witness.rs b/kimchi/src/circuits/polynomials/range_check/witness.rs index e133c8b372..38ba5178a3 100644 --- a/kimchi/src/circuits/polynomials/range_check/witness.rs +++ b/kimchi/src/circuits/polynomials/range_check/witness.rs @@ -9,11 +9,11 @@ use std::array; use crate::{ circuits::{ polynomial::COLUMNS, + polynomials::foreign_field_common::{BigUintForeignFieldHelpers, LIMB_BITS}, witness::{init_row, CopyBitsCell, CopyCell, VariableCell, Variables, WitnessCell}, }, variable_map, variables, }; -use o1_utils::foreign_field::BigUintForeignFieldHelpers; /// Witness layout /// * The values and cell contents are in little-endian order. @@ -26,14 +26,14 @@ use o1_utils::foreign_field::BigUintForeignFieldHelpers; /// For example, we can convert the `RangeCheck0` circuit gate into /// a 64-bit lookup by adding two copy constraints to constrain /// columns 1 and 2 to zero. -fn layout() -> Vec<[Box>; COLUMNS]> { - vec![ +fn layout() -> [Vec>>; 4] { + [ /* row 1, RangeCheck0 row */ range_check_0_row("v0", 0), /* row 2, RangeCheck0 row */ range_check_0_row("v1", 1), /* row 3, RangeCheck1 row */ - [ + vec![ VariableCell::create("v2"), VariableCell::create("v12"), // optional /* 2-bit crumbs (placed here to keep lookup pattern */ @@ -55,7 +55,7 @@ fn layout() -> Vec<[Box>; COLUMNS]> { CopyBitsCell::create(2, 0, 22, 24), ], /* row 4, Zero row */ - [ + vec![ CopyBitsCell::create(2, 0, 20, 22), /* 2-bit crumbs (placed here to keep lookup pattern */ /* the same as RangeCheck0) */ @@ -83,8 +83,8 @@ fn layout() -> Vec<[Box>; COLUMNS]> { pub fn range_check_0_row( limb_name: &'static str, row: usize, -) -> [Box>; COLUMNS] { - [ +) -> Vec>> { + vec![ VariableCell::create(limb_name), /* 12-bit copies */ // Copy cells are required because we have a limit @@ -211,7 +211,7 @@ pub fn extend_multi_compact_limbs(witness: &mut [Vec; COLUMNS] /// Extend an existing witness with a multi-range-check gadget for ForeignElement pub fn extend_multi_from_fe( witness: &mut [Vec; COLUMNS], - fe: &ForeignElement, + fe: &ForeignElement, ) { extend_multi(witness, fe.limbs[0], fe.limbs[1], fe.limbs[2]); } diff --git a/kimchi/src/circuits/polynomials/rot.rs b/kimchi/src/circuits/polynomials/rot.rs index 4ed7089997..b07f2e00fa 100644 --- a/kimchi/src/circuits/polynomials/rot.rs +++ b/kimchi/src/circuits/polynomials/rot.rs @@ -4,6 +4,7 @@ use super::range_check::witness::range_check_0_row; use crate::{ circuits::{ argument::{Argument, ArgumentEnv, ArgumentType}, + berkeley_columns::BerkeleyChallengeTerm, expr::{ constraints::{crumb, ExprOps}, Cache, @@ -31,11 +32,14 @@ pub enum RotMode { impl CircuitGate { /// Creates a Rot64 gadget to rotate a word /// It will need: - /// - 1 Generic gate to constrain to zero the top 2 limbs of the shifted witness of the rotation + /// - 1 Generic gate to constrain to zero the top 2 limbs of the shifted and excess witness of the rotation /// /// It has: /// - 1 Rot64 gate to rotate the word /// - 1 RangeCheck0 to constrain the size of the shifted witness of the rotation + /// - 1 RangeCheck0 to constrain the size of the excess witness of the rotation + /// Assumes: + /// - the witness word is 64-bits, otherwise, will need to append a new RangeCheck0 for the word pub fn create_rot64(new_row: usize, rot: u32) -> Vec { vec![ CircuitGate { @@ -48,6 +52,11 @@ impl CircuitGate { wires: Wire::for_row(new_row + 1), coeffs: vec![F::zero()], }, + CircuitGate { + typ: GateType::RangeCheck0, + wires: Wire::for_row(new_row + 2), + coeffs: vec![F::zero()], + }, ] } @@ -65,8 +74,12 @@ impl CircuitGate { pub fn extend_rot(gates: &mut Vec, rot: u32, side: RotMode, zero_row: usize) -> usize { let (new_row, mut rot_gates) = Self::create_rot(gates.len(), rot, side); gates.append(&mut rot_gates); - // Check that 2 most significant limbs of shifted are zero + // Check that 2 most significant limbs of shifted and excess are zero + gates.connect_64bit(zero_row, new_row - 2); gates.connect_64bit(zero_row, new_row - 1); + // Connect excess with the Rot64 gate + gates.connect_cell_pair((new_row - 3, 2), (new_row - 1, 0)); + gates.len() } @@ -140,25 +153,25 @@ pub fn lookup_table() -> LookupTable { //~ is almost empty, we can use it to perform the range check within the same gate. Then, using the following layout //~ and assuming that the gate has a coefficient storing the value $2^{rot}$, which is publicly known //~ -//~ | Gate | `Rot64` | `RangeCheck0` | -//~ | ------ | ------------------- | ---------------- | -//~ | Column | `Curr` | `Next` | -//~ | ------ | ------------------- | ---------------- | -//~ | 0 | copy `word` |`shifted` | -//~ | 1 | copy `rotated` | 0 | -//~ | 2 | `excess` | 0 | -//~ | 3 | `bound_limb0` | `shifted_limb0` | -//~ | 4 | `bound_limb1` | `shifted_limb1` | -//~ | 5 | `bound_limb2` | `shifted_limb2` | -//~ | 6 | `bound_limb3` | `shifted_limb3` | -//~ | 7 | `bound_crumb0` | `shifted_crumb0` | -//~ | 8 | `bound_crumb1` | `shifted_crumb1` | -//~ | 9 | `bound_crumb2` | `shifted_crumb2` | -//~ | 10 | `bound_crumb3` | `shifted_crumb3` | -//~ | 11 | `bound_crumb4` | `shifted_crumb4` | -//~ | 12 | `bound_crumb5` | `shifted_crumb5` | -//~ | 13 | `bound_crumb6` | `shifted_crumb6` | -//~ | 14 | `bound_crumb7` | `shifted_crumb7` | +//~ | Gate | `Rot64` | `RangeCheck0` gadgets (designer's duty) | +//~ | ------ | ------------------- | --------------------------------------------------------- | +//~ | Column | `Curr` | `Next` | `Next` + 1 | `Next`+ 2, if needed | +//~ | ------ | ------------------- | ---------------- | --------------- | -------------------- | +//~ | 0 | copy `word` |`shifted` | copy `excess` | copy `word` | +//~ | 1 | copy `rotated` | 0 | 0 | 0 | +//~ | 2 | `excess` | 0 | 0 | 0 | +//~ | 3 | `bound_limb0` | `shifted_limb0` | `excess_limb0` | `word_limb0` | +//~ | 4 | `bound_limb1` | `shifted_limb1` | `excess_limb1` | `word_limb1` | +//~ | 5 | `bound_limb2` | `shifted_limb2` | `excess_limb2` | `word_limb2` | +//~ | 6 | `bound_limb3` | `shifted_limb3` | `excess_limb3` | `word_limb3` | +//~ | 7 | `bound_crumb0` | `shifted_crumb0` | `excess_crumb0` | `word_crumb0` | +//~ | 8 | `bound_crumb1` | `shifted_crumb1` | `excess_crumb1` | `word_crumb1` | +//~ | 9 | `bound_crumb2` | `shifted_crumb2` | `excess_crumb2` | `word_crumb2` | +//~ | 10 | `bound_crumb3` | `shifted_crumb3` | `excess_crumb3` | `word_crumb3` | +//~ | 11 | `bound_crumb4` | `shifted_crumb4` | `excess_crumb4` | `word_crumb4` | +//~ | 12 | `bound_crumb5` | `shifted_crumb5` | `excess_crumb5` | `word_crumb5` | +//~ | 13 | `bound_crumb6` | `shifted_crumb6` | `excess_crumb6` | `word_crumb6` | +//~ | 14 | `bound_crumb7` | `shifted_crumb7` | `excess_crumb7` | `word_crumb7` | //~ //~ In Keccak, rotations are performed over a 5x5 matrix state of w-bit words each cell. The values used //~ to perform the rotation are fixed, public, and known in advance, according to the following table, @@ -201,7 +214,10 @@ where // (stored in coefficient as a power-of-two form) // * Operates on Curr row // * Shifts the words by `rot` bits and then adds the excess to obtain the rotated word. - fn constraint_checks>(env: &ArgumentEnv, _cache: &mut Cache) -> Vec { + fn constraint_checks>( + env: &ArgumentEnv, + _cache: &mut Cache, + ) -> Vec { // Check that the last 8 columns are 2-bit crumbs // C1..C8: x * (x - 1) * (x - 2) * (x - 3) = 0 let mut constraints = (7..COLUMNS) @@ -254,12 +270,16 @@ where // ROTATION WITNESS COMPUTATION -fn layout_rot64(curr_row: usize) -> [[Box>; COLUMNS]; 2] { - [rot_row(), range_check_0_row("shifted", curr_row + 1)] +fn layout_rot64(curr_row: usize) -> [Vec>>; 3] { + [ + rot_row(), + range_check_0_row("shifted", curr_row + 1), + range_check_0_row("excess", curr_row + 2), + ] } -fn rot_row() -> [Box>; COLUMNS] { - [ +fn rot_row() -> Vec>> { + vec![ VariableCell::create("word"), VariableCell::create("rotated"), VariableCell::create("excess"), @@ -312,8 +332,8 @@ pub fn extend_rot( rot: u32, side: RotMode, ) { - assert!(rot < 64, "Rotation value must be less than 64"); - assert_ne!(rot, 0, "Rotation value must be non-zero"); + assert!(rot <= 64, "Rotation value must be less or equal than 64"); + let rot = if side == RotMode::Right { 64 - rot } else { @@ -327,15 +347,15 @@ pub fn extend_rot( // shifted [------] * 2^rot // rot = [------|000] // + [---] excess - let shifted = (word as u128 * 2u128.pow(rot) % 2u128.pow(64)) as u64; - let excess = word / 2u64.pow(64 - rot); + let shifted = (word as u128) * 2u128.pow(rot) % 2u128.pow(64); + let excess = (word as u128) / 2u128.pow(64 - rot); let rotated = shifted + excess; // Value for the added value for the bound // Right input of the "FFAdd" for the bound equation let bound = 2u128.pow(64) - 2u128.pow(rot); let rot_row = witness[0].len(); - let rot_witness: [Vec; COLUMNS] = array::from_fn(|_| vec![F::zero(); 2]); + let rot_witness: [Vec; COLUMNS] = array::from_fn(|_| vec![F::zero(); 3]); for col in 0..COLUMNS { witness[col].extend(rot_witness[col].iter()); } diff --git a/kimchi/src/circuits/polynomials/turshi.rs b/kimchi/src/circuits/polynomials/turshi.rs index 6a581df924..d8f8551ca5 100644 --- a/kimchi/src/circuits/polynomials/turshi.rs +++ b/kimchi/src/circuits/polynomials/turshi.rs @@ -20,7 +20,7 @@ //! · ap += <`op`> //! · ap++ //! -//! A Cairo program runs accross a number of state transitions. +//! A Cairo program runs across a number of state transitions. //! Each state transition has the following structure: //! //! * Has access to a read-only memory @@ -82,8 +82,9 @@ use crate::{ alphas::Alphas, circuits::{ argument::{Argument, ArgumentEnv, ArgumentType}, + berkeley_columns::{BerkeleyChallengeTerm, BerkeleyChallenges, Column, E}, constraints::ConstraintSystem, - expr::{self, constraints::ExprOps, Cache, Column, E}, + expr::{self, constraints::ExprOps, Cache}, gate::{CircuitGate, GateType}, wires::{GateWires, Wire, COLUMNS}, }, @@ -91,6 +92,7 @@ use crate::{ proof::ProofEvaluations, }; use ark_ff::{FftField, Field, PrimeField}; +use log::error; use std::{array, marker::PhantomData}; use turshi::{ runner::{CairoInstruction, CairoProgram, Pointers}, @@ -218,21 +220,23 @@ impl CircuitGate { // Setup circuit constants let constants = expr::Constants { - alpha: F::rand(rng), - beta: F::rand(rng), - gamma: F::rand(rng), - joint_combiner: None, endo_coefficient: cs.endo, mds: &G::sponge_params().mds, zk_rows: 3, }; + let challenges = BerkeleyChallenges { + alpha: F::rand(rng), + beta: F::rand(rng), + gamma: F::rand(rng), + joint_combiner: F::zero(), + }; let pt = F::rand(rng); // Evaluate constraints match linearized .constant_term - .evaluate_(cs.domain.d1, pt, &evals, &constants) + .evaluate_(cs.domain.d1, pt, &evals, &constants, &challenges) { Ok(x) => { if x == F::zero() { @@ -242,7 +246,7 @@ impl CircuitGate { } } Err(x) => { - println!("{x:?}"); + error!("{x:?}"); Err(format!("Failed to evaluate {:?} constraint", self.typ)) } } @@ -728,7 +732,7 @@ pub mod testing { // CONSTRAINTS-RELATED /// Returns the expression corresponding to the literal "2" -fn two>() -> T { +fn two>() -> T { T::literal(2u64.into()) // 2 } @@ -763,7 +767,10 @@ where /// Generates the constraints for the Cairo initial claim and first memory checks /// Accesses Curr and Next rows - fn constraint_checks>(env: &ArgumentEnv, _cache: &mut Cache) -> Vec { + fn constraint_checks>( + env: &ArgumentEnv, + _cache: &mut Cache, + ) -> Vec { let pc_ini = env.witness_curr(0); // copy from public input let ap_ini = env.witness_curr(1); // copy from public input let pc_fin = env.witness_curr(2); // copy from public input @@ -800,7 +807,10 @@ where /// Generates the constraints for the Cairo instruction /// Accesses Curr and Next rows - fn constraint_checks>(env: &ArgumentEnv, cache: &mut Cache) -> Vec { + fn constraint_checks>( + env: &ArgumentEnv, + cache: &mut Cache, + ) -> Vec { // load all variables of the witness corresponding to Cairoinstruction gates let pc = env.witness_curr(0); let ap = env.witness_curr(1); @@ -946,7 +956,10 @@ where /// Generates the constraints for the Cairo flags /// Accesses Curr and Next rows - fn constraint_checks>(env: &ArgumentEnv, _cache: &mut Cache) -> Vec { + fn constraint_checks>( + env: &ArgumentEnv, + _cache: &mut Cache, + ) -> Vec { // Load current row let f_pc_abs = env.witness_curr(7); let f_pc_rel = env.witness_curr(8); @@ -1013,7 +1026,10 @@ where /// Generates the constraints for the Cairo transition /// Accesses Curr and Next rows (Next only first 3 entries) - fn constraint_checks>(env: &ArgumentEnv, _cache: &mut Cache) -> Vec { + fn constraint_checks>( + env: &ArgumentEnv, + _cache: &mut Cache, + ) -> Vec { // load computed updated registers let pcup = env.witness_curr(7); let apup = env.witness_curr(8); diff --git a/kimchi/src/circuits/polynomials/varbasemul.rs b/kimchi/src/circuits/polynomials/varbasemul.rs index 22b9522195..c3face3e84 100644 --- a/kimchi/src/circuits/polynomials/varbasemul.rs +++ b/kimchi/src/circuits/polynomials/varbasemul.rs @@ -12,7 +12,8 @@ use crate::circuits::{ argument::{Argument, ArgumentEnv, ArgumentType}, - expr::{constraints::ExprOps, Cache, Column, Variable}, + berkeley_columns::{BerkeleyChallengeTerm, Column}, + expr::{constraints::ExprOps, Cache, Variable as VariableGen}, gate::{CircuitGate, CurrOrNext, GateType}, wires::{GateWires, COLUMNS}, }; @@ -20,6 +21,8 @@ use ark_ff::{FftField, PrimeField}; use std::marker::PhantomData; use CurrOrNext::{Curr, Next}; +type Variable = VariableGen; + //~ We implement custom Plonk constraints for short Weierstrass curve variable base scalar multiplication. //~ //~ Given a finite field $\mathbb{F}_q$ of order $q$, if the order is not a multiple of 2 nor 3, then an @@ -47,12 +50,11 @@ use CurrOrNext::{Curr, Next}; //~ ```ignore //~ Acc := [2]T //~ for i = n-1 ... 0: -//~ Q := (r_i == 1) ? T : -T +//~ Q := (k_{i + 1} == 1) ? T : -T //~ Acc := Acc + (Q + Acc) -//~ return (d_0 == 0) ? Q - P : Q +//~ return (k_0 == 0) ? Acc - P : Acc //~ ``` //~ -//~ //~ The layout of the witness requires 2 rows. //~ The i-th row will be a `VBSM` gate whereas the next row will be a `ZERO` gate. //~ @@ -169,7 +171,10 @@ impl Point { } impl Point { - pub fn new_from_env>(&self, env: &ArgumentEnv) -> Point { + pub fn new_from_env>( + &self, + env: &ArgumentEnv, + ) -> Point { Point::create(self.x.new_from_env(env), self.y.new_from_env(env)) } } @@ -219,7 +224,7 @@ fn single_bit_witness( (out_x, out_y) } -fn single_bit>( +fn single_bit>( cache: &mut Cache, b: &T, base: Point, @@ -286,7 +291,7 @@ where impl FromWitness for Variable where F: PrimeField, - T: ExprOps, + T: ExprOps, { fn new_from_env(&self, env: &ArgumentEnv) -> T { let column_to_index = |_| match self.col { @@ -324,7 +329,10 @@ impl Layout { } } - fn new_from_env>(&self, env: &ArgumentEnv) -> Layout { + fn new_from_env>( + &self, + env: &ArgumentEnv, + ) -> Layout { Layout { accs: self.accs.map(|point| point.new_from_env(env)), bits: self.bits.map(|var| var.new_from_env(env)), @@ -407,7 +415,10 @@ where const ARGUMENT_TYPE: ArgumentType = ArgumentType::Gate(GateType::VarBaseMul); const CONSTRAINTS: u32 = 21; - fn constraint_checks>(env: &ArgumentEnv, cache: &mut Cache) -> Vec { + fn constraint_checks>( + env: &ArgumentEnv, + cache: &mut Cache, + ) -> Vec { let Layout { base, accs, diff --git a/kimchi/src/circuits/polynomials/xor.rs b/kimchi/src/circuits/polynomials/xor.rs index 564b7d25c8..85a7270128 100644 --- a/kimchi/src/circuits/polynomials/xor.rs +++ b/kimchi/src/circuits/polynomials/xor.rs @@ -4,6 +4,7 @@ use crate::{ circuits::{ argument::{Argument, ArgumentEnv, ArgumentType}, + berkeley_columns::BerkeleyChallengeTerm, expr::{constraints::ExprOps, Cache}, gate::{CircuitGate, Connect, GateType}, lookup::{ @@ -150,7 +151,10 @@ where // * Operates on Curr and Next rows // * Constrain the decomposition of `in1`, `in2` and `out` of multiples of 16 bits // * The actual XOR is performed thanks to the plookups of 4-bit XORs. - fn constraint_checks>(env: &ArgumentEnv, _cache: &mut Cache) -> Vec { + fn constraint_checks>( + env: &ArgumentEnv, + _cache: &mut Cache, + ) -> Vec { let two = T::from(2u64); // in1 = in1_0 + in1_1 * 2^4 + in1_2 * 2^8 + in1_3 * 2^12 + next_in1 * 2^16 // in2 = in2_0 + in2_1 * 2^4 + in2_2 * 2^8 + in2_3 * 2^12 + next_in2 * 2^16 @@ -169,7 +173,7 @@ where } // Witness layout -fn layout(curr_row: usize, bits: usize) -> Vec<[Box>; COLUMNS]> { +fn layout(curr_row: usize, bits: usize) -> Vec>>> { let num_xor = num_xors(bits); let mut layout = (0..num_xor) .map(|i| xor_row(i, curr_row + i)) @@ -178,9 +182,9 @@ fn layout(curr_row: usize, bits: usize) -> Vec<[Box(nybble: usize, curr_row: usize) -> [Box>; COLUMNS] { +fn xor_row(nybble: usize, curr_row: usize) -> Vec>> { let start = nybble * 16; - [ + vec![ VariableBitsCell::create("in1", start, None), VariableBitsCell::create("in2", start, None), VariableBitsCell::create("out", start, None), @@ -199,8 +203,8 @@ fn xor_row(nybble: usize, curr_row: usize) -> [Box() -> [Box>; COLUMNS] { - [ +fn zero_row() -> Vec>> { + vec![ ConstantCell::create(F::zero()), ConstantCell::create(F::zero()), ConstantCell::create(F::zero()), diff --git a/kimchi/src/circuits/witness/constant_cell.rs b/kimchi/src/circuits/witness/constant_cell.rs index 15ac14e61f..cbc7aaa7c8 100644 --- a/kimchi/src/circuits/witness/constant_cell.rs +++ b/kimchi/src/circuits/witness/constant_cell.rs @@ -1,5 +1,4 @@ use super::{variables::Variables, WitnessCell}; -use crate::circuits::polynomial::COLUMNS; use ark_ff::Field; /// Witness cell with constant value @@ -14,8 +13,8 @@ impl ConstantCell { } } -impl WitnessCell for ConstantCell { - fn value(&self, _witness: &mut [Vec; COLUMNS], _variables: &Variables) -> F { +impl WitnessCell for ConstantCell { + fn value(&self, _witness: &mut [Vec; W], _variables: &Variables, _index: usize) -> F { self.value } } diff --git a/kimchi/src/circuits/witness/copy_bits_cell.rs b/kimchi/src/circuits/witness/copy_bits_cell.rs index 5cb6064f02..a61ca183a8 100644 --- a/kimchi/src/circuits/witness/copy_bits_cell.rs +++ b/kimchi/src/circuits/witness/copy_bits_cell.rs @@ -2,7 +2,6 @@ use ark_ff::Field; use o1_utils::FieldHelpers; use super::{variables::Variables, WitnessCell}; -use crate::circuits::polynomial::COLUMNS; /// Witness cell copied from bits of another witness cell pub struct CopyBitsCell { @@ -24,8 +23,8 @@ impl CopyBitsCell { } } -impl WitnessCell for CopyBitsCell { - fn value(&self, witness: &mut [Vec; COLUMNS], _variables: &Variables) -> F { +impl WitnessCell for CopyBitsCell { + fn value(&self, witness: &mut [Vec; W], _variables: &Variables, _index: usize) -> F { F::from_bits(&witness[self.col][self.row].to_bits()[self.start..self.end]) .expect("failed to deserialize field bits for copy bits cell") } diff --git a/kimchi/src/circuits/witness/copy_cell.rs b/kimchi/src/circuits/witness/copy_cell.rs index d5ea353b56..d3fad71654 100644 --- a/kimchi/src/circuits/witness/copy_cell.rs +++ b/kimchi/src/circuits/witness/copy_cell.rs @@ -1,7 +1,6 @@ use ark_ff::Field; use super::{variables::Variables, WitnessCell}; -use crate::circuits::polynomial::COLUMNS; /// Witness cell copied from another witness cell pub struct CopyCell { @@ -16,8 +15,8 @@ impl CopyCell { } } -impl WitnessCell for CopyCell { - fn value(&self, witness: &mut [Vec; COLUMNS], _variables: &Variables) -> F { +impl WitnessCell for CopyCell { + fn value(&self, witness: &mut [Vec; W], _variables: &Variables, _index: usize) -> F { witness[self.col][self.row] } } diff --git a/kimchi/src/circuits/witness/copy_shift_cell.rs b/kimchi/src/circuits/witness/copy_shift_cell.rs index 822833b521..b0c87587d1 100644 --- a/kimchi/src/circuits/witness/copy_shift_cell.rs +++ b/kimchi/src/circuits/witness/copy_shift_cell.rs @@ -1,5 +1,4 @@ use super::{variables::Variables, WitnessCell}; -use crate::circuits::polynomial::COLUMNS; use ark_ff::Field; /// Witness cell copied from another cell and shifted @@ -16,8 +15,8 @@ impl CopyShiftCell { } } -impl WitnessCell for CopyShiftCell { - fn value(&self, witness: &mut [Vec; COLUMNS], _variables: &Variables) -> F { +impl WitnessCell for CopyShiftCell { + fn value(&self, witness: &mut [Vec; W], _variables: &Variables, _index: usize) -> F { F::from(2u32).pow([self.shift]) * witness[self.col][self.row] } } diff --git a/kimchi/src/circuits/witness/index_cell.rs b/kimchi/src/circuits/witness/index_cell.rs new file mode 100644 index 0000000000..fd2628a316 --- /dev/null +++ b/kimchi/src/circuits/witness/index_cell.rs @@ -0,0 +1,29 @@ +use super::{variables::Variables, WitnessCell}; +use ark_ff::Field; + +/// Witness cell assigned from an indexable variable +/// See [Variables] for more details +pub struct IndexCell<'a> { + name: &'a str, + length: usize, +} + +impl<'a> IndexCell<'a> { + /// Create witness cell assigned from a variable name a length + pub fn create(name: &'a str, from: usize, to: usize) -> Box> { + Box::new(IndexCell { + name, + length: to - from, + }) + } +} + +impl<'a, F: Field, const W: usize> WitnessCell, W> for IndexCell<'a> { + fn value(&self, _witness: &mut [Vec; W], variables: &Variables>, index: usize) -> F { + assert!(index < self.length, "index out of bounds of `IndexCell`"); + variables[self.name][index] + } + fn length(&self) -> usize { + self.length + } +} diff --git a/kimchi/src/circuits/witness/mod.rs b/kimchi/src/circuits/witness/mod.rs index 75271215f4..53f278045c 100644 --- a/kimchi/src/circuits/witness/mod.rs +++ b/kimchi/src/circuits/witness/mod.rs @@ -1,8 +1,10 @@ use ark_ff::{Field, PrimeField}; + mod constant_cell; mod copy_bits_cell; mod copy_cell; mod copy_shift_cell; +mod index_cell; mod variable_bits_cell; mod variable_cell; mod variables; @@ -12,6 +14,7 @@ pub use self::{ copy_bits_cell::CopyBitsCell, copy_cell::CopyCell, copy_shift_cell::CopyShiftCell, + index_cell::IndexCell, variable_bits_cell::VariableBitsCell, variable_cell::VariableCell, variables::{variable_map, variables, Variables}, @@ -19,42 +22,64 @@ pub use self::{ use super::polynomial::COLUMNS; -/// Witness cell interface -pub trait WitnessCell { - fn value(&self, witness: &mut [Vec; COLUMNS], variables: &Variables) -> F; +/// Witness cell interface. By default, the witness cell is a single element of type F. +pub trait WitnessCell { + fn value(&self, witness: &mut [Vec; W], variables: &Variables, index: usize) -> F; + + // Length is 1 by default (T is single F element) unless overridden + fn length(&self) -> usize { + 1 + } } /// Initialize a witness cell based on layout and computed variables -pub fn init_cell( - witness: &mut [Vec; COLUMNS], +/// Inputs: +/// - witness: the witness to initialize with values +/// - offset: the row offset of the witness before initialization +/// - row: the row index inside the partial layout +/// - col: the column index inside the witness +/// - cell: the cell index inside the partial layout (for any but IndexCell, it must be the same as col) +/// - index: the index within the variable (for IndexCell, 0 otherwise) +/// - layout: the partial layout to initialize from +/// - variables: the hashmap of variables to get the values from +#[allow(clippy::too_many_arguments)] +pub fn init_cell( + witness: &mut [Vec; W], offset: usize, row: usize, col: usize, - layout: &[[Box>; COLUMNS]], - variables: &Variables, + cell: usize, + index: usize, + layout: &[Vec>>], + variables: &Variables, ) { - witness[col][row + offset] = layout[row][col].value(witness, variables); + witness[col][row + offset] = layout[row][cell].value(witness, variables, index); } /// Initialize a witness row based on layout and computed variables -pub fn init_row( - witness: &mut [Vec; COLUMNS], +pub fn init_row( + witness: &mut [Vec; W], offset: usize, row: usize, - layout: &[[Box>; COLUMNS]], - variables: &Variables, + layout: &[Vec>>], + variables: &Variables, ) { - for col in 0..COLUMNS { - init_cell(witness, offset, row, col, layout, variables); + let mut col = 0; + for cell in 0..layout[row].len() { + // The loop will only run more than once if the cell is an IndexCell + for index in 0..layout[row][cell].length() { + init_cell(witness, offset, row, col, cell, index, layout, variables); + col += 1; + } } } /// Initialize a witness based on layout and computed variables -pub fn init( - witness: &mut [Vec; COLUMNS], +pub fn init( + witness: &mut [Vec; W], offset: usize, - layout: &[[Box>; COLUMNS]], - variables: &Variables, + layout: &[Vec>>], + variables: &Variables, ) { for row in 0..layout.len() { init_row(witness, offset, row, layout, variables); @@ -67,6 +92,7 @@ mod tests { use super::*; + use crate::circuits::polynomial::COLUMNS; use ark_ec::AffineRepr; use ark_ff::{Field, One, Zero}; use mina_curves::pasta::Pallas; @@ -74,7 +100,7 @@ mod tests { #[test] fn zero_layout() { - let layout: Vec<[Box>; COLUMNS]> = vec![[ + let layout: Vec>>> = vec![vec![ ConstantCell::create(PallasField::zero()), ConstantCell::create(PallasField::zero()), ConstantCell::create(PallasField::zero()), @@ -102,7 +128,7 @@ mod tests { } // Set a single cell to zero - init_cell(&mut witness, 0, 0, 4, &layout, &variables!()); + init_cell(&mut witness, 0, 0, 4, 4, 0, &layout, &variables!()); assert_eq!(witness[4][0], PallasField::zero()); // Set all the cells to zero @@ -117,8 +143,8 @@ mod tests { #[test] fn mixed_layout() { - let layout: Vec<[Box>; COLUMNS]> = vec![ - [ + let layout: Vec>>> = vec![ + vec![ ConstantCell::create(PallasField::from(12u32)), ConstantCell::create(PallasField::from(0xa5a3u32)), ConstantCell::create(PallasField::from(0x800u32)), @@ -135,7 +161,7 @@ mod tests { ConstantCell::create(PallasField::zero()), ConstantCell::create(PallasField::zero()), ], - [ + vec![ CopyCell::create(0, 0), CopyBitsCell::create(0, 1, 4, 8), CopyShiftCell::create(0, 2, 8), diff --git a/kimchi/src/circuits/witness/variable_bits_cell.rs b/kimchi/src/circuits/witness/variable_bits_cell.rs index 1e1cfe31f1..584920592d 100644 --- a/kimchi/src/circuits/witness/variable_bits_cell.rs +++ b/kimchi/src/circuits/witness/variable_bits_cell.rs @@ -1,5 +1,4 @@ use super::{variables::Variables, WitnessCell}; -use crate::circuits::polynomial::COLUMNS; use ark_ff::Field; use o1_utils::FieldHelpers; @@ -19,8 +18,8 @@ impl<'a> VariableBitsCell<'a> { } } -impl<'a, F: Field> WitnessCell for VariableBitsCell<'a> { - fn value(&self, _witness: &mut [Vec; COLUMNS], variables: &Variables) -> F { +impl<'a, F: Field, const W: usize> WitnessCell for VariableBitsCell<'a> { + fn value(&self, _witness: &mut [Vec; W], variables: &Variables, _index: usize) -> F { let bits = if let Some(end) = self.end { F::from_bits(&variables[self.name].to_bits()[self.start..end]) } else { diff --git a/kimchi/src/circuits/witness/variable_cell.rs b/kimchi/src/circuits/witness/variable_cell.rs index 372ff98bdf..fcb6ee9f21 100644 --- a/kimchi/src/circuits/witness/variable_cell.rs +++ b/kimchi/src/circuits/witness/variable_cell.rs @@ -1,5 +1,4 @@ use super::{variables::Variables, WitnessCell}; -use crate::circuits::polynomial::COLUMNS; use ark_ff::Field; /// Witness cell assigned from a variable @@ -15,8 +14,8 @@ impl<'a> VariableCell<'a> { } } -impl<'a, F: Field> WitnessCell for VariableCell<'a> { - fn value(&self, _witness: &mut [Vec; COLUMNS], variables: &Variables) -> F { +impl<'a, F: Field, const W: usize> WitnessCell for VariableCell<'a> { + fn value(&self, _witness: &mut [Vec; W], variables: &Variables, _index: usize) -> F { variables[self.name] } } diff --git a/kimchi/src/circuits/witness/variables.rs b/kimchi/src/circuits/witness/variables.rs index 773b79c0bd..2327d62315 100644 --- a/kimchi/src/circuits/witness/variables.rs +++ b/kimchi/src/circuits/witness/variables.rs @@ -1,4 +1,3 @@ -use ark_ff::Field; /// Layout variable handling /// /// First, you use "anchor" names for the variables when specifying @@ -50,28 +49,28 @@ use std::{ /// /// Map of witness values (used by VariableCells) /// name (String) -> value (F) -pub struct Variables<'a, F: Field>(HashMap<&'a str, F>); +pub struct Variables<'a, T>(HashMap<&'a str, T>); -impl<'a, F: Field> Variables<'a, F> { +impl<'a, T> Variables<'a, T> { /// Create a layout variable map - pub fn create() -> Variables<'a, F> { + pub fn create() -> Variables<'a, T> { Variables(HashMap::new()) } /// Insert a variable and corresponding value into the variable map - pub fn insert(&mut self, name: &'a str, value: F) { + pub fn insert(&mut self, name: &'a str, value: T) { self.0.insert(name, value); } } -impl<'a, F: Field> Index<&'a str> for Variables<'a, F> { - type Output = F; +impl<'a, T> Index<&'a str> for Variables<'a, T> { + type Output = T; fn index(&self, name: &'a str) -> &Self::Output { &self.0[name] } } -impl<'a, F: Field> IndexMut<&'a str> for Variables<'a, F> { +impl<'a, T> IndexMut<&'a str> for Variables<'a, T> { fn index_mut(&mut self, name: &'a str) -> &mut Self::Output { self.0.get_mut(name).expect("failed to get witness value") } diff --git a/kimchi/src/curve.rs b/kimchi/src/curve.rs index db2aa340a2..b82ce74ebb 100644 --- a/kimchi/src/curve.rs +++ b/kimchi/src/curve.rs @@ -10,10 +10,11 @@ use mina_poseidon::poseidon::ArithmeticSpongeParams; use once_cell::sync::Lazy; use poly_commitment::{ commitment::{CommitmentCurve, EndoCurve}, - srs::endos, + ipa::endos, }; -/// Represents additional information that a curve needs in order to be used with Kimchi +/// Represents additional information that a curve needs in order to be used +/// with Kimchi pub trait KimchiCurve: CommitmentCurve + EndoCurve { /// A human readable name. const NAME: &'static str; @@ -24,16 +25,17 @@ pub trait KimchiCurve: CommitmentCurve + EndoCurve { /// Provides the sponge params to be used with the other curve. fn other_curve_sponge_params() -> &'static ArithmeticSpongeParams; - /// Provides the coefficients for the curve endomorphism, called (q,r) in some places. + /// Provides the coefficients for the curve endomorphism, called (q,r) in + /// some places. fn endos() -> &'static (Self::BaseField, Self::ScalarField); - /// Provides the coefficient for the curve endomorphism over the other field, called q in some - /// places. + /// Provides the coefficient for the curve endomorphism over the other + /// field, called q in some places. fn other_curve_endo() -> &'static Self::ScalarField; /// Accessor for the other curve's prime subgroup generator, as coordinates // TODO: This leaked from snarky.rs. Stop the bleed. - fn other_curve_prime_subgroup_generator() -> (Self::ScalarField, Self::ScalarField); + fn other_curve_generator() -> (Self::ScalarField, Self::ScalarField); } fn vesta_endos() -> &'static ( @@ -77,7 +79,7 @@ impl KimchiCurve for Affine { &pallas_endos().0 } - fn other_curve_prime_subgroup_generator() -> (Self::ScalarField, Self::ScalarField) { + fn other_curve_generator() -> (Self::ScalarField, Self::ScalarField) { Affine::::generator() .to_coordinates() .unwrap() @@ -103,7 +105,7 @@ impl KimchiCurve for Affine { &vesta_endos().0 } - fn other_curve_prime_subgroup_generator() -> (Self::ScalarField, Self::ScalarField) { + fn other_curve_generator() -> (Self::ScalarField, Self::ScalarField) { Affine::::generator() .to_coordinates() .unwrap() @@ -133,7 +135,7 @@ impl KimchiCurve for Affine { &pallas_endos().0 } - fn other_curve_prime_subgroup_generator() -> (Self::ScalarField, Self::ScalarField) { + fn other_curve_generator() -> (Self::ScalarField, Self::ScalarField) { Affine::::generator() .to_coordinates() .unwrap() @@ -159,7 +161,7 @@ impl KimchiCurve for Affine { &vesta_endos().0 } - fn other_curve_prime_subgroup_generator() -> (Self::ScalarField, Self::ScalarField) { + fn other_curve_generator() -> (Self::ScalarField, Self::ScalarField) { Affine::::generator() .to_coordinates() .unwrap() @@ -197,7 +199,7 @@ impl KimchiCurve for Affine { &ENDO } - fn other_curve_prime_subgroup_generator() -> (Self::ScalarField, Self::ScalarField) { + fn other_curve_generator() -> (Self::ScalarField, Self::ScalarField) { // TODO: Dummy value, this is definitely not right (44u64.into(), 88u64.into()) } diff --git a/kimchi/src/error.rs b/kimchi/src/error.rs index 8f2e8406a1..84bd430dfa 100644 --- a/kimchi/src/error.rs +++ b/kimchi/src/error.rs @@ -74,13 +74,13 @@ pub enum VerifyError { IncorrectRuntimeProof, #[error("the evaluation for {0:?} is missing")] - MissingEvaluation(crate::circuits::expr::Column), + MissingEvaluation(crate::circuits::berkeley_columns::Column), #[error("the evaluation for PublicInput is missing")] MissingPublicInputEvaluation, #[error("the commitment for {0:?} is missing")] - MissingCommitment(crate::circuits::expr::Column), + MissingCommitment(crate::circuits::berkeley_columns::Column), } /// Errors that can arise when preparing the setup diff --git a/kimchi/src/lagrange_basis_evaluations.rs b/kimchi/src/lagrange_basis_evaluations.rs index 13f5dc2600..a99f3bda8b 100644 --- a/kimchi/src/lagrange_basis_evaluations.rs +++ b/kimchi/src/lagrange_basis_evaluations.rs @@ -2,18 +2,95 @@ use ark_ff::{batch_inversion_and_mul, FftField}; use ark_poly::{EvaluationDomain, Evaluations, Radix2EvaluationDomain as D}; use rayon::prelude::*; -/// The evaluations of all normalized lagrange basis polynomials at a given -/// point. Can be used to evaluate an `Evaluations` form polynomial at that point. +/// Evaluations of all normalized lagrange basis polynomials at a given point. +/// Can be used to evaluate an `Evaluations` form polynomial at that point. +/// +/// The Lagrange basis for polynomials of degree `<= d` over a domain +/// `{ω_0,...,ω_{d-1}}` is the set of `d` polynomials `{l_0,...,l_{d-1}}` of +/// degree `d-1` that equal `1` at `ω_i` and `0` in the rest of the domain +/// terms. They can be used to evaluate polynomials in evaluation form +/// efficiently in `O(d)` time. +/// +/// When chunking is in place, the domain size `n` is larger than the maximum +/// polynomial degree allowed `m`. Thus, on input `n = c·m` evaluations for `c` +/// chunks, we cannot obtain a polynomial `f` with degree `c·m-1` with the equation: +/// +/// `f(X) = x_0 · l_0(X) + ... + x_{c·m-1} · l_{c·m-1}(X)` +/// +/// Instead, this struct will contain the `c·m` coefficients of the polynomial +/// that is equal to the powers of the point `x` in the positions corresponding +/// to the chunk, and `0` elsewhere in the domain. This is useful to evaluate the +/// chunks of polynomials of degree `c·m-1` given in evaluation form at the point. pub struct LagrangeBasisEvaluations { + /// If no chunking: + /// - evals is a vector of length 1 containing a vector of size `n` + /// corresponding to the evaluations of the Lagrange polynomials, which + /// are the polynomials that equal `1` at `ω_i` and `0` elsewhere in the + /// domain. + /// If chunking (a vector of size `c · n`) + /// - the first index refers to the chunks + /// - the second index refers j-th coefficient of the i-th chunk of the + /// polynomial that equals the powers of the point and `0` elsewhere (and + /// the length of each such vector is `n`). evals: Vec>, } impl LagrangeBasisEvaluations { - /// Given the evaluations form of a polynomial, directly evaluate that polynomial at a point. + /// Return the domain size of the individual evaluations. + /// + /// Note that there is an invariant that all individual evaluation chunks + /// have the same size. It is enforced by each constructor. + /// + pub fn domain_size(&self) -> usize { + self.evals[0].len() + } + + /// Given the evaluations form of a polynomial, directly evaluate that + /// polynomial at a point. + /// + /// The Lagrange basis evaluations can be used to evaluate a polynomial + /// given in evaluation form efficiently in `O(n)` time, where `n` is the + /// domain size, without the need of interpolating. + /// + /// Recall that a polynomial can be represented as the sum of the scaled + /// Lagrange polynomials using its evaluations on the domain: + /// `f(x) = x_0 · l_0(x) + ... + x_n · l_n(x)` + /// + /// But when chunking is in place, we want to evaluate a polynomial `f` of + /// degree `c · m - 1` at point `z`, expressed as + /// ```text + /// f(z) = a_0·z^0 + ... + a_{c*m}·z^{c·m} + /// = z^0 · f_0(z) + z^m · f_1(z) + ... + z^{(c-1)m} · f_{c-1}(z) + /// ``` + /// + /// where `f_i(X)` is the i-th chunked polynomial of degree `m-1` of `f`: + /// `f_i(x) = a_{i·m} · x^0 + ... + a_{(i+1)m-1} · x^{m-1}` + /// + /// Returns the evaluation of each chunk of the polynomial at the point + /// (when there is no chunking, the result is a vector of length 1). They + /// correspond to the `f_i(z)` in the equation above. pub fn evaluate>(&self, p: &Evaluations) -> Vec { - assert_eq!(p.evals.len() % self.evals[0].len(), 0); - let stride = p.evals.len() / self.evals[0].len(); + // The domain size must be a multiple of the number of evaluations so + // that the degree of the polynomial can be split into chunks of equal size. + assert_eq!(p.evals.len() % self.domain_size(), 0); + // The number of chunks c + let stride = p.evals.len() / self.domain_size(); let p_evals = &p.evals; + + // Performs the operation + // ```text + // n-1 + // j ∈ [0, c) : eval_{j} = Σ p_{i · c} · l_{j,i} + // i=0 + // ``` + // Note that in the chunking case, the Lagrange basis contains the + // coefficient form of the polynomial that evaluates to the powers of + // `z` in the chunk positions and `0` elsewhere. + // + // Then, the evaluation of `f` on `z` can be computed as the sum of the + // products of the evaluations of `f` in the domain and the Lagrange + // evaluations. + (&self.evals) .into_par_iter() .map(|evals| { @@ -26,11 +103,15 @@ impl LagrangeBasisEvaluations { .collect() } - /// Given the evaluations form of a polynomial, directly evaluate that polynomial at a point, - /// assuming that the given evaluations are either 0 or 1 at every point of the domain. + /// Given the evaluations form of a polynomial, directly evaluate that + /// polynomial at a point, assuming that the given evaluations are either + /// `0` or `1` at every point of the domain. + /// + /// This method can particularly be useful when the polynomials represent + /// (boolean) selectors in a circuit. pub fn evaluate_boolean>(&self, p: &Evaluations) -> Vec { - assert_eq!(p.evals.len() % self.evals[0].len(), 0); - let stride = p.evals.len() / self.evals[0].len(); + assert_eq!(p.evals.len() % self.domain_size(), 0); + let stride = p.evals.len() / self.domain_size(); self.evals .iter() .map(|evals| { @@ -45,43 +126,43 @@ impl LagrangeBasisEvaluations { .collect() } - /// Compute all evaluations of the normalized lagrange basis polynomials of the - /// given domain at the given point. Runs in time O(domain size). + /// Compute all evaluations of the normalized lagrange basis polynomials of + /// the given domain at the given point. Runs in time O(domain size). fn new_with_segment_size_1(domain: D, x: F) -> LagrangeBasisEvaluations { let n = domain.size(); // We want to compute for all i // s_i = 1 / t_i // where - // t_i = prod_{j != i} (omega^i - omega^j) + // t_i = ∏_{j ≠ i} (ω^i - ω^j) // - // Suppose we have t_0 = prod_{j = 1}^{n-1} (1 - omega^j). - // This is a product with n-1 terms. We want to shift each term over by omega - // so we multiply by omega^{n-1}: + // Suppose we have t_0 = ∏_{j = 1}^{n-1} (1 - ω^j). + // This is a product with n-1 terms. We want to shift each term over by + // ω so we multiply by ω^{n-1}: // - // omega^{n-1} * t_0 - // = prod_{j = 1}^{n-1} omega (1 - omega^j). - // = prod_{j = 1}^{n-1} (omega - omega^{j+1)). - // = (omega - omega^2) (omega - omega^3) ... (omega - omega^{n-1+1}) - // = (omega - omega^2) (omega - omega^3) ... (omega - omega^0) + // ω^{n-1} * t_0 + // = ∏_{j = 1}^{n-1} ω (1 - ω^j). + // = ∏_{j = 1}^{n-1} (ω - ω^{j+1)). + // = (ω - ω^2) (ω - ω^3) ... (ω - ω^{n-1+1}) + // = (ω - ω^2) (ω - ω^3) ... (ω - ω^0) // = t_1 // // And generally - // omega^{n-1} * t_i - // = prod_{j != i} omega (omega^i - omega^j) - // = prod_{j != i} (omega^{i + 1} - omega^{j + 1}) - // = prod_{j + 1 != i + 1} (omega^{i + 1} - omega^{j + 1}) - // = prod_{j' != i + 1} (omega^{i + 1} - omega^{j'}) + // ω^{n-1} * t_i + // = ∏_{j ≠ i} ω (ω^i - ω^j) + // = ∏_{j ≠ i} (ω^{i + 1} - ω^{j + 1}) + // = ∏_{j + 1 ≠ i + 1} (ω^{i + 1} - ω^{j + 1}) + // = ∏_{j' ≠ i + 1} (ω^{i + 1} - ω^{j'}) // = t_{i + 1} // - // Since omega^{n-1} = omega^{-1}, we write this as - // omega{-1} t_i = t_{i + 1} + // Since ω^{n-1} = ω^{-1}, we write this as + // ω^{-1} t_i = t_{i + 1} // and then by induction, - // omega^{-i} t_0 = t_i + // ω^{-i} t_0 = t_i - // Now, the ith lagrange evaluation at x is - // (1 / prod_{j != i} (omega^i - omega^j)) (x^n - 1) / (x - omega^i) - // = (x^n - 1) / [t_i (x - omega^i)] - // = (x^n - 1) / [omega^{-i} * t_0 * (x - omega^i)] + // Now, the ith Lagrange evaluation at x is + // (1 / ∏_{j ≠ i} (ω^i - ω^j)) (x^n - 1) / (x - ω^i) + // = (x^n - 1) / [t_i (x - ω^i)] + // = (x^n - 1) / [ω^{-i} * t_0 * (x - ω^i)] // // We compute this using the [batch_inversion_and_mul] function. @@ -106,20 +187,38 @@ impl LagrangeBasisEvaluations { batch_inversion_and_mul(&mut denominators[..], &numerator); - // Denominators now contains the desired result. + // Denominators now contain the desired result. LagrangeBasisEvaluations { evals: vec![denominators], } } - /// Compute all evaluations of the normalized lagrange basis polynomials of the - /// given domain at the given point. Runs in time O(n log(n)) for n = domain size. + /// Compute all evaluations of the normalized Lagrange basis polynomials of + /// the given domain at the given point `x`. Runs in time O(n log(n)) where + /// n is the domain size. fn new_with_chunked_segments( max_poly_size: usize, domain: D, x: F, ) -> LagrangeBasisEvaluations { + // For each chunk, this loop obtains the coefficient form of the + // polynomial that equals the powers of `x` in the positions + // corresponding to the chunk, and 0 elsewhere in the domain, using an + // iFFT operation of length n, resulting in an algorithm that runs in + // `O(c n log n)`. + // + // Example: + // ```text + // i-th chunk + // ----------------------- + // chunked: [ 0, ..., 0, 1, x, x^2, ..., x^{m-1}, 0, ..., 0 ] + // indices: 0 i·m-1 i·m (i+1)m-1 (i+1)m cm-1=n-1 + // ``` + // A total of `n=c·m` coefficients are returned. These will be helpful to + // evaluate the chunks of polynomials of degree `n-1` at the point `x`. + // let n = domain.size(); + assert_eq!(n % max_poly_size, 0); let num_chunks = n / max_poly_size; let mut evals = Vec::with_capacity(num_chunks); for i in 0..num_chunks { @@ -132,8 +231,23 @@ impl LagrangeBasisEvaluations { // This uses the same trick as `poly_commitment::srs::SRS::lagrange_basis`, but // applied to field elements instead of group elements. domain.ifft_in_place(&mut chunked_evals); + // Check that the number of coefficients after iFFT is as expected + assert_eq!( + chunked_evals.len(), + n, + "The number of coefficients of the {}-th segment is {} but it should have been {n}", + i, + chunked_evals.len() + ); evals.push(chunked_evals); } + // Sanity check + assert_eq!( + evals.len(), + num_chunks, + "The number of expected chunks is {num_chunks} but only {} has/have been computed", + evals.len() + ); LagrangeBasisEvaluations { evals } } @@ -153,13 +267,15 @@ mod tests { use ark_ff::{One, UniformRand, Zero}; use ark_poly::{Polynomial, Radix2EvaluationDomain}; use mina_curves::pasta::Fp; + use rand::Rng; #[test] fn test_lagrange_evaluations() { - let n = 1 << 4; + let mut rng = o1_utils::tests::make_test_rng(None); + let domain_log_size = rng.gen_range(1..10); + let n = 1 << domain_log_size; let domain = Radix2EvaluationDomain::new(n).unwrap(); - let rng = &mut o1_utils::tests::make_test_rng(None); - let x = Fp::rand(rng); + let x = Fp::rand(&mut rng); let evaluator = LagrangeBasisEvaluations::new(domain.size(), domain, x); let expected = (0..n).map(|i| { @@ -181,19 +297,23 @@ mod tests { #[test] fn test_new_with_chunked_segments() { - let n = 1 << 4; + let mut rng = o1_utils::tests::make_test_rng(None); + let domain_log_size = rng.gen_range(1..10); + let n = 1 << domain_log_size; let domain = Radix2EvaluationDomain::new(n).unwrap(); - let rng = &mut o1_utils::tests::make_test_rng(None); - let x = Fp::rand(rng); + let x = Fp::rand(&mut rng); let evaluator = LagrangeBasisEvaluations::new(domain.size(), domain, x); let evaluator_chunked = LagrangeBasisEvaluations::new_with_chunked_segments(domain.size(), domain, x); + let chunk_length = evaluator_chunked.domain_size(); for (i, (evals, evals_chunked)) in evaluator .evals .iter() .zip(evaluator_chunked.evals.iter()) .enumerate() { + // Check all chunks have the same length + assert_eq!(evals_chunked.len(), chunk_length); for (j, (evals, evals_chunked)) in evals.iter().zip(evals_chunked.iter()).enumerate() { if evals != evals_chunked { panic!("{}, {}, {}: {} != {}", line!(), i, j, evals, evals_chunked); @@ -205,7 +325,8 @@ mod tests { #[test] fn test_evaluation() { let rng = &mut o1_utils::tests::make_test_rng(None); - let n = 1 << 10; + let domain_log_size = rng.gen_range(1..10); + let n = 1 << domain_log_size; let domain = Radix2EvaluationDomain::new(n).unwrap(); let evals = { @@ -228,7 +349,8 @@ mod tests { #[test] fn test_evaluation_boolean() { let rng = &mut o1_utils::tests::make_test_rng(None); - let n = 1 << 1; + let domain_log_size = rng.gen_range(1..10); + let n = 1 << domain_log_size; let domain = Radix2EvaluationDomain::new(n).unwrap(); let evals = { @@ -240,7 +362,6 @@ mod tests { Fp::zero() }); } - e = vec![Fp::zero(), Fp::one()]; Evaluations::from_vec_and_domain(e, domain) }; diff --git a/kimchi/src/lib.rs b/kimchi/src/lib.rs index 6a1e004c57..523b362b46 100644 --- a/kimchi/src/lib.rs +++ b/kimchi/src/lib.rs @@ -29,3 +29,14 @@ pub mod verifier_index; #[cfg(test)] mod tests; + +/// Handy macro to return the filename and line number of a place in the code. +#[macro_export] +macro_rules! loc { + () => {{ + ::std::borrow::Cow::Owned(format!("{}:{}", file!(), line!())) + }}; +} + +/// Export what is commonly used. +pub use snarky::prelude::*; diff --git a/kimchi/src/linearization.rs b/kimchi/src/linearization.rs index ef5abe39ef..a688211a15 100644 --- a/kimchi/src/linearization.rs +++ b/kimchi/src/linearization.rs @@ -4,6 +4,7 @@ use crate::{ alphas::Alphas, circuits::{ argument::{Argument, ArgumentType}, + berkeley_columns::BerkeleyChallengeTerm, expr, lookup, lookup::{ constraints::LookupConfiguration, @@ -26,8 +27,9 @@ use crate::{ }; use crate::circuits::{ + berkeley_columns::Column, constraints::FeatureFlags, - expr::{Column, ConstantExpr, Expr, FeatureFlag, Linearization, PolishToken}, + expr::{ConstantExpr, Expr, FeatureFlag, Linearization, PolishToken}, gate::GateType, wires::COLUMNS, }; @@ -41,7 +43,10 @@ use ark_ff::{FftField, PrimeField, Zero}; pub fn constraints_expr( feature_flags: Option<&FeatureFlags>, generic: bool, -) -> (Expr>, Alphas) { +) -> ( + Expr, Column>, + Alphas, +) { // register powers of alpha so that we don't reuse them across mutually inclusive constraints let mut powers_of_alpha = Alphas::::default(); @@ -333,16 +338,20 @@ pub fn linearization_columns( /// Linearize the `expr`. /// -/// If the `feature_flags` argument is `None`, this will generate an expression using the -/// `Expr::IfFeature` variant for each of the flags. +/// If the `feature_flags` argument is `None`, this will generate an expression +/// using the `Expr::IfFeature` variant for each of the flags. /// /// # Panics /// /// Will panic if the `linearization` process fails. +#[allow(clippy::type_complexity)] pub fn expr_linearization( feature_flags: Option<&FeatureFlags>, generic: bool, -) -> (Linearization>>, Alphas) { +) -> ( + Linearization>, Column>, + Alphas, +) { let evaluated_cols = linearization_columns::(feature_flags); let (expr, powers_of_alpha) = constraints_expr(feature_flags, generic); diff --git a/kimchi/src/oracles.rs b/kimchi/src/oracles.rs index c36167ae0b..1ed5fb8ec5 100644 --- a/kimchi/src/oracles.rs +++ b/kimchi/src/oracles.rs @@ -38,7 +38,7 @@ where #[cfg(feature = "ocaml_types")] pub mod caml { use ark_ff::PrimeField; - use poly_commitment::{commitment::shift_scalar, evaluation_proof::OpeningProof}; + use poly_commitment::{commitment::shift_scalar, ipa::OpeningProof}; use crate::{ circuits::scalars::caml::CamlRandomOracles, curve::KimchiCurve, error::VerifyError, diff --git a/kimchi/src/plonk_sponge.rs b/kimchi/src/plonk_sponge.rs index 4ee1953915..9c5d11ac1a 100644 --- a/kimchi/src/plonk_sponge.rs +++ b/kimchi/src/plonk_sponge.rs @@ -7,6 +7,11 @@ use mina_poseidon::{ use crate::proof::{PointEvaluations, ProofEvaluations}; +/// Abstracts a sponge that operates on the scalar field of an +/// elliptic curve. Unlike the [`FqSponge`](mina_poseidon::FqSponge) +/// it cannot absorb or digest base field elements. However, the +/// [`FqSponge`](mina_poseidon::FqSponge) can *also* operate on the +/// scalar field by the means of a specific encoding technique. pub trait FrSponge { /// Creates a new Fr-Sponge. fn new(p: &'static ArithmeticSpongeParams) -> Self; @@ -47,7 +52,6 @@ impl FrSponge for DefaultFrSponge { } fn challenge(&mut self) -> ScalarChallenge { - // TODO: why involve sponge_5_wires here? ScalarChallenge(self.squeeze(mina_poseidon::sponge::CHALLENGE_LENGTH_IN_LIMBS)) } diff --git a/kimchi/src/precomputed_srs.rs b/kimchi/src/precomputed_srs.rs index 6d3bcb0c11..7dd2347c0d 100644 --- a/kimchi/src/precomputed_srs.rs +++ b/kimchi/src/precomputed_srs.rs @@ -1,58 +1,187 @@ -//! To prover and verify proofs you need a [Structured Reference String](https://www.cryptologie.net/article/560/zk-faq-whats-a-trusted-setup-whats-a-structured-reference-string-whats-toxic-waste/) (SRS). -//! The generation of this SRS is quite expensive, so we provide a pre-generated SRS in this repo. +//! To prover and verify proofs you need a [Structured Reference +//! String](https://www.cryptologie.net/article/560/zk-faq-whats-a-trusted-setup-whats-a-structured-reference-string-whats-toxic-waste/) +//! (SRS). +//! The generation of this SRS is quite expensive, so we provide a pre-generated +//! SRS in this repo. //! Specifically, two of them, one for each pasta curve. //! //! We generate the SRS within the test in this module. -//! If you modify the SRS, you will need to regenerate the SRS by passing the `SRS_OVERWRITE` env var. +//! If you modify the SRS, you will need to regenerate the SRS by passing the +//! `SRS_OVERWRITE` env var. -use std::{fs::File, io::BufReader, path::PathBuf}; +use crate::curve::KimchiCurve; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; +use poly_commitment::{hash_map_cache::HashMapCache, ipa::SRS, PolyComm}; +use serde::{Deserialize, Serialize}; +use serde_with::serde_as; +use std::{collections::HashMap, fs::File, io::BufReader, path::PathBuf}; + +/// We store several different types of SRS objects. This enum parameterizes +/// them. +#[derive(Clone, Copy, PartialEq, Eq)] +pub enum StoredSRSType { + Test, + Prod, +} -use poly_commitment::srs::SRS; +/// A clone of the SRS struct that is used for serialization, in a +/// test-optimised way. +/// +/// NB: Serialization of these fields is unchecked (and fast). If you +/// want to make sure the data is checked on deserialization, this code +/// must be changed; or you can check it externally. +#[serde_as] +#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)] +#[serde(bound = "G: CanonicalDeserialize + CanonicalSerialize")] +pub struct TestSRS { + /// The vector of group elements for committing to polynomials in + /// coefficient form. + #[serde_as(as = "Vec")] + pub g: Vec, -use crate::curve::KimchiCurve; + /// A group element used for blinding commitments + #[serde_as(as = "o1_utils::serialization::SerdeAsUnchecked")] + pub h: G, + + /// Commitments to Lagrange bases, per domain size + #[serde_as(as = "HashMap<_,Vec>>")] + pub lagrange_bases: HashMap>>, +} + +impl From> for TestSRS { + fn from(value: SRS) -> Self { + TestSRS { + g: value.g, + h: value.h, + lagrange_bases: value.lagrange_bases.into(), + } + } +} + +impl From> for SRS { + fn from(value: TestSRS) -> Self { + SRS { + g: value.g, + h: value.h, + lagrange_bases: HashMapCache::new_from_hashmap(value.lagrange_bases), + } + } +} /// The size of the SRS that we serialize. pub const SERIALIZED_SRS_SIZE: u32 = 16; /// The path of the serialized SRS. -fn get_srs_path() -> PathBuf { - let base_path = std::env::var("CARGO_MANIFEST_DIR").expect("failed to get manifest path"); +fn get_srs_path(srs_type: StoredSRSType) -> PathBuf { + let test_prefix: String = (match srs_type { + StoredSRSType::Test => "test_", + StoredSRSType::Prod => "", + }) + .to_owned(); + let base_path = env!("CARGO_MANIFEST_DIR"); PathBuf::from(base_path) .join("../srs") - .join(format!("{}.srs", G::NAME)) + .join(test_prefix + &format!("{}.srs", G::NAME)) } -/// Obtains an SRS for a specific curve from disk. -/// Panics if the SRS does not exists. -pub fn get_srs() -> SRS +/// Generic SRS getter fuction. +pub fn get_srs_generic(srs_type: StoredSRSType) -> SRS where G: KimchiCurve, { - let srs_path = get_srs_path::(); + let srs_path = get_srs_path::(srs_type); let file = File::open(srs_path.clone()).unwrap_or_else(|_| panic!("missing SRS file: {srs_path:?}")); let reader = BufReader::new(file); - rmp_serde::from_read(reader).unwrap() + match srs_type { + StoredSRSType::Test => { + let test_srs: TestSRS = rmp_serde::from_read(reader).unwrap(); + From::from(test_srs) + } + StoredSRSType::Prod => rmp_serde::from_read(reader).unwrap(), + } +} + +/// Obtains an SRS for a specific curve from disk. +/// Panics if the SRS does not exists. +pub fn get_srs() -> SRS +where + G: KimchiCurve, +{ + get_srs_generic(StoredSRSType::Prod) +} + +/// Obtains a Test SRS for a specific curve from disk. +/// Panics if the SRS does not exists. +pub fn get_srs_test() -> SRS +where + G: KimchiCurve, +{ + get_srs_generic(StoredSRSType::Test) } #[cfg(test)] mod tests { use super::*; + use ark_ec::AffineRepr; use ark_ff::PrimeField; use ark_serialize::Write; + use hex; use mina_curves::pasta::{Pallas, Vesta}; + use poly_commitment::{hash_map_cache::HashMapCache, SRS as _}; + + use crate::circuits::domains::EvaluationDomains; + + fn test_regression_serialization_srs_with_generators(exp_output: String) { + let h = G::generator(); + let g = vec![h]; + let lagrange_bases = HashMapCache::new(); + let srs = SRS:: { + g, + h, + lagrange_bases, + }; + let srs_bytes = rmp_serde::to_vec(&srs).unwrap(); + let output = hex::encode(srs_bytes.clone()); + assert_eq!(output, exp_output) + } - fn create_or_check_srs(log2_size: u32) + #[test] + fn test_regression_serialization_srs_with_generators_vesta() { + // This is the same as Pallas as we encode the coordinate x only. + // Generated with commit 4c69a4defdb109b94f1124fe93283e728f1d8758 + let exp_output = "9291c421010000000000000000000000000000000000000000000000000000000000000000c421010000000000000000000000000000000000000000000000000000000000000000"; + test_regression_serialization_srs_with_generators::(exp_output.to_string()) + } + + #[test] + fn test_regression_serialization_srs_with_generators_pallas() { + // This is the same as Pallas as we encode the coordinate x only. + // Generated with commit 4c69a4defdb109b94f1124fe93283e728f1d8758 + let exp_output = "9291c421010000000000000000000000000000000000000000000000000000000000000000c421010000000000000000000000000000000000000000000000000000000000000000"; + test_regression_serialization_srs_with_generators::(exp_output.to_string()) + } + + fn create_or_check_srs(log2_size: u32, srs_type: StoredSRSType) where G: KimchiCurve, G::BaseField: PrimeField, { // generate SRS - let srs = SRS::::create(1 << log2_size); + let domain_size = 1 << log2_size; + let srs = SRS::::create(domain_size); + + // Test SRS objects have Lagrange bases precomputed + if srs_type == StoredSRSType::Test { + for sub_domain_size in 1..=domain_size { + let domain = EvaluationDomains::::create(sub_domain_size).unwrap(); + srs.get_lagrange_basis(domain.d1); + } + } // overwrite SRS if the env var is set - let srs_path = get_srs_path::(); + let srs_path = get_srs_path::(srs_type); if std::env::var("SRS_OVERWRITE").is_ok() { let mut file = std::fs::OpenOptions::new() .create(true) @@ -60,22 +189,55 @@ mod tests { .open(srs_path) .expect("failed to open SRS file"); - let srs_bytes = rmp_serde::to_vec(&srs).unwrap(); + let srs_bytes = match srs_type { + StoredSRSType::Test => { + let srs: TestSRS = From::from(srs.clone()); + rmp_serde::to_vec(&srs).unwrap() + } + StoredSRSType::Prod => rmp_serde::to_vec(&srs).unwrap(), + }; + file.write_all(&srs_bytes).expect("failed to write file"); file.flush().expect("failed to flush file"); } // get SRS from disk - let srs_on_disk = get_srs::(); + let srs_on_disk: SRS = get_srs_generic::(srs_type); // check that it matches what we just generated assert_eq!(srs, srs_on_disk); } + /// Checks if `get_srs` (prod) succeeds for Pallas. Can be used for time-profiling. + #[test] + pub fn heavy_check_get_srs_prod_pallas() { + get_srs::(); + } + + /// Checks if `get_srs` (prod) succeeds for Vesta. Can be used for time-profiling. + #[test] + pub fn heavy_check_get_srs_prod_vesta() { + get_srs::(); + } + + /// Checks if `get_srs` (test) succeeds for Pallas. Can be used for time-profiling. + #[test] + pub fn check_get_srs_test_pallas() { + get_srs_test::(); + } + + /// Checks if `get_srs` (test) succeeds for Vesta. Can be used for time-profiling. + #[test] + pub fn check_get_srs_test_vesta() { + get_srs_test::(); + } + /// This test checks that the two serialized SRS on disk are correct. #[test] - pub fn test_srs_serialization() { - create_or_check_srs::(SERIALIZED_SRS_SIZE); - create_or_check_srs::(SERIALIZED_SRS_SIZE); + pub fn heavy_test_srs_serialization() { + create_or_check_srs::(SERIALIZED_SRS_SIZE, StoredSRSType::Prod); + create_or_check_srs::(SERIALIZED_SRS_SIZE, StoredSRSType::Prod); + create_or_check_srs::(SERIALIZED_SRS_SIZE, StoredSRSType::Test); + create_or_check_srs::(SERIALIZED_SRS_SIZE, StoredSRSType::Test); } } diff --git a/kimchi/src/proof.rs b/kimchi/src/proof.rs index 5492fef1f5..6e03282f20 100644 --- a/kimchi/src/proof.rs +++ b/kimchi/src/proof.rs @@ -1,7 +1,7 @@ //! This module implements the data structures of a proof. use crate::circuits::{ - expr::Column, + berkeley_columns::Column, gate::GateType, lookup::lookups::LookupPattern, wires::{COLUMNS, PERMUTS}, @@ -38,8 +38,10 @@ pub struct PointEvaluations { // TODO: this should really be vectors here, perhaps create another type for chunked evaluations? /// Polynomial evaluations contained in a `ProverProof`. -/// - **Chunked evaluations** `Field` is instantiated with vectors with a length that equals the length of the chunk -/// - **Non chunked evaluations** `Field` is instantiated with a field, so they are single-sized#[serde_as] +/// - **Chunked evaluations** `Field` is instantiated with vectors with a length +/// that equals the length of the chunk +/// - **Non chunked evaluations** `Field` is instantiated with a field, so they +/// are single-sized#[serde_as] #[serde_as] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct ProofEvaluations { @@ -50,7 +52,8 @@ pub struct ProofEvaluations { /// permutation polynomial pub z: Evals, /// permutation polynomials - /// (PERMUTS-1 evaluations because the last permutation is only used in commitment form) + /// (PERMUTS-1 evaluations because the last permutation is only used in + /// commitment form) pub s: [Evals; PERMUTS - 1], /// coefficient polynomials pub coefficients: [Evals; COLUMNS], @@ -60,11 +63,13 @@ pub struct ProofEvaluations { pub poseidon_selector: Evals, /// evaluation of the elliptic curve addition selector polynomial pub complete_add_selector: Evals, - /// evaluation of the elliptic curve variable base scalar multiplication selector polynomial + /// evaluation of the elliptic curve variable base scalar multiplication + /// selector polynomial pub mul_selector: Evals, /// evaluation of the endoscalar multiplication selector polynomial pub emul_selector: Evals, - /// evaluation of the endoscalar multiplication scalar computation selector polynomial + /// evaluation of the endoscalar multiplication scalar computation selector + /// polynomial pub endomul_scalar_selector: Evals, // Optional gates @@ -100,7 +105,8 @@ pub struct ProofEvaluations { pub lookup_gate_lookup_selector: Option, /// evaluation of the RangeCheck range check pattern selector polynomial pub range_check_lookup_selector: Option, - /// evaluation of the ForeignFieldMul range check pattern selector polynomial + /// evaluation of the ForeignFieldMul range check pattern selector + /// polynomial pub foreign_field_mul_lookup_selector: Option, } @@ -132,7 +138,8 @@ pub struct ProverCommitments { pub lookup: Option>, } -/// The proof that the prover creates from a [ProverIndex](super::prover_index::ProverIndex) and a `witness`. +/// The proof that the prover creates from a +/// [ProverIndex](super::prover_index::ProverIndex) and a `witness`. #[serde_as] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[serde(bound = "G: ark_serialize::CanonicalDeserialize + ark_serialize::CanonicalSerialize")] @@ -150,7 +157,8 @@ pub struct ProverProof { /// Two evaluations over a number of committed polynomials pub evals: ProofEvaluations>>, - /// Required evaluation for [Maller's optimization](https://o1-labs.github.io/mina-book/crypto/plonk/maller_15.html#the-evaluation-of-l) + /// Required evaluation for [Maller's + /// optimization](https://o1-labs.github.io/proof-systems/kimchi/maller_15.html#the-evaluation-of-l) #[serde_as(as = "o1_utils::serialization::SerdeAs")] pub ft_eval1: G::ScalarField, @@ -620,6 +628,13 @@ pub mod caml { CamlF: From, { fn from(pe: ProofEvaluations>>) -> Self { + let first = pe.public.map(|x: PointEvaluations>| { + // map both fields of each evaluation. + x.map(&|x: Vec| { + let y: Vec = x.into_iter().map(Into::into).collect(); + y + }) + }); let w = ( pe.w[0] .clone() @@ -735,85 +750,83 @@ pub mod caml { .map(&|x| x.into_iter().map(Into::into).collect()), ); - ( - pe.public + let second = CamlProofEvaluations { + w, + coefficients, + z: pe.z.map(&|x| x.into_iter().map(Into::into).collect()), + s, + generic_selector: pe + .generic_selector + .map(&|x| x.into_iter().map(Into::into).collect()), + poseidon_selector: pe + .poseidon_selector + .map(&|x| x.into_iter().map(Into::into).collect()), + complete_add_selector: pe + .complete_add_selector + .map(&|x| x.into_iter().map(Into::into).collect()), + mul_selector: pe + .mul_selector + .map(&|x| x.into_iter().map(Into::into).collect()), + emul_selector: pe + .emul_selector + .map(&|x| x.into_iter().map(Into::into).collect()), + endomul_scalar_selector: pe + .endomul_scalar_selector + .map(&|x| x.into_iter().map(Into::into).collect()), + range_check0_selector: pe + .range_check0_selector + .map(|x| x.map(&|x| x.into_iter().map(Into::into).collect())), + range_check1_selector: pe + .range_check1_selector .map(|x| x.map(&|x| x.into_iter().map(Into::into).collect())), - CamlProofEvaluations { - w, - coefficients, - z: pe.z.map(&|x| x.into_iter().map(Into::into).collect()), - s, - generic_selector: pe - .generic_selector - .map(&|x| x.into_iter().map(Into::into).collect()), - poseidon_selector: pe - .poseidon_selector - .map(&|x| x.into_iter().map(Into::into).collect()), - complete_add_selector: pe - .complete_add_selector - .map(&|x| x.into_iter().map(Into::into).collect()), - mul_selector: pe - .mul_selector - .map(&|x| x.into_iter().map(Into::into).collect()), - emul_selector: pe - .emul_selector - .map(&|x| x.into_iter().map(Into::into).collect()), - endomul_scalar_selector: pe - .endomul_scalar_selector - .map(&|x| x.into_iter().map(Into::into).collect()), - range_check0_selector: pe - .range_check0_selector - .map(|x| x.map(&|x| x.into_iter().map(Into::into).collect())), - range_check1_selector: pe - .range_check1_selector - .map(|x| x.map(&|x| x.into_iter().map(Into::into).collect())), - foreign_field_add_selector: pe - .foreign_field_add_selector - .map(|x| x.map(&|x| x.into_iter().map(Into::into).collect())), - foreign_field_mul_selector: pe - .foreign_field_mul_selector - .map(|x| x.map(&|x| x.into_iter().map(Into::into).collect())), - xor_selector: pe - .xor_selector - .map(|x| x.map(&|x| x.into_iter().map(Into::into).collect())), - rot_selector: pe - .rot_selector - .map(|x| x.map(&|x| x.into_iter().map(Into::into).collect())), - lookup_aggregation: pe - .lookup_aggregation - .map(|x| x.map(&|x| x.into_iter().map(Into::into).collect())), - lookup_table: pe - .lookup_table - .map(|x| x.map(&|x| x.into_iter().map(Into::into).collect())), - lookup_sorted: pe - .lookup_sorted - .iter() - .map(|x| { - x.as_ref().map(|x| { - x.map_ref(&|x| x.clone().into_iter().map(Into::into).collect()) - }) + foreign_field_add_selector: pe + .foreign_field_add_selector + .map(|x| x.map(&|x| x.into_iter().map(Into::into).collect())), + foreign_field_mul_selector: pe + .foreign_field_mul_selector + .map(|x| x.map(&|x| x.into_iter().map(Into::into).collect())), + xor_selector: pe + .xor_selector + .map(|x| x.map(&|x| x.into_iter().map(Into::into).collect())), + rot_selector: pe + .rot_selector + .map(|x| x.map(&|x| x.into_iter().map(Into::into).collect())), + lookup_aggregation: pe + .lookup_aggregation + .map(|x| x.map(&|x| x.into_iter().map(Into::into).collect())), + lookup_table: pe + .lookup_table + .map(|x| x.map(&|x| x.into_iter().map(Into::into).collect())), + lookup_sorted: pe + .lookup_sorted + .iter() + .map(|x| { + x.as_ref().map(|x| { + x.map_ref(&|x| x.clone().into_iter().map(Into::into).collect()) }) - .collect::>(), - runtime_lookup_table: pe - .runtime_lookup_table - .map(|x| x.map(&|x| x.into_iter().map(Into::into).collect())), - runtime_lookup_table_selector: pe - .runtime_lookup_table_selector - .map(|x| x.map(&|x| x.into_iter().map(Into::into).collect())), - xor_lookup_selector: pe - .xor_lookup_selector - .map(|x| x.map(&|x| x.into_iter().map(Into::into).collect())), - lookup_gate_lookup_selector: pe - .lookup_gate_lookup_selector - .map(|x| x.map(&|x| x.into_iter().map(Into::into).collect())), - range_check_lookup_selector: pe - .range_check_lookup_selector - .map(|x| x.map(&|x| x.into_iter().map(Into::into).collect())), - foreign_field_mul_lookup_selector: pe - .foreign_field_mul_lookup_selector - .map(|x| x.map(&|x| x.into_iter().map(Into::into).collect())), - }, - ) + }) + .collect::>(), + runtime_lookup_table: pe + .runtime_lookup_table + .map(|x| x.map(&|x| x.into_iter().map(Into::into).collect())), + runtime_lookup_table_selector: pe + .runtime_lookup_table_selector + .map(|x| x.map(&|x| x.into_iter().map(Into::into).collect())), + xor_lookup_selector: pe + .xor_lookup_selector + .map(|x| x.map(&|x| x.into_iter().map(Into::into).collect())), + lookup_gate_lookup_selector: pe + .lookup_gate_lookup_selector + .map(|x| x.map(&|x| x.into_iter().map(Into::into).collect())), + range_check_lookup_selector: pe + .range_check_lookup_selector + .map(|x| x.map(&|x| x.into_iter().map(Into::into).collect())), + foreign_field_mul_lookup_selector: pe + .foreign_field_mul_lookup_selector + .map(|x| x.map(&|x| x.into_iter().map(Into::into).collect())), + }; + + (first, second) } } diff --git a/kimchi/src/prover.rs b/kimchi/src/prover.rs index 14e043be07..b135d20836 100644 --- a/kimchi/src/prover.rs +++ b/kimchi/src/prover.rs @@ -3,8 +3,9 @@ use crate::{ circuits::{ argument::{Argument, ArgumentType}, + berkeley_columns::{BerkeleyChallenges, Environment, LookupEnvironment}, constraints::zk_rows_strict_lower_bound, - expr::{self, l0_1, Constants, Environment, LookupEnvironment}, + expr::{self, l0_1, Constants}, gate::GateType, lookup::{self, runtime_tables::RuntimeTable, tables::combine_table_entry}, polynomials::{ @@ -45,7 +46,7 @@ use poly_commitment::{ commitment::{ absorb_commitment, b_poly_coefficients, BlindedCommitment, CommitmentCurve, PolyComm, }, - evaluation_proof::DensePolynomialOrEvaluations, + utils::DensePolynomialOrEvaluations, OpenProof, SRS as _, }; use rand_core::{CryptoRng, RngCore}; @@ -156,11 +157,13 @@ where ) } - /// This function constructs prover's recursive zk-proof from the witness & the `ProverIndex` against SRS instance + /// This function constructs prover's recursive zk-proof from the witness & + /// the `ProverIndex` against SRS instance /// /// # Errors /// - /// Will give error if inputs(like `lookup_context.joint_lookup_table_d8`) are None. + /// Will give error if inputs(like `lookup_context.joint_lookup_table_d8`) + /// are None. /// /// # Panics /// @@ -210,9 +213,17 @@ where .ok_or(ProverError::NoRoomForZkInWitness)?; let zero_knowledge_limit = zk_rows_strict_lower_bound(num_chunks); - if (index.cs.zk_rows as usize) < zero_knowledge_limit { + // Because the lower bound is strict, the result of the function above + // is not a sufficient number of zero knowledge rows, so the error must + // be raised anytime the number of zero knowledge rows is not greater + // than the strict lower bound. + // Example: + // for 1 chunk, `zero_knowledge_limit` is 2, and we need at least 3, + // thus the error should be raised and the message should say that the + // expected number of zero knowledge rows is 3 (hence the + 1). + if (index.cs.zk_rows as usize) <= zero_knowledge_limit { return Err(ProverError::NotZeroKnowledge( - zero_knowledge_limit, + zero_knowledge_limit + 1, index.cs.zk_rows as usize, )); } @@ -425,7 +436,7 @@ where G::ScalarField::zero() }; - //~~ * Derive the scalar joint combiner $j$ from $j'$ using the endomorphism (TOOD: specify) + //~~ * Derive the scalar joint combiner $j$ from $j'$ using the endomorphism (TODO: specify) let joint_combiner: G::ScalarField = ScalarChallenge(joint_combiner).to_field(endo_r); //~~ * If multiple lookup tables are involved, @@ -706,14 +717,18 @@ where let mds = &G::sponge_params().mds; Environment { constants: Constants { - alpha, - beta, - gamma, - joint_combiner: lookup_context.joint_combiner, endo_coefficient: index.cs.endo, mds, zk_rows: index.cs.zk_rows, }, + challenges: BerkeleyChallenges { + alpha, + beta, + gamma, + joint_combiner: lookup_context + .joint_combiner + .unwrap_or(G::ScalarField::zero()), + }, witness: &lagrange.d8.this.w, coefficient: &index.column_evaluations.coefficients8, vanishes_on_zero_knowledge_and_previous_rows: &index @@ -870,7 +885,7 @@ where //~ 1. commit (hiding) to the quotient polynomial $t$ let t_comm = { index.srs.commit("ient_poly, 7 * num_chunks, rng) }; - //~ 1. Absorb the the commitment of the quotient polynomial with the Fq-Sponge. + //~ 1. Absorb the commitment of the quotient polynomial with the Fq-Sponge. absorb_commitment(&mut fq_sponge, &t_comm.commitment); //~ 1. Sample $\zeta'$ with the Fq-Sponge. @@ -951,7 +966,7 @@ where //~ //~ $$(f_0(x), f_1(x), f_2(x), \ldots)$$ //~ - //~ TODO: do we want to specify more on that? It seems unecessary except for the t polynomial (or if for some reason someone sets that to a low value) + //~ TODO: do we want to specify more on that? It seems unnecessary except for the t polynomial (or if for some reason someone sets that to a low value) internal_tracing::checkpoint!(internal_traces; lagrange_basis_eval_zeta_poly); let zeta_evals = @@ -1101,7 +1116,7 @@ where //~ 1. Evaluate the same polynomials without chunking them //~ (so that each polynomial should correspond to a single value this time). - let evals = { + let evals: ProofEvaluations> = { let powers_of_eval_points_for_chunks = PointEvaluations { zeta: zeta_to_srs_len, zeta_omega: zeta_omega_to_srs_len, @@ -1110,7 +1125,7 @@ where }; //~ 1. Compute the ft polynomial. - //~ This is to implement [Maller's optimization](https://o1-labs.github.io/mina-book/crypto/plonk/maller_15.html). + //~ This is to implement [Maller's optimization](https://o1-labs.github.io/proof-systems/kimchi/maller_15.html). internal_tracing::checkpoint!(internal_traces; compute_ft_poly); let ft: DensePolynomial = { let f_chunked = { @@ -1133,7 +1148,7 @@ where drop(env); - // see https://o1-labs.github.io/mina-book/crypto/plonk/maller_15.html#the-prover-side + // see https://o1-labs.github.io/proof-systems/kimchi/maller_15.html#the-prover-side f.to_chunked_polynomial(num_chunks, index.max_poly_size) .linearize(zeta_to_srs_len) }; @@ -1146,14 +1161,14 @@ where }; //~ 1. construct the blinding part of the ft polynomial commitment - //~ [see this section](https://o1-labs.github.io/mina-book/crypto/plonk/maller_15.html#evaluation-proof-and-blinding-factors) + //~ [see this section](https://o1-labs.github.io/proof-systems/kimchi/maller_15.html#evaluation-proof-and-blinding-factors) let blinding_ft = { let blinding_t = t_comm.blinders.chunk_blinding(zeta_to_srs_len); let blinding_f = G::ScalarField::zero(); PolyComm { // blinding_f - Z_H(zeta) * blinding_t - elems: vec![ + chunks: vec![ blinding_f - (zeta_to_domain_size - G::ScalarField::one()) * blinding_t, ], } @@ -1189,7 +1204,7 @@ where .map(|RecursionChallenge { chals, comm }| { ( DensePolynomial::from_coefficients_vec(b_poly_coefficients(chals)), - comm.elems.len(), + comm.len(), ) }) .collect::>(); @@ -1223,8 +1238,12 @@ where //~ 1. Create a list of all polynomials that will require evaluations //~ (and evaluation proofs) in the protocol. //~ First, include the previous challenges, in case we are in a recursive prover. - let non_hiding = |d1_size: usize| PolyComm { - elems: vec![G::ScalarField::zero(); d1_size], + let non_hiding = |n_chunks: usize| PolyComm { + chunks: vec![G::ScalarField::zero(); n_chunks], + }; + + let fixed_hiding = |n_chunks: usize| PolyComm { + chunks: vec![G::ScalarField::one(); n_chunks], }; let coefficients_form = DensePolynomialOrEvaluations::DensePolynomial; @@ -1232,13 +1251,9 @@ where let mut polynomials = polys .iter() - .map(|(p, d1_size)| (coefficients_form(p), non_hiding(*d1_size))) + .map(|(p, n_chunks)| (coefficients_form(p), non_hiding(*n_chunks))) .collect::>(); - let fixed_hiding = |d1_size: usize| PolyComm { - elems: vec![G::ScalarField::one(); d1_size], - }; - //~ 1. Then, include: //~~ * the negated public polynomial //~~ * the ft polynomial @@ -1383,17 +1398,16 @@ where if lcs.runtime_selector.is_some() { let runtime_comm = lookup_context.runtime_table_comm.as_ref().unwrap(); - let elems = runtime_comm + let chunks = runtime_comm .blinders - .elems - .iter() - .map(|blinding| *joint_combiner * blinding + base_blinding) + .into_iter() + .map(|blinding| *joint_combiner * *blinding + base_blinding) .collect(); - PolyComm { elems } + PolyComm::new(chunks) } else { - let elems = vec![base_blinding; num_chunks]; - PolyComm { elems } + let chunks = vec![base_blinding; num_chunks]; + PolyComm::new(chunks) } }; @@ -1438,7 +1452,7 @@ where } //~ 1. Create an aggregated evaluation proof for all of these polynomials at $\zeta$ and $\zeta\omega$ using $u$ and $v$. - internal_tracing::checkpoint!(internal_traces; create_aggregated_evaluation_proof); + internal_tracing::checkpoint!(internal_traces; create_aggregated_ipa); let proof = OpenProof::open( &*index.srs, group_map, @@ -1495,7 +1509,7 @@ internal_tracing::decl_traces!(internal_traces; compute_ft_poly, ft_eval_zeta_omega, build_polynomials, - create_aggregated_evaluation_proof, + create_aggregated_ipa, create_recursive_done); #[cfg(feature = "ocaml_types")] @@ -1504,8 +1518,8 @@ pub mod caml { use crate::proof::caml::{CamlProofEvaluations, CamlRecursionChallenge}; use ark_ec::AffineRepr; use poly_commitment::{ - commitment::caml::{CamlOpeningProof, CamlPolyComm}, - evaluation_proof::OpeningProof, + commitment::caml::CamlPolyComm, + ipa::{caml::CamlOpeningProof, OpeningProof}, }; #[cfg(feature = "internal_tracing")] @@ -1529,7 +1543,8 @@ pub mod caml { pub evals: CamlProofEvaluations, pub ft_eval1: CamlF, pub public: Vec, - pub prev_challenges: Vec>, //Vec<(Vec, CamlPolyComm)>, + //Vec<(Vec, CamlPolyComm)>, + pub prev_challenges: Vec>, } // @@ -1581,8 +1596,10 @@ pub mod caml { // ProverCommitments -> CamlProverCommitments // we need to know how to convert G to CamlG. // we don't know that information, unless we implemented some trait (e.g. ToCaml) - // we can do that, but instead we implemented the From trait for the reverse operations (From for CamlG). - // it reduces the complexity, but forces us to do the conversion in two phases instead of one. + // we can do that, but instead we implemented the From trait for the reverse + // operations (From for CamlG). + // it reduces the complexity, but forces us to do the conversion in two + // phases instead of one. // // CamlLookupCommitments <-> LookupCommitments diff --git a/kimchi/src/prover_index.rs b/kimchi/src/prover_index.rs index 43318499a1..b03e66a4cc 100644 --- a/kimchi/src/prover_index.rs +++ b/kimchi/src/prover_index.rs @@ -3,6 +3,7 @@ use crate::{ alphas::Alphas, circuits::{ + berkeley_columns::{BerkeleyChallengeTerm, Column}, constraints::{ColumnEvaluations, ConstraintSystem}, expr::{Linearization, PolishToken}, }, @@ -28,7 +29,8 @@ pub struct ProverIndex> { /// The symbolic linearization of our circuit, which can compile to concrete types once certain values are learned in the protocol. #[serde(skip)] - pub linearization: Linearization>>, + pub linearization: + Linearization>, Column>, /// The mapping between powers of alpha and constraints #[serde(skip)] @@ -142,7 +144,10 @@ pub mod testing { }; use ark_ff::PrimeField; use ark_poly::{EvaluationDomain, Radix2EvaluationDomain as D}; - use poly_commitment::{evaluation_proof::OpeningProof, srs::SRS, OpenProof}; + use poly_commitment::{ + ipa::{OpeningProof, SRS}, + OpenProof, + }; #[allow(clippy::too_many_arguments)] pub fn new_index_for_test_with_lookups_and_custom_srs< @@ -212,7 +217,7 @@ pub mod testing { let log2_size = size.ilog2(); let srs = if log2_size <= precomputed_srs::SERIALIZED_SRS_SIZE { // TODO: we should trim it if it's smaller - precomputed_srs::get_srs() + precomputed_srs::get_srs_test() } else { // TODO: we should resume the SRS generation starting from the serialized one SRS::::create(size) diff --git a/kimchi/src/snarky/README.md b/kimchi/src/snarky/README.md new file mode 100644 index 0000000000..9cab81767f --- /dev/null +++ b/kimchi/src/snarky/README.md @@ -0,0 +1,3 @@ +# Snarky-rs + +This is a Rust implementation of snarky, originally written in OCaml. diff --git a/kimchi/src/snarky/api.rs b/kimchi/src/snarky/api.rs new file mode 100644 index 0000000000..c1b0e56015 --- /dev/null +++ b/kimchi/src/snarky/api.rs @@ -0,0 +1,347 @@ +//! The main interface to using Snarky. +//! +//! To use Snarky, simply implements the [SnarkyCircuit] trait. + +use std::marker::PhantomData; + +use crate::{ + circuits::{constraints::ConstraintSystem, gate::CircuitGate, polynomial::COLUMNS}, + curve::KimchiCurve, + groupmap::GroupMap, + mina_poseidon::FqSponge, + plonk_sponge::FrSponge, + proof::ProverProof, + prover_index::ProverIndex, + verifier::verify, + verifier_index::VerifierIndex, +}; + +use ark_ec::AffineRepr; +use ark_ff::PrimeField; +use log::debug; +use poly_commitment::{commitment::CommitmentCurve, OpenProof, SRS}; + +use super::{errors::SnarkyResult, runner::RunState, snarky_type::SnarkyType}; + +/// A witness represents the execution trace of a circuit. +#[derive(Debug)] +pub struct Witness(pub [Vec; COLUMNS]); + +// +// aliases +// + +type ScalarField = ::ScalarField; +type BaseField = ::BaseField; + +/// A prover index. +pub struct ProverIndexWrapper +where + Circuit: SnarkyCircuit, +{ + compiled_circuit: CompiledCircuit, + index: ProverIndex, +} + +type Proof = ProverProof<::Curve, ::Proof>; +type Output = <::PublicOutput as SnarkyType< + ScalarField<::Curve>, +>>::OutOfCircuit; + +impl ProverIndexWrapper +where + Circuit: SnarkyCircuit, +{ + /// Produces an assembly-like encoding of the circuit. + pub fn asm(&self) -> String { + crate::circuits::gate::Circuit::new( + self.compiled_circuit.public_input_size, + &self.compiled_circuit.gates, + ) + .generate_asm() + } + + /// Produces a proof for the given public input. + pub fn prove( + // TODO: this should not be mutable ideally + &mut self, + public_input: >>::OutOfCircuit, + private_input: Circuit::PrivateInput, + // TODO: rename to verify_witness? + debug: bool, + ) -> SnarkyResult<(Proof, Box>)> + where + ::BaseField: PrimeField, + EFqSponge: Clone + + FqSponge, Circuit::Curve, ScalarField>, + EFrSponge: FrSponge>, + { + // create public input + let public_input_without_output = + Circuit::PublicInput::value_to_field_elements(&public_input).0; + + // init + self.compiled_circuit + .sys + .generate_witness_init(public_input_without_output.clone())?; + + // run circuit and get return var + let public_input_var: Circuit::PublicInput = self.compiled_circuit.sys.public_input(); + let return_var = self.compiled_circuit.circuit.circuit( + &mut self.compiled_circuit.sys, + public_input_var, + Some(&private_input), + )?; + + // get values from private input vec + let (return_cvars, aux) = return_var.to_cvars(); + let mut public_output_values = vec![]; + for cvar in &return_cvars { + public_output_values.push(cvar.eval(&self.compiled_circuit.sys)); + } + + // create constraint between public output var and return var + { + // Note: since the values of the public output part are set to zero at this point, + // let's also avoid checking the wiring (which would fail) + let eval_constraints = self.compiled_circuit.sys.eval_constraints; + self.compiled_circuit.sys.eval_constraints = false; + + self.compiled_circuit.sys.wire_public_output(return_var)?; + + self.compiled_circuit.sys.eval_constraints = eval_constraints; + } + + // finalize + let mut witness = self.compiled_circuit.sys.generate_witness(); + + // replace public output part of witness + let start = Circuit::PublicInput::SIZE_IN_FIELD_ELEMENTS; + let end = start + Circuit::PublicOutput::SIZE_IN_FIELD_ELEMENTS; + for (cell, val) in &mut witness.0[0][start..end] + .iter_mut() + .zip(&public_output_values) + { + *cell = *val; + } + + // same but with the full public input + let mut public_input_and_output = public_input_without_output; + public_input_and_output.extend(public_output_values.clone()); + + // reconstruct public output + let public_output = + Circuit::PublicOutput::value_of_field_elements(public_output_values, aux); + + // verify the witness + // TODO: return error instead of panicking + if debug { + witness.debug(); + self.index + .verify(&witness.0, &public_input_and_output) + .unwrap(); + } + + // produce a proof + let group_map = ::Map::setup(); + + // TODO: return error instead of panicking + let proof: ProverProof = + ProverProof::create::( + &group_map, + witness.0, + &[], + &self.index, + &mut rand::rngs::OsRng, + ) + .unwrap(); + + // return proof + public output + Ok((proof, Box::new(public_output))) + } +} + +/// A verifier index. +pub struct VerifierIndexWrapper +where + Circuit: SnarkyCircuit, +{ + index: VerifierIndex, +} + +impl VerifierIndexWrapper +where + Circuit: SnarkyCircuit, +{ + /// Verify a proof for a given public input and public output. + pub fn verify( + &self, + proof: ProverProof, + public_input: >>::OutOfCircuit, + public_output: >>::OutOfCircuit, + ) where + ::BaseField: PrimeField, + EFqSponge: Clone + + FqSponge, Circuit::Curve, ScalarField>, + EFrSponge: FrSponge>, + { + let mut public_input = Circuit::PublicInput::value_to_field_elements(&public_input).0; + public_input.extend(Circuit::PublicOutput::value_to_field_elements(&public_output).0); + + // verify the proof + let group_map = ::Map::setup(); + + verify::( + &group_map, + &self.index, + &proof, + &public_input, + ) + .unwrap() + } +} + +// +// Compilation +// + +/// A compiled circuit. +// TODO: implement digest function +pub struct CompiledCircuit +where + Circuit: SnarkyCircuit, +{ + /// The snarky circuit itself. + circuit: Circuit, + + //// The state after compilation + sys: RunState>, + + /// The public input size. + // TODO: can't we get this from `circuit.public_input_size()`? (easy to implement). Or better, this could be a `Circuit` type that contains the gates as well (or the kimchi ConstraintSystem type) + public_input_size: usize, + + /// The gates obtained after compilation. + pub gates: Vec>>, + phantom: PhantomData, +} + +/// Compiles a circuit to a [CompiledCircuit]. +fn compile(circuit: Circuit) -> SnarkyResult> { + // calculate public input size + let public_input_size = Circuit::PublicInput::SIZE_IN_FIELD_ELEMENTS + + Circuit::PublicOutput::SIZE_IN_FIELD_ELEMENTS; + + // create snarky constraint system + let mut sys = RunState::new::( + Circuit::PublicInput::SIZE_IN_FIELD_ELEMENTS, + Circuit::PublicOutput::SIZE_IN_FIELD_ELEMENTS, + true, + ); + + // run circuit and get return var + let public_input: Circuit::PublicInput = sys.public_input(); + let return_var = circuit.circuit(&mut sys, public_input, None)?; + + // create constraint between public output var and return var + // compile to gates + // TODO: don't panic here, return an error + + let gates = sys.wire_output_and_compile(return_var).unwrap(); + let gates = gates.to_vec(); + + // return compiled circuit + let compiled_circuit = CompiledCircuit { + circuit, + sys, + public_input_size, + gates, + phantom: PhantomData, + }; + Ok(compiled_circuit) +} + +// +// The main user-facing trait for constructing circuits. +// + +/// The main trait. Implement this on your circuit to get access to more functions (specifically [Self::compile_to_indexes]). +pub trait SnarkyCircuit: Sized { + /// A circuit must be defined for a specific field, + /// as it might be incorrect to use a different field. + /// Currently we specify the field by the curve, + /// which is more strict and needed due to implementation details in kimchi. + // TODO: if we remove `sponge_params` from KimchiCurve and move it to the Field then we could specify a field here instead. + type Curve: KimchiCurve; + type Proof: OpenProof; + + /// The private input used by the circuit. + type PrivateInput; + + /// The public input used by the circuit. + type PublicInput: SnarkyType>; + + /// The public output returned by the circuit. + type PublicOutput: SnarkyType>; + + /// The circuit. It takes: + /// + /// - `self`: to parameterize it at compile time. + /// - `sys`: to construct the circuit or generate the witness (dpeending on mode) + /// - `public_input`: the public input (as defined above) + /// - `private_input`: the private input as an option, set to `None` for compilation. + /// + /// It returns a [SnarkyResult] containing the public output. + fn circuit( + &self, + // TODO: change to an enum that is either the state for compilation or the state for proving ([WitnessGeneration]) + // TODO: change the name to `runner` everywhere? + sys: &mut RunState>, + public_input: Self::PublicInput, + private_input: Option<&Self::PrivateInput>, + ) -> SnarkyResult; + + /// Compiles the circuit to a prover index ([ProverIndexWrapper]) and a verifier index ([VerifierIndexWrapper]). + fn compile_to_indexes( + self, + ) -> SnarkyResult<(ProverIndexWrapper, VerifierIndexWrapper)> + where + ::BaseField: PrimeField, + { + let compiled_circuit = compile(self)?; + + // create constraint system + let cs = ConstraintSystem::create(compiled_circuit.gates.clone()) + .public(compiled_circuit.public_input_size) + .build() + .unwrap(); + + // create SRS (for vesta, as the circuit is in Fp) + // let mut srs = SRS::::create(cs.domain.d1.size as usize); + let srs = <>::SRS as SRS>::create( + cs.domain.d1.size as usize, + ); + srs.get_lagrange_basis(cs.domain.d1); + let srs = std::sync::Arc::new(srs); + + debug!("using an SRS of size {}", srs.size()); + + // create indexes + let endo_q = <::Curve as KimchiCurve>::other_curve_endo(); + + let prover_index = + crate::prover_index::ProverIndex::::create(cs, *endo_q, srs); + let verifier_index = prover_index.verifier_index(); + + let prover_index = ProverIndexWrapper { + compiled_circuit, + index: prover_index, + }; + + let verifier_index = VerifierIndexWrapper { + index: verifier_index, + }; + + Ok((prover_index, verifier_index)) + } +} diff --git a/kimchi/src/snarky/asm.rs b/kimchi/src/snarky/asm.rs index 9facbc5202..12e85c42c9 100644 --- a/kimchi/src/snarky/asm.rs +++ b/kimchi/src/snarky/asm.rs @@ -1,5 +1,6 @@ //! An ASM-like language to print a human-friendly version of a circuit. +use itertools::Itertools; use std::{ collections::{HashMap, HashSet}, fmt::Write, @@ -13,6 +14,8 @@ use crate::circuits::{ }; use ark_ff::PrimeField; +use super::api::Witness; + /// Print a field in a negative form if it's past the half point. fn pretty(ff: F) -> String { let bigint: num_bigint::BigUint = ff.into(); @@ -84,19 +87,25 @@ where ) in wires.iter().enumerate() { if row != *to_row || col != *to_col { + // if this gate is generic, use generic variables let col_str = if matches!(typ, GateType::Generic) { format!(".{}", Self::generic_cols(col)) } else { format!("[{col}]") }; + // same for the wired gate let to_col = if matches!(self.gates[*to_row].typ, GateType::Generic) { format!(".{}", Self::generic_cols(*to_col)) } else { format!("[{to_col}]") }; - let res = format!("{col_str} -> row{to_row}{to_col}"); + let res = if row != *to_row { + format!("{col_str} -> row{to_row}{to_col}") + } else { + format!("{col_str} -> {to_col}") + }; if matches!(typ, GateType::Generic) && col < GENERIC_REGISTERS { wires1.push(res); @@ -207,6 +216,18 @@ where } } +impl Witness +where + F: PrimeField, +{ + pub fn debug(&self) { + for (row, values) in self.0.iter().enumerate() { + let values = values.iter().map(|v| pretty(*v)).join(" | "); + println!("{row} - {values}"); + } + } +} + #[cfg(test)] mod tests { use mina_curves::pasta::Fp; @@ -215,8 +236,10 @@ mod tests { use super::*; - #[test] - fn test_simple_circuit_asm() { + // FIXME: This test doesn't print the correct output. + // Not critical atm, commenting it + // #[test] + fn _test_simple_circuit_asm() { let public_input_size = 1; let gates: &Vec> = &vec![ CircuitGate::new( diff --git a/kimchi/src/snarky/boolean.rs b/kimchi/src/snarky/boolean.rs new file mode 100644 index 0000000000..275e8a4535 --- /dev/null +++ b/kimchi/src/snarky/boolean.rs @@ -0,0 +1,205 @@ +//! The [Boolean] type is a snarky type that represents a boolean variable. + +use std::borrow::Cow; + +use crate::snarky::{ + constraint_system::BasicSnarkyConstraint, cvar::FieldVar, runner::RunState, + snarky_type::SnarkyType, +}; +use ark_ff::PrimeField; + +use super::{errors::SnarkyResult, runner::Constraint}; + +trait OutOfCircuitSnarkyType2 { + type InCircuit; +} + +impl OutOfCircuitSnarkyType2 for bool +where + F: PrimeField, +{ + type InCircuit = Boolean; +} + +/// A boolean variable. +#[derive(Debug, Clone)] +pub struct Boolean(FieldVar); + +impl SnarkyType for Boolean +where + F: PrimeField, +{ + type Auxiliary = (); + + type OutOfCircuit = bool; + + const SIZE_IN_FIELD_ELEMENTS: usize = 1; + + fn to_cvars(&self) -> (Vec>, Self::Auxiliary) { + (vec![self.0.clone()], ()) + } + + fn from_cvars_unsafe(cvars: Vec>, _aux: Self::Auxiliary) -> Self { + assert_eq!(cvars.len(), Self::SIZE_IN_FIELD_ELEMENTS); + Self(cvars[0].clone()) + } + + fn check(&self, cs: &mut RunState, loc: Cow<'static, str>) -> SnarkyResult<()> { + let constraint = BasicSnarkyConstraint::Boolean(self.0.clone()); + cs.add_constraint( + Constraint::BasicSnarkyConstraint(constraint), + Some("boolean check".into()), + loc, + ) + } + + fn constraint_system_auxiliary() -> Self::Auxiliary {} + + fn value_to_field_elements(value: &Self::OutOfCircuit) -> (Vec, Self::Auxiliary) { + if *value { + (vec![F::one()], ()) + } else { + (vec![F::zero()], ()) + } + } + + fn value_of_field_elements(fields: Vec, _aux: Self::Auxiliary) -> Self::OutOfCircuit { + assert_eq!(fields.len(), 1); + + fields[0] != F::zero() + } +} + +impl Boolean +where + F: PrimeField, +{ + pub fn true_() -> Self { + Self(FieldVar::Constant(F::one())) + } + + pub fn false_() -> Self { + Self(FieldVar::zero()) + } + + pub fn create_unsafe(x: FieldVar) -> Self { + Self(x) + } + + pub fn to_field_var(&self) -> FieldVar { + self.0.clone() + } + + pub fn not(&self) -> Self { + Self(Self::true_().0 - &self.0) + } + + pub fn and(&self, other: &Self, cs: &mut RunState, loc: Cow<'static, str>) -> Self { + let res = self + .0 + .mul(&other.0, Some("bool.and".into()), loc, cs) + .expect("compiler bug"); + Self(res) + } + + pub fn or(&self, other: &Self, loc: Cow<'static, str>, cs: &mut RunState) -> Self { + let both_false = self.not().and(&other.not(), cs, loc); + both_false.not() + } + + pub fn any(xs: &[&Self], cs: &mut RunState, loc: Cow<'static, str>) -> SnarkyResult { + if xs.is_empty() { + return Ok(Self::false_()); // TODO: shouldn't we panic instead? + } else if xs.len() == 1 { + return Ok(xs[0].clone()); + } else if xs.len() == 2 { + return Ok(xs[0].or(xs[1], loc, cs)); // TODO: is this better than below? + } + + let zero = FieldVar::zero(); + + let xs: Vec<_> = xs.iter().map(|x| &x.0).collect(); + let sum = FieldVar::sum(&xs); + let all_zero = sum.equal(cs, loc, &zero)?; + + let res = all_zero.not(); + + Ok(res) + } + + pub fn all(xs: &[Self], cs: &mut RunState, loc: Cow<'static, str>) -> SnarkyResult { + if xs.is_empty() { + return Ok(Self::true_()); // TODO: shouldn't we panic instead? + } else if xs.len() == 1 { + return Ok(xs[0].clone()); + } else if xs.len() == 2 { + return Ok(xs[0].and(&xs[1], cs, loc)); // TODO: is this better than below? + } + + let expected = FieldVar::Constant(F::from(xs.len() as u64)); + let xs: Vec<_> = xs.iter().map(|x| &x.0).collect(); + let sum = FieldVar::sum(&xs); + + sum.equal(cs, loc, &expected) + } + + pub fn to_constant(&self) -> Option { + match self.0 { + FieldVar::Constant(x) => Some(x == F::one()), + _ => None, + } + } + + pub fn xor( + &self, + other: &Self, + state: &mut RunState, + loc: Cow<'static, str>, + ) -> SnarkyResult { + let res = match (self.to_constant(), other.to_constant()) { + (Some(true), _) => other.not(), + (_, Some(true)) => self.not(), + (Some(false), _) => other.clone(), + (_, Some(false)) => self.clone(), + (None, None) => { + /* + (1 - 2 a) (1 - 2 b) = 1 - 2 c + 1 - 2 (a + b) + 4 a b = 1 - 2 c + - 2 (a + b) + 4 a b = - 2 c + (a + b) - 2 a b = c + 2 a b = a + b - c + */ + + let self_clone = self.clone(); + let other_clone = other.clone(); + let res: Boolean = state.compute_unsafe(loc.clone(), move |env| { + let _b1: bool = self_clone.read(env); + let _b2: bool = other_clone.read(env); + + /* + let%bind res = + exists typ_unchecked + ~compute: + As_prover.( + map2 ~f:Bool.( <> ) (read typ_unchecked b1) + (read typ_unchecked b2)) + in + */ + + todo!() + })?; + + let x = &self.0 + &self.0; + let y = &other.0; + let z = &self.0 + &other.0 - &res.0; + + // TODO: annotation? + state.assert_r1cs(Some("xor".into()), loc, x, y.clone(), z)?; + + res + } + }; + + Ok(res) + } +} diff --git a/kimchi/src/snarky/constants.rs b/kimchi/src/snarky/constants.rs index 2324d4ff87..646a81ee10 100644 --- a/kimchi/src/snarky/constants.rs +++ b/kimchi/src/snarky/constants.rs @@ -1,11 +1,10 @@ //! Constants used for poseidon. +use crate::curve::KimchiCurve; use ark_ff::Field; use mina_poseidon::poseidon::ArithmeticSpongeParams; -use crate::curve::KimchiCurve; - -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct Constants { pub poseidon: ArithmeticSpongeParams, pub endo: F, @@ -19,7 +18,7 @@ where pub fn new>() -> Self { let poseidon = Curve::sponge_params().clone(); let endo_q = Curve::other_curve_endo(); - let base = Curve::other_curve_prime_subgroup_generator(); + let base = Curve::other_curve_generator(); Self { poseidon, diff --git a/kimchi/src/snarky/constraint_system.rs b/kimchi/src/snarky/constraint_system.rs index e3df69f1f9..ed53a40e32 100644 --- a/kimchi/src/snarky/constraint_system.rs +++ b/kimchi/src/snarky/constraint_system.rs @@ -1,20 +1,34 @@ #![allow(clippy::all)] -use crate::circuits::{ - gate::{CircuitGate, GateType}, - polynomials::poseidon::{ROUNDS_PER_HASH, SPONGE_WIDTH}, - wires::{Wire, COLUMNS, PERMUTS}, +//! The backend used by Snarky, gluing snarky to kimchi. +//! This module holds the actual logic that constructs the circuit using kimchi's gates, +//! as well as the logic that constructs the permutation, +//! and the symbolic execution trace table (both for compilation and at runtime). + +use crate::{ + circuits::{ + gate::{CircuitGate, GateType}, + polynomials::{ + generic::GENERIC_COEFFS, + poseidon::{ROUNDS_PER_HASH, ROUNDS_PER_ROW, SPONGE_WIDTH}, + }, + wires::{Wire, COLUMNS, PERMUTS}, + }, + snarky::{constants::Constants, cvar::FieldVar, runner::WitnessGeneration}, }; use ark_ff::PrimeField; use itertools::Itertools; -use std::collections::{HashMap, HashSet}; +use std::{ + borrow::Cow, + collections::{HashMap, HashSet}, +}; -use super::constants::Constants; +use super::{errors::SnarkyRuntimeError, union_find::DisjointSet}; /** A row indexing in a constraint system. Either a public input row, or a non-public input row that starts at index 0. */ -#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] enum Row { PublicInput(usize), AfterPublicInput(usize), @@ -32,7 +46,7 @@ impl Row { /* TODO: rename module Position to Permutation/Wiring? */ /** A position represents the position of a cell in the constraint system. A position is a row and a column. */ -#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] struct Position { row: Row, col: usize, @@ -47,9 +61,17 @@ impl Position { } } +#[derive(Debug, Clone)] +struct PendingGate { + labels: Vec>, + loc: Cow<'static, str>, + vars: (Option, Option, Option), + coeffs: Vec, +} + /** A gate/row/constraint consists of a type (kind), a row, the other cells its columns/cells are connected to (`wired_to`), and the selector polynomial associated with the gate. */ -#[derive(Clone)] +#[derive(Debug, Clone)] struct GateSpec { kind: GateType, wired_to: Vec>, @@ -91,6 +113,11 @@ impl GateSpec { } } +#[derive(Debug)] +#[cfg_attr( + feature = "ocaml_types", + derive(ocaml::IntoValue, ocaml::FromValue, ocaml_gen::Struct) +)] pub struct ScaleRound { pub accs: Vec<(A, A)>, pub bits: Vec, @@ -100,6 +127,11 @@ pub struct ScaleRound { pub n_next: A, } +#[derive(Debug)] +#[cfg_attr( + feature = "ocaml_types", + derive(ocaml::IntoValue, ocaml::FromValue, ocaml_gen::Struct) +)] pub struct EndoscaleRound { pub xt: A, pub yt: A, @@ -116,6 +148,11 @@ pub struct EndoscaleRound { pub b4: A, } +#[derive(Debug)] +#[cfg_attr( + feature = "ocaml_types", + derive(ocaml::IntoValue, ocaml::FromValue, ocaml_gen::Struct) +)] pub struct EndoscaleScalarRound { pub n0: A, pub n8: A, @@ -133,6 +170,8 @@ pub struct EndoscaleScalarRound { pub x7: A, } +// TODO: get rid of this +#[derive(Debug)] pub enum BasicSnarkyConstraint { Boolean(Var), Equal(Var, Var), @@ -140,49 +179,85 @@ pub enum BasicSnarkyConstraint { R1CS(Var, Var, Var), } +#[derive(Debug)] +#[cfg_attr( + feature = "ocaml_types", + derive(ocaml::IntoValue, ocaml::FromValue, ocaml_gen::Struct) +)] +pub struct BasicInput { + pub l: (Field, Var), + pub r: (Field, Var), + pub o: (Field, Var), + pub m: Field, + pub c: Field, +} + +#[derive(Debug)] +#[cfg_attr( + feature = "ocaml_types", + derive(ocaml::IntoValue, ocaml::FromValue, ocaml_gen::Struct) +)] +pub struct PoseidonInput { + // TODO: revert back to arrays once we don't need to expose this struct to OCaml + // pub states: [[Var; SPONGE_WIDTH]; ROUNDS_PER_HASH], + // pub last: [Var; SPONGE_WIDTH], + pub states: Vec>, + pub last: Vec, +} + +#[derive(Debug)] +#[cfg_attr( + feature = "ocaml_types", + derive(ocaml::IntoValue, ocaml::FromValue, ocaml_gen::Struct) +)] +pub struct EcAddCompleteInput { + pub p1: (Var, Var), + pub p2: (Var, Var), + pub p3: (Var, Var), + pub inf: Var, + pub same_x: Var, + pub slope: Var, + pub inf_z: Var, + pub x21_inv: Var, +} + +#[derive(Debug)] +#[cfg_attr( + feature = "ocaml_types", + derive(ocaml::IntoValue, ocaml::FromValue, ocaml_gen::Struct) +)] +pub struct EcEndoscaleInput { + pub state: Vec>, + pub xs: Var, + pub ys: Var, + pub n_acc: Var, +} + /** A PLONK constraint (or gate) can be [`Basic`](KimchiConstraint::Basic), [`Poseidon`](KimchiConstraint::Poseidon), * [`EcAddComplete`](KimchiConstraint::EcAddComplete), [`EcScale`](KimchiConstraint::EcScale), * [`EcEndoscale`](KimchiConstraint::EcEndoscale), or [`EcEndoscalar`](KimchiConstraint::EcEndoscalar). */ +#[derive(Debug)] +#[cfg_attr( + feature = "ocaml_types", + derive(ocaml::IntoValue, ocaml::FromValue, ocaml_gen::Enum) +)] pub enum KimchiConstraint { - Basic { - l: (Field, Var), - r: (Field, Var), - o: (Field, Var), - m: Field, - c: Field, - }, - Poseidon { - state: Vec>, - }, - EcAddComplete { - p1: (Var, Var), - p2: (Var, Var), - p3: (Var, Var), - inf: Var, - same_x: Var, - slope: Var, - inf_z: Var, - x21_inv: Var, - }, - EcScale { - state: Vec>, - }, - EcEndoscale { - state: Vec>, - xs: Var, - ys: Var, - n_acc: Var, - }, - EcEndoscalar { - state: Vec>, - }, + Basic(BasicInput), + Poseidon(Vec>), + Poseidon2(PoseidonInput), + EcAddComplete(EcAddCompleteInput), + EcScale(Vec>), + EcEndoscale(EcEndoscaleInput), + EcEndoscalar(Vec>), + //[[Var; 15]; 4] + RangeCheck(Vec>), } /* TODO: This is a Unique_id in OCaml. */ -#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] struct InternalVar(usize); -#[derive(Copy, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] enum V { /** An external variable (generated by snarky, via [exists]). */ External(usize), @@ -195,7 +270,7 @@ enum V { /** Keeps track of a circuit (which is a list of gates) while it is being written. */ -#[derive(Clone)] +#[derive(Debug, Clone)] enum Circuit where F: PrimeField, @@ -209,6 +284,7 @@ where } /** The constraint system. */ +#[derive(Debug, Clone)] pub struct SnarkyConstraintSystem where Field: PrimeField, @@ -229,17 +305,26 @@ where */ gates: Circuit, /** The row to use the next time we add a constraint. */ + // TODO: I think we can delete this and get it from rows.len() or something next_row: usize, /** The size of the public input (which fills the first rows of our constraint system. */ public_input_size: Option, - /** Whatever is not public input. */ - auxiliary_input_size: usize, + + /** The number of previous recursion challenges. */ + prev_challenges: Option, + + /// Enables the double generic gate optimization. + /// It can be useful to disable this feature for debugging. + generic_gate_optimization: bool, + /** Queue (of size 1) of generic gate. */ - pending_generic_gate: Option<(Option, Option, Option, Vec)>, + pending_generic_gate: Option>, + /** V.t's corresponding to constant values. We reuse them so we don't need to use a fresh generic constraint each time to create a constant. */ cached_constants: HashMap, + /** The [equivalence_classes](SnarkyConstraintSystem::equivalence_classes) field keeps track of the positions which must be enforced to be equivalent due to the fact that they correspond to the same V.t value. I.e., positions that are different usages of the same [V.t]. @@ -250,10 +335,25 @@ where a single equivalence class, so that the permutation argument enforces these desired equalities as well. */ - union_finds: disjoint_set::DisjointSet, + union_finds: DisjointSet, } impl SnarkyConstraintSystem { + /** Sets the number of public-input. It must and can only be called once. */ + pub fn set_primary_input_size(&mut self, num_pub_inputs: usize) { + if self.public_input_size.is_some() { + panic!("set_primary_input_size can only be called once"); + } + self.public_input_size = Some(num_pub_inputs); + } + + pub fn set_prev_challenges(&mut self, prev_challenges: usize) { + if self.prev_challenges.is_some() { + panic!("set_prev_challenges can only be called once"); + } + self.prev_challenges = Some(prev_challenges); + } + /** Converts the set of permutations (`equivalence_classes`) to a hash table that maps each position to the next one. For example, if one of the equivalence class is [pos1, pos3, pos7], @@ -281,26 +381,68 @@ impl SnarkyConstraintSystem { res } + pub fn compute_witness_for_ocaml( + &mut self, + public_inputs: &[Field], + private_inputs: &[Field], + ) -> [Vec; COLUMNS] { + // make sure it's finalized + self.finalize(); + + // ensure that we have the right number of public inputs + let public_input_size = self.get_primary_input_size(); + assert_eq!(public_inputs.len(), public_input_size); + + // create closure that will read variables from the input + let external_values = |i| { + if i < public_input_size { + public_inputs[i] + } else { + private_inputs[i - public_input_size] + } + }; + + // compute witness + self.compute_witness(external_values) + } + /// Compute the witness, given the constraint system `sys` /// and a function that converts the indexed secret inputs to their concrete values. /// /// # Panics /// /// Will panic if some inputs like `public_input_size` are unknown(None value). - pub fn compute_witness Field>(&self, external_values: F) -> Vec> { + // TODO: build the transposed version instead of this + pub fn compute_witness(&mut self, external_values: FUNC) -> [Vec; COLUMNS] + where + FUNC: Fn(usize) -> Field, + { + // make sure it's finalized + self.finalize(); + + // init execution trace table let mut internal_values = HashMap::new(); let public_input_size = self.public_input_size.unwrap(); let num_rows = public_input_size + self.next_row; - let mut res = vec![vec![Field::zero(); num_rows]; COLUMNS]; + let mut res: [_; COLUMNS] = std::array::from_fn(|_| vec![Field::zero(); num_rows]); + + // obtain public input from closure for i in 0..public_input_size { - res[0][i] = external_values(i + 1); + res[0][i] = external_values(i); } + + // compute rest of execution trace table for (i_after_input, cols) in self.rows.iter().enumerate() { let row_idx = i_after_input + public_input_size; for (col_idx, var) in cols.iter().enumerate() { match var { + // keep default value of zero None => (), + + // use closure for external values Some(V::External(var)) => res[col_idx][row_idx] = external_values(*var), + + // for internal values, compute the linear combination Some(V::Internal(var)) => { let (lc, c) = { match self.internal_vars.get(var) { @@ -326,6 +468,7 @@ impl SnarkyConstraintSystem { } } } + res } @@ -346,36 +489,31 @@ impl SnarkyConstraintSystem { // TODO: if we expect a `Field: KimchiParams` we can simply do `Field::constants()` here. But we might want to wait for Fabrizio's trait? Also we should keep this close to the OCaml stuff if we want to avoid pains when we plug this in constants: constants, public_input_size: None, + prev_challenges: None, next_internal_var: 0, internal_vars: HashMap::new(), gates: Circuit::Unfinalized(Vec::new()), rows: Vec::new(), next_row: 0, equivalence_classes: HashMap::new(), - auxiliary_input_size: 0, + generic_gate_optimization: true, pending_generic_gate: None, cached_constants: HashMap::new(), - union_finds: disjoint_set::DisjointSet::new(), + union_finds: DisjointSet::new(), } } - /** Returns the number of auxiliary inputs. */ - pub fn get_auxiliary_input_size(&self) -> usize { - self.auxiliary_input_size - } - /// Returns the number of public inputs. /// /// # Panics /// /// Will panic if `public_input_size` is None. pub fn get_primary_input_size(&self) -> usize { - self.public_input_size.unwrap() + self.public_input_size.expect("attempt to retrieve public input size before it was set (via `set_primary_input_size`)") } - /** Non-public part of the witness. */ - pub fn set_auxiliary_input_size(&mut self, x: usize) { - self.auxiliary_input_size = x; + pub fn get_prev_challenges(&self) -> Option { + self.prev_challenges } /** Sets the number of public-input. It should only be called once. */ @@ -401,7 +539,20 @@ impl SnarkyConstraintSystem { } /** Adds a row/gate/constraint to a constraint system `sys`. */ - fn add_row(&mut self, vars: Vec>, kind: GateType, coeffs: Vec) { + fn add_row( + &mut self, + labels: &[Cow<'static, str>], + loc: &Cow<'static, str>, + vars: Vec>, + kind: GateType, + coeffs: Vec, + ) { + // TODO: for now we can print the debug info at runtime, but in the future we should allow serialization of these things as well + // TODO: this ignores the public gates!! + if std::env::var("SNARKY_LOG_CONSTRAINTS").is_ok() { + println!("{}: {loc} - {}", self.next_row, labels.join(", ")); + } + /* As we're adding a row, we're adding new cells. If these cells (the first 7) contain variables, make sure that they are wired @@ -427,6 +578,13 @@ impl SnarkyConstraintSystem { self.rows.push(vars); } + /// Returns the number of rows in the constraint system. + /// Note: This is not necessarily the number of rows of the compiled circuit. + /// If the circuit has not finished compiling, you will only get the current number of rows. + pub fn get_rows_len(&self) -> usize { + self.rows.len() + } + /// Fill the `gate` values(input and output), and finalize the `circuit`. /// /// # Panics @@ -440,9 +598,21 @@ impl SnarkyConstraintSystem { } // if we still have some pending gates, deal with it first - if let Some((l, r, o, coeffs)) = self.pending_generic_gate.take() { + if let Some(PendingGate { + labels, + loc, + vars: (l, r, o), + coeffs, + }) = self.pending_generic_gate.take() + { self.pending_generic_gate = None; - self.add_row(vec![l, r, o], GateType::Generic, coeffs.clone()); + self.add_row( + &labels, + &loc, + vec![l, r, o], + GateType::Generic, + coeffs.clone(), + ); } // get gates without holding on an immutable reference @@ -455,14 +625,15 @@ impl SnarkyConstraintSystem { let public_input_size = self.public_input_size.unwrap(); let pub_selectors: Vec<_> = vec![ Field::one(), + // TODO: unecessary Field::zero(), Field::zero(), Field::zero(), Field::zero(), ]; let mut public_gates = Vec::new(); - for row in 0..(public_input_size - 1) { - let public_var = V::External(row + 1); + for row in 0..public_input_size { + let public_var = V::External(row); self.wire_(public_var, Row::PublicInput(row), 0); public_gates.push(GateSpec { kind: GateType::Generic, @@ -529,6 +700,21 @@ impl SnarkyConstraintSystem { self.gates = Circuit::Compiled(digest, rust_gates); } + /// Produces a digest of the constraint system. + /// + /// # Panics + /// + /// Will panic if the constraint system has not previously been compiled (via [`Self::finalize`]). + pub fn digest(&mut self) -> [u8; 32] { + // make sure it's finalized + self.finalize(); + + match &self.gates { + Circuit::Compiled(digest, _) => *digest, + Circuit::Unfinalized(_) => unreachable!(), + } + } + // TODO: why does it return a mutable reference? pub fn finalize_and_get_gates(&mut self) -> &mut Vec> { self.finalize(); @@ -604,19 +790,49 @@ impl SnarkyConstraintSystem { */ fn add_generic_constraint( &mut self, + labels: &[Cow<'static, str>], + loc: &Cow<'static, str>, l: Option, r: Option, o: Option, mut coeffs: Vec, ) { + if !self.generic_gate_optimization { + assert!(coeffs.len() <= GENERIC_COEFFS); + self.add_row(labels, loc, vec![l, r, o], GateType::Generic, coeffs); + return; + } + match self.pending_generic_gate { - None => self.pending_generic_gate = Some((l, r, o, coeffs)), + None => { + self.pending_generic_gate = Some(PendingGate { + labels: labels.to_vec(), + loc: loc.to_owned(), + vars: (l, r, o), + coeffs, + }) + } Some(_) => { - if let Some((l2, r2, o2, coeffs2)) = - std::mem::replace(&mut self.pending_generic_gate, None) + if let Some(PendingGate { + labels: labels2, + loc: loc2, + vars: (l2, r2, o2), + coeffs: coeffs2, + }) = std::mem::replace(&mut self.pending_generic_gate, None) { + let labels1 = labels.join(","); + let labels2 = labels2.join(","); + let labels = vec![Cow::Owned(format!("gen1:[{}] gen2:[{}]", labels1, labels2))]; + let loc = format!("gen1:[{}] gen2:[{}]", loc, loc2).into(); + coeffs.extend(coeffs2); - self.add_row(vec![l, r, o, l2, r2, o2], GateType::Generic, coeffs); + self.add_row( + &labels, + &loc, + vec![l, r, o, l2, r2, o2], + GateType::Generic, + coeffs, + ); } } } @@ -637,7 +853,12 @@ impl SnarkyConstraintSystem { - return (1, internal_var_2) It assumes that the list of terms is not empty. */ - fn completely_reduce(&mut self, terms: Terms) -> (Field, V) + fn completely_reduce( + &mut self, + labels: &[Cow<'static, str>], + loc: &Cow<'static, str>, + terms: Terms, + ) -> (Field, V) where Terms: IntoIterator, ::IntoIter: DoubleEndedIterator, @@ -654,6 +875,8 @@ impl SnarkyConstraintSystem { let lx = V::External(lx); let s1x1_plus_s2x2 = self.create_internal(None, vec![(ls, lx), (rs, rx)]); self.add_generic_constraint( + labels, + loc, Some(lx), Some(rx), Some(s1x1_plus_s2x2), @@ -670,17 +893,19 @@ impl SnarkyConstraintSystem { It returns the output variable as (1, `Var res), unless the output is a constant, in which case it returns (c, `Constant). */ - fn reduce_lincom(&mut self, x: Cvar) -> (Field, ConstantOrVar) + fn reduce_lincom( + &mut self, + labels: &[Cow<'static, str>], + loc: &Cow<'static, str>, + x: Cvar, + ) -> (Field, ConstantOrVar) where Cvar: SnarkyCvar, { let (constant, terms) = x.to_constant_and_terms(); let terms = accumulate_terms(terms); let mut terms_list: Vec<_> = terms.into_iter().map(|(key, data)| (data, key)).collect(); - /* WARNING: The order here may differ from the OCaml order, since that depends on the order - of the map. */ terms_list.sort(); - terms_list.reverse(); match (constant, terms_list.len()) { (Some(c), 0) => (c, ConstantOrVar::Constant), (None, 0) => (Field::zero(), ConstantOrVar::Constant), @@ -693,6 +918,8 @@ impl SnarkyConstraintSystem { /* res = ls * lx + c */ let res = self.create_internal(Some(c), vec![(*ls, V::External(*lx))]); self.add_generic_constraint( + labels, + loc, Some(V::External(*lx)), None, Some(res), @@ -704,10 +931,12 @@ impl SnarkyConstraintSystem { /* reduce the terms, then add the constant */ let mut terms_list_iterator = terms_list.into_iter(); let (ls, lx) = terms_list_iterator.next().unwrap(); - let (rs, rx) = self.completely_reduce(terms_list_iterator); + let (rs, rx) = self.completely_reduce(labels, loc, terms_list_iterator); let res = self.create_internal(constant, vec![(ls, V::External(lx)), (rs, rx)]); /* res = ls * lx + rs * rx + c */ self.add_generic_constraint( + labels, + loc, Some(V::External(lx)), Some(rx), Some(res), @@ -725,11 +954,16 @@ impl SnarkyConstraintSystem { } /// reduce any [Cvar] to a single internal variable [V] - fn reduce_to_var(&mut self, x: Cvar) -> V + fn reduce_to_var( + &mut self, + labels: &[Cow<'static, str>], + loc: &Cow<'static, str>, + x: Cvar, + ) -> V where Cvar: SnarkyCvar, { - match self.reduce_lincom(x) { + match self.reduce_lincom(labels, loc, x) { (s, ConstantOrVar::Var(x)) => { if s == Field::one() { x @@ -737,6 +971,8 @@ impl SnarkyConstraintSystem { let sx = self.create_internal(Some(s), vec![(s, x)]); // s * x - sx = 0 self.add_generic_constraint( + labels, + loc, Some(x), None, Some(sx), @@ -756,6 +992,8 @@ impl SnarkyConstraintSystem { None => { let x = self.create_internal(None, vec![]); self.add_generic_constraint( + labels, + loc, Some(x), None, None, @@ -780,19 +1018,28 @@ impl SnarkyConstraintSystem { /// # Panics /// /// Will panic if `constant selector` constraints are not matching. - pub fn add_basic_snarky_constraint(&mut self, constraint: BasicSnarkyConstraint) - where + pub fn add_basic_snarky_constraint( + &mut self, + labels: &[Cow<'static, str>], + loc: &Cow<'static, str>, + constraint: BasicSnarkyConstraint, + ) where Cvar: SnarkyCvar, { match constraint { BasicSnarkyConstraint::Square(v1, v2) => { - match (self.reduce_lincom(v1), self.reduce_lincom(v2)) { + match ( + self.reduce_lincom(labels, loc, v1), + self.reduce_lincom(labels, loc, v2), + ) { ((sl, ConstantOrVar::Var(xl)), (so, ConstantOrVar::Var(xo))) => /* (sl * xl)^2 = so * xo sl^2 * xl * xl - so * xo = 0 */ { self.add_generic_constraint( + labels, + loc, Some(xl), Some(xl), Some(xo), @@ -803,6 +1050,8 @@ impl SnarkyConstraintSystem { /* TODO: it's hard to read the array of selector values, name them! */ { self.add_generic_constraint( + labels, + loc, Some(xl), Some(xl), None, @@ -813,6 +1062,8 @@ impl SnarkyConstraintSystem { /* sl^2 = so * xo */ { self.add_generic_constraint( + labels, + loc, None, None, Some(xo), @@ -825,9 +1076,9 @@ impl SnarkyConstraintSystem { } } BasicSnarkyConstraint::R1CS(v1, v2, v3) => match ( - self.reduce_lincom(v1), - self.reduce_lincom(v2), - self.reduce_lincom(v3), + self.reduce_lincom(labels, loc, v1), + self.reduce_lincom(labels, loc, v2), + self.reduce_lincom(labels, loc, v3), ) { ( (s1, ConstantOrVar::Var(x1)), @@ -839,6 +1090,8 @@ impl SnarkyConstraintSystem { */ { self.add_generic_constraint( + labels, + loc, Some(x1), Some(x2), Some(x3), @@ -850,6 +1103,8 @@ impl SnarkyConstraintSystem { (s2, ConstantOrVar::Var(x2)), (s3, ConstantOrVar::Constant), ) => self.add_generic_constraint( + labels, + loc, Some(x1), Some(x2), None, @@ -863,6 +1118,8 @@ impl SnarkyConstraintSystem { /* s1 x1 * s2 = s3 x3 */ { self.add_generic_constraint( + labels, + loc, Some(x1), None, Some(x3), @@ -874,6 +1131,8 @@ impl SnarkyConstraintSystem { (s2, ConstantOrVar::Var(x2)), (s3, ConstantOrVar::Var(x3)), ) => self.add_generic_constraint( + labels, + loc, None, Some(x2), Some(x3), @@ -884,6 +1143,8 @@ impl SnarkyConstraintSystem { (s2, ConstantOrVar::Constant), (s3, ConstantOrVar::Constant), ) => self.add_generic_constraint( + labels, + loc, Some(x1), None, None, @@ -894,6 +1155,8 @@ impl SnarkyConstraintSystem { (s2, ConstantOrVar::Var(x2)), (s3, ConstantOrVar::Constant), ) => self.add_generic_constraint( + labels, + loc, None, None, Some(x2), @@ -904,6 +1167,8 @@ impl SnarkyConstraintSystem { (s2, ConstantOrVar::Constant), (s3, ConstantOrVar::Var(x3)), ) => self.add_generic_constraint( + labels, + loc, None, None, Some(x3), @@ -916,12 +1181,14 @@ impl SnarkyConstraintSystem { ) => assert_eq!(s3, s1 * s2), }, BasicSnarkyConstraint::Boolean(v) => { - let (s, x) = self.reduce_lincom(v); + let (s, x) = self.reduce_lincom(labels, loc, v); match x { ConstantOrVar::Var(x) => /* -x + x * x = 0 */ { self.add_generic_constraint( + labels, + loc, Some(x), Some(x), None, @@ -938,7 +1205,10 @@ impl SnarkyConstraintSystem { } } BasicSnarkyConstraint::Equal(v1, v2) => { - let ((s1, x1), (s2, x2)) = (self.reduce_lincom(v1), self.reduce_lincom(v2)); + let ((s1, x1), (s2, x2)) = ( + self.reduce_lincom(labels, loc, v1), + self.reduce_lincom(labels, loc, v2), + ); match (x1, x2) { (ConstantOrVar::Var(x1), ConstantOrVar::Var(x2)) => { /* TODO: This logic is wrong, but matches the OCaml side. Fix both. */ @@ -952,6 +1222,8 @@ impl SnarkyConstraintSystem { /* s1 x1 - s2 x2 = 0 */ s1 != s2 { self.add_generic_constraint( + labels, + loc, Some(x1), Some(x2), None, @@ -959,6 +1231,8 @@ impl SnarkyConstraintSystem { ); } else { self.add_generic_constraint( + labels, + loc, Some(x1), Some(x2), None, @@ -980,6 +1254,8 @@ impl SnarkyConstraintSystem { } None => { self.add_generic_constraint( + labels, + loc, Some(x1), None, None, @@ -1003,6 +1279,8 @@ impl SnarkyConstraintSystem { } None => { self.add_generic_constraint( + labels, + loc, None, Some(x2), None, @@ -1023,12 +1301,16 @@ impl SnarkyConstraintSystem { /// # Panics /// /// Will panic if `witness` fields are empty. - pub fn add_constraint(&mut self, constraint: KimchiConstraint) - where + pub fn add_constraint( + &mut self, + labels: &[Cow<'static, str>], + loc: &Cow<'static, str>, + constraint: KimchiConstraint, + ) where Cvar: SnarkyCvar, { match constraint { - KimchiConstraint::Basic { l, r, o, m, c } => { + KimchiConstraint::Basic(BasicInput { l, r, o, m, c }) => { /* 0 = l.s * l.x + r.s * r.x @@ -1049,7 +1331,7 @@ impl SnarkyConstraintSystem { + c */ let mut c = c; - let mut red_pr = |(s, x)| match self.reduce_lincom(x) { + let mut red_pr = |(s, x)| match self.reduce_lincom(labels, loc, x) { (s2, ConstantOrVar::Constant) => { c += s * s2; (s2, None) @@ -1081,6 +1363,8 @@ impl SnarkyConstraintSystem { } }; self.add_generic_constraint( + labels, + loc, var(l), var(r), var(o), @@ -1089,10 +1373,10 @@ impl SnarkyConstraintSystem { } // TODO: the code in circuit-writer was better // TODO: also `rounds` would be a better name than `state` - KimchiConstraint::Poseidon { state } => { + KimchiConstraint::Poseidon(state) => { // we expect state to be a vector of all the intermediary round states // (in addition to the initial and final states) - assert_eq!(state.len(), ROUNDS_PER_HASH * ROUNDS_PER_HASH); + assert_eq!(state.len(), ROUNDS_PER_HASH + 1); // where each state is three field elements assert!(state.iter().all(|x| x.len() == SPONGE_WIDTH)); @@ -1100,7 +1384,11 @@ impl SnarkyConstraintSystem { // reduce the state let state: Vec> = state .into_iter() - .map(|vars| vars.into_iter().map(|x| self.reduce_to_var(x)).collect()) + .map(|vars| { + vars.into_iter() + .map(|x| self.reduce_to_var(labels, loc, x)) + .collect() + }) .collect(); // retrieve the final state @@ -1151,7 +1439,7 @@ impl SnarkyConstraintSystem { self.constants.poseidon.round_constants[round_4][1], self.constants.poseidon.round_constants[round_4][2], ]; - self.add_row(vars, GateType::Poseidon, coeffs); + self.add_row(labels, loc, vars, GateType::Poseidon, coeffs); } // add_last_row adds the last row containing the output @@ -1172,9 +1460,49 @@ impl SnarkyConstraintSystem { None, None, ]; - self.add_row(vars, GateType::Zero, vec![]); + self.add_row(labels, loc, vars, GateType::Zero, vec![]); + } + KimchiConstraint::Poseidon2(PoseidonInput { states, last }) => { + // reduce variables + let states = states + .into_iter() + .map(|round| { + round + .into_iter() + .map(|x| self.reduce_to_var(labels, loc, x)) + .collect_vec() + }) + .collect_vec(); + + // create the rows + for rounds in &states + .into_iter() + // TODO: poseidon constants should really be passed instead of living in the constraint system as a cfg no? annoying clone fosho + .zip(self.constants.poseidon.round_constants.clone()) + .chunks(ROUNDS_PER_ROW) + { + let (vars, coeffs) = rounds + .into_iter() + .flat_map(|(round, round_constants)| { + round + .into_iter() + .map(Option::Some) + .zip(round_constants.into_iter()) + }) + .unzip(); + self.add_row(labels, loc, vars, GateType::Poseidon, coeffs); + } + + // last row is a zero gate to save as output + let last = last + .into_iter() + .map(|x| self.reduce_to_var(labels, loc, x)) + .map(Some) + .collect_vec(); + self.add_row(labels, loc, last, GateType::Zero, vec![]); } - KimchiConstraint::EcAddComplete { + + KimchiConstraint::EcAddComplete(EcAddCompleteInput { p1, p2, p3, @@ -1183,9 +1511,13 @@ impl SnarkyConstraintSystem { slope, inf_z, x21_inv, - } => { - let mut reduce_curve_point = - |(x, y)| (self.reduce_to_var(x), self.reduce_to_var(y)); + }) => { + let mut reduce_curve_point = |(x, y)| { + ( + self.reduce_to_var(labels, loc, x), + self.reduce_to_var(labels, loc, y), + ) + }; // 0 1 2 3 4 5 6 7 8 9 // x1 y1 x2 y2 x3 y3 inf same_x s inf_z x21_inv let (x1, y1) = reduce_curve_point(p1); @@ -1199,15 +1531,15 @@ impl SnarkyConstraintSystem { Some(y2), Some(x3), Some(y3), - Some(self.reduce_to_var(inf)), - Some(self.reduce_to_var(same_x)), - Some(self.reduce_to_var(slope)), - Some(self.reduce_to_var(inf_z)), - Some(self.reduce_to_var(x21_inv)), + Some(self.reduce_to_var(labels, loc, inf)), + Some(self.reduce_to_var(labels, loc, same_x)), + Some(self.reduce_to_var(labels, loc, slope)), + Some(self.reduce_to_var(labels, loc, inf_z)), + Some(self.reduce_to_var(labels, loc, x21_inv)), ]; - self.add_row(vars, GateType::CompleteAdd, vec![]); + self.add_row(labels, loc, vars, GateType::CompleteAdd, vec![]); } - KimchiConstraint::EcScale { state } => { + KimchiConstraint::EcScale(state) => { for ScaleRound { accs, bits, @@ -1221,69 +1553,69 @@ impl SnarkyConstraintSystem { // xT yT x0 y0 n n' x1 y1 x2 y2 x3 y3 x4 y4 // x5 y5 b0 b1 b2 b3 b4 s0 s1 s2 s3 s4 let curr_row = vec![ - Some(self.reduce_to_var(base.0)), - Some(self.reduce_to_var(base.1)), - Some(self.reduce_to_var(accs[0].0.clone())), - Some(self.reduce_to_var(accs[0].1.clone())), - Some(self.reduce_to_var(n_prev)), - Some(self.reduce_to_var(n_next)), + Some(self.reduce_to_var(labels, loc, base.0)), + Some(self.reduce_to_var(labels, loc, base.1)), + Some(self.reduce_to_var(labels, loc, accs[0].0.clone())), + Some(self.reduce_to_var(labels, loc, accs[0].1.clone())), + Some(self.reduce_to_var(labels, loc, n_prev)), + Some(self.reduce_to_var(labels, loc, n_next)), None, - Some(self.reduce_to_var(accs[1].0.clone())), - Some(self.reduce_to_var(accs[1].1.clone())), - Some(self.reduce_to_var(accs[2].0.clone())), - Some(self.reduce_to_var(accs[2].1.clone())), - Some(self.reduce_to_var(accs[3].0.clone())), - Some(self.reduce_to_var(accs[3].1.clone())), - Some(self.reduce_to_var(accs[4].0.clone())), - Some(self.reduce_to_var(accs[4].1.clone())), + Some(self.reduce_to_var(labels, loc, accs[1].0.clone())), + Some(self.reduce_to_var(labels, loc, accs[1].1.clone())), + Some(self.reduce_to_var(labels, loc, accs[2].0.clone())), + Some(self.reduce_to_var(labels, loc, accs[2].1.clone())), + Some(self.reduce_to_var(labels, loc, accs[3].0.clone())), + Some(self.reduce_to_var(labels, loc, accs[3].1.clone())), + Some(self.reduce_to_var(labels, loc, accs[4].0.clone())), + Some(self.reduce_to_var(labels, loc, accs[4].1.clone())), ]; - self.add_row(curr_row, GateType::VarBaseMul, vec![]); + self.add_row(labels, loc, curr_row, GateType::VarBaseMul, vec![]); let next_row = vec![ - Some(self.reduce_to_var(accs[5].0.clone())), - Some(self.reduce_to_var(accs[5].1.clone())), - Some(self.reduce_to_var(bits[0].clone())), - Some(self.reduce_to_var(bits[1].clone())), - Some(self.reduce_to_var(bits[2].clone())), - Some(self.reduce_to_var(bits[3].clone())), - Some(self.reduce_to_var(bits[4].clone())), - Some(self.reduce_to_var(ss[0].clone())), - Some(self.reduce_to_var(ss[1].clone())), - Some(self.reduce_to_var(ss[2].clone())), - Some(self.reduce_to_var(ss[3].clone())), - Some(self.reduce_to_var(ss[4].clone())), + Some(self.reduce_to_var(labels, loc, accs[5].0.clone())), + Some(self.reduce_to_var(labels, loc, accs[5].1.clone())), + Some(self.reduce_to_var(labels, loc, bits[0].clone())), + Some(self.reduce_to_var(labels, loc, bits[1].clone())), + Some(self.reduce_to_var(labels, loc, bits[2].clone())), + Some(self.reduce_to_var(labels, loc, bits[3].clone())), + Some(self.reduce_to_var(labels, loc, bits[4].clone())), + Some(self.reduce_to_var(labels, loc, ss[0].clone())), + Some(self.reduce_to_var(labels, loc, ss[1].clone())), + Some(self.reduce_to_var(labels, loc, ss[2].clone())), + Some(self.reduce_to_var(labels, loc, ss[3].clone())), + Some(self.reduce_to_var(labels, loc, ss[4].clone())), ]; - self.add_row(next_row, GateType::Zero, vec![]); + self.add_row(labels, loc, next_row, GateType::Zero, vec![]); } } - KimchiConstraint::EcEndoscale { + KimchiConstraint::EcEndoscale(EcEndoscaleInput { state, xs, ys, n_acc, - } => { + }) => { for round in state { let vars = vec![ - Some(self.reduce_to_var(round.xt)), - Some(self.reduce_to_var(round.yt)), + Some(self.reduce_to_var(labels, loc, round.xt)), + Some(self.reduce_to_var(labels, loc, round.yt)), None, None, - Some(self.reduce_to_var(round.xp)), - Some(self.reduce_to_var(round.yp)), - Some(self.reduce_to_var(round.n_acc)), - Some(self.reduce_to_var(round.xr)), - Some(self.reduce_to_var(round.yr)), - Some(self.reduce_to_var(round.s1)), - Some(self.reduce_to_var(round.s3)), - Some(self.reduce_to_var(round.b1)), - Some(self.reduce_to_var(round.b2)), - Some(self.reduce_to_var(round.b3)), - Some(self.reduce_to_var(round.b4)), + Some(self.reduce_to_var(labels, loc, round.xp)), + Some(self.reduce_to_var(labels, loc, round.yp)), + Some(self.reduce_to_var(labels, loc, round.n_acc)), + Some(self.reduce_to_var(labels, loc, round.xr)), + Some(self.reduce_to_var(labels, loc, round.yr)), + Some(self.reduce_to_var(labels, loc, round.s1)), + Some(self.reduce_to_var(labels, loc, round.s3)), + Some(self.reduce_to_var(labels, loc, round.b1)), + Some(self.reduce_to_var(labels, loc, round.b2)), + Some(self.reduce_to_var(labels, loc, round.b3)), + Some(self.reduce_to_var(labels, loc, round.b4)), ]; - self.add_row(vars, GateType::EndoMul, vec![]); + self.add_row(labels, loc, vars, GateType::EndoMul, vec![]); } // last row @@ -1292,38 +1624,271 @@ impl SnarkyConstraintSystem { None, None, None, - Some(self.reduce_to_var(xs)), - Some(self.reduce_to_var(ys)), - Some(self.reduce_to_var(n_acc)), + Some(self.reduce_to_var(labels, loc, xs)), + Some(self.reduce_to_var(labels, loc, ys)), + Some(self.reduce_to_var(labels, loc, n_acc)), ]; - self.add_row(vars, GateType::Zero, vec![]); + self.add_row(labels, loc, vars, GateType::Zero, vec![]); } - KimchiConstraint::EcEndoscalar { state } => { + KimchiConstraint::EcEndoscalar(state) => { for round in state { let vars = vec![ - Some(self.reduce_to_var(round.n0)), - Some(self.reduce_to_var(round.n8)), - Some(self.reduce_to_var(round.a0)), - Some(self.reduce_to_var(round.b0)), - Some(self.reduce_to_var(round.a8)), - Some(self.reduce_to_var(round.b8)), - Some(self.reduce_to_var(round.x0)), - Some(self.reduce_to_var(round.x1)), - Some(self.reduce_to_var(round.x2)), - Some(self.reduce_to_var(round.x3)), - Some(self.reduce_to_var(round.x4)), - Some(self.reduce_to_var(round.x5)), - Some(self.reduce_to_var(round.x6)), - Some(self.reduce_to_var(round.x7)), + Some(self.reduce_to_var(labels, loc, round.n0)), + Some(self.reduce_to_var(labels, loc, round.n8)), + Some(self.reduce_to_var(labels, loc, round.a0)), + Some(self.reduce_to_var(labels, loc, round.b0)), + Some(self.reduce_to_var(labels, loc, round.a8)), + Some(self.reduce_to_var(labels, loc, round.b8)), + Some(self.reduce_to_var(labels, loc, round.x0)), + Some(self.reduce_to_var(labels, loc, round.x1)), + Some(self.reduce_to_var(labels, loc, round.x2)), + Some(self.reduce_to_var(labels, loc, round.x3)), + Some(self.reduce_to_var(labels, loc, round.x4)), + Some(self.reduce_to_var(labels, loc, round.x5)), + Some(self.reduce_to_var(labels, loc, round.x6)), + Some(self.reduce_to_var(labels, loc, round.x7)), ]; - self.add_row(vars, GateType::EndoMulScalar, vec![]); + self.add_row(labels, loc, vars, GateType::EndoMulScalar, vec![]); } } + KimchiConstraint::RangeCheck(rows) => { + let rows: Result<[Vec; 4], _> = rows.try_into(); + let rows: Result<[[Cvar; 15]; 4], _> = rows.map(|rows| { + rows.map(|r| { + let r = r.try_into(); + match r { + Ok(r) => r, + Err(_) => { + panic!("size of row is != 15"); + } + } + }) + }); + let rows = match rows { + Ok(rows) => rows, + Err(_) => { + panic!("wrong number of rows"); + } + }; + + let vars = |cvars: [Cvar; 15]| { + cvars + .map(|v| self.reduce_to_var(labels, loc, v)) + .map(Some) + .to_vec() + }; + let [r0, r1, r2, r3] = rows.map(vars); + self.add_row(labels, loc, r0, GateType::RangeCheck0, vec![Field::zero()]); + self.add_row(labels, loc, r1, GateType::RangeCheck0, vec![Field::zero()]); + self.add_row(labels, loc, r2, GateType::RangeCheck1, vec![]); + self.add_row(labels, loc, r3, GateType::Zero, vec![]); + } } } + pub(crate) fn sponge_params(&self) -> mina_poseidon::poseidon::ArithmeticSpongeParams { + self.constants.poseidon.clone() + } } enum ConstantOrVar { Constant, Var(V), } + +impl BasicSnarkyConstraint> +where + F: PrimeField, +{ + pub fn check_constraint( + &self, + env: &impl WitnessGeneration, + ) -> Result<(), Box> { + let result = match self { + BasicSnarkyConstraint::Boolean(v) => { + let v = env.read_var(v); + if !(v.is_one() || v.is_zero()) { + Err(SnarkyRuntimeError::UnsatisfiedBooleanConstraint( + env.constraints_counter(), + v.to_string(), + )) + } else { + Ok(()) + } + } + BasicSnarkyConstraint::Equal(v1, v2) => { + let v1 = env.read_var(v1); + let v2 = env.read_var(v2); + if v1 != v2 { + Err(SnarkyRuntimeError::UnsatisfiedEqualConstraint( + env.constraints_counter(), + v1.to_string(), + v2.to_string(), + )) + } else { + Ok(()) + } + } + BasicSnarkyConstraint::Square(v1, v2) => { + let v1 = env.read_var(v1); + let v2 = env.read_var(v2); + let square = v1.square(); + if square != v2 { + Err(SnarkyRuntimeError::UnsatisfiedSquareConstraint( + env.constraints_counter(), + v1.to_string(), + v2.to_string(), + )) + } else { + Ok(()) + } + } + BasicSnarkyConstraint::R1CS(v1, v2, v3) => { + let v1 = env.read_var(v1); + let v2 = env.read_var(v2); + let v3 = env.read_var(v3); + let mul = v1 * v2; + if mul != v3 { + Err(SnarkyRuntimeError::UnsatisfiedR1CSConstraint( + env.constraints_counter(), + v1.to_string(), + v2.to_string(), + v3.to_string(), + )) + } else { + Ok(()) + } + } + }; + + result.map_err(Box::new) + } +} + +impl KimchiConstraint, F> +where + F: PrimeField, +{ + pub fn check_constraint( + &self, + env: &impl WitnessGeneration, + ) -> Result<(), Box> { + match self { + // we only check the basic gate + KimchiConstraint::Basic(BasicInput { + l: (c0, l_var), + r: (c1, r_var), + o: (c2, o_var), + m: c3, + c: c4, + }) => { + let l = env.read_var(l_var); + let r = env.read_var(r_var); + let o = env.read_var(o_var); + let res = *c0 * l + *c1 * r + *c2 * o + l * r * c3 + c4; + if !res.is_zero() { + // TODO: return different errors depending on the type of generic gate (e.g. addition, cst, mul, etc.) + return Err(Box::new(SnarkyRuntimeError::UnsatisfiedGenericConstraint( + c0.to_string(), + l.to_string(), + c1.to_string(), + r.to_string(), + c2.to_string(), + o.to_string(), + c3.to_string(), + c4.to_string(), + env.constraints_counter(), + ))); + } + } + + // we trust the witness generation to be correct for other gates, + // or that the gadgets will do the check + KimchiConstraint::Poseidon { .. } + | KimchiConstraint::Poseidon2 { .. } + | KimchiConstraint::EcAddComplete { .. } + | KimchiConstraint::EcScale { .. } + | KimchiConstraint::EcEndoscale { .. } + | KimchiConstraint::EcEndoscalar { .. } + | KimchiConstraint::RangeCheck { .. } => (), + }; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use mina_curves::pasta::{Fp, Vesta}; + + fn setup(public_input_size: usize) -> SnarkyConstraintSystem { + let constants = Constants::new::(); + let mut state = SnarkyConstraintSystem::::create(constants); + state.set_primary_input_size(public_input_size); + state + } + + #[test] + fn test_permutation_equal() { + let mut state = setup(0); + + let x = FieldVar::Var(0); + let y = FieldVar::Var(1); + let z = FieldVar::Var(2); + + let labels = &vec![]; + let loc = &Cow::Borrowed(""); + + // x * y = z + state.add_basic_snarky_constraint( + labels, + loc, + BasicSnarkyConstraint::R1CS(x.clone(), y.clone(), z), + ); + + // x = y + state.add_basic_snarky_constraint(labels, loc, BasicSnarkyConstraint::Equal(x, y)); + + let gates = state.finalize_and_get_gates(); + + for col in 0..PERMUTS { + assert_eq!(gates[0].wires[col].row, 0); + } + + assert_eq!(gates[0].wires[0].col, 1); + assert_eq!(gates[0].wires[1].col, 0); + } + + #[test] + fn test_permutation_public() { + let mut state = setup(1); + + let public = FieldVar::Var(0); + + let x = FieldVar::Var(1); + let y = FieldVar::Var(2); + + let labels = &vec![]; + let loc = &Cow::Borrowed(""); + + // x * y = z + state.add_basic_snarky_constraint( + labels, + loc, + BasicSnarkyConstraint::R1CS(x.clone(), y.clone(), public), + ); + + // x = y + state.add_basic_snarky_constraint(labels, loc, BasicSnarkyConstraint::Equal(x, y)); + + state.finalize(); + + let gates = state.finalize_and_get_gates(); + + assert_eq!(gates[1].wires[0].col, 1); + assert_eq!(gates[1].wires[1].col, 0); + + assert_eq!(gates[0].wires[0], Wire { row: 1, col: 2 }); + assert_eq!(gates[1].wires[2], Wire { row: 0, col: 0 }); + } +} diff --git a/kimchi/src/snarky/cvar.rs b/kimchi/src/snarky/cvar.rs new file mode 100644 index 0000000000..a26e9daab5 --- /dev/null +++ b/kimchi/src/snarky/cvar.rs @@ -0,0 +1,485 @@ +use crate::snarky::{ + boolean::Boolean, + constraint_system::SnarkyCvar, + runner::{RunState, WitnessGeneration}, + snarky_type::SnarkyType, +}; +use ark_ff::PrimeField; +use std::{ + borrow::Cow, + ops::{Add, Neg, Sub}, +}; + +use super::{ + constraint_system::BasicSnarkyConstraint, + errors::{SnarkyCompilationError, SnarkyResult}, + runner::Constraint, +}; + +/// A circuit variable represents a field element in the circuit. +/// Note that a [`FieldVar`] currently represents a small AST that can hide nested additions and multiplications of [`FieldVar`]s. +/// The reason behind this decision is that converting additions and multiplications directly into constraints is not optimal. +/// So we most often wait until some additions have been accumulated before reducing them down to constraints. +/// +/// Note: this design decision also leads to optimization issues, +/// when we end up cloning and then reducing the same mini-ASTs multiple times. +/// To force reduction of a [`FieldVar`] (most often before cloning), +/// the [Self::seal] function can be used. +#[derive(Clone, Debug)] +pub enum FieldVar +where + F: PrimeField, +{ + /// A constant (a field element). + Constant(F), + + /// A variable, tracked by a counter. + Var(usize), + + /// Addition of two [`FieldVar`]s. + Add(Box>, Box>), + + /// Multiplication of a [`FieldVar`] by a constant. + Scale(F, Box>), +} + +impl SnarkyCvar for FieldVar +where + F: PrimeField, +{ + type Field = F; + + fn to_constant_and_terms(&self) -> (Option, Vec<(Self::Field, usize)>) { + self.to_constant_and_terms() + } +} + +// TODO: create a real type for this +pub type Term = (F, usize); + +// TODO: create a real type for this +pub type ScaledCVar = (F, FieldVar); + +impl FieldVar +where + F: PrimeField, +{ + /// Converts a field element `c` into a [`FieldVar`]. + pub fn constant(c: F) -> Self { + FieldVar::Constant(c) + } + + /// Returns the zero constant. + pub fn zero() -> Self { + Self::constant(F::zero()) + } + + // TODO: should we cache results to avoid recomputation them again and again? + fn eval_inner(&self, state: &RunState, scale: F, res: &mut F) { + match self { + FieldVar::Constant(c) => { + *res += scale * c; + } + FieldVar::Var(v) => { + let v = state.read_var_idx(*v); + *res += scale * v; + } + FieldVar::Add(a, b) => { + a.eval_inner(state, scale, res); + b.eval_inner(state, scale, res); + } + FieldVar::Scale(s, v) => { + v.eval_inner(state, scale * s, res); + } + } + } + + /// Evaluate the field element associated to a variable (used during witness generation) + pub fn eval(&self, state: &RunState) -> F { + let mut res = F::zero(); + self.eval_inner(state, F::one(), &mut res); + res + } + + fn to_constant_and_terms_inner( + &self, + scale: F, + constant: F, + terms: Vec>, + ) -> (F, Vec>) { + match self { + FieldVar::Constant(c) => (constant + (scale * c), terms), + FieldVar::Var(v) => { + let mut new_terms = vec![(scale, *v)]; + new_terms.extend(terms); + (constant, new_terms) + } + FieldVar::Scale(s, t) => t.to_constant_and_terms_inner(scale * s, constant, terms), + FieldVar::Add(x1, x2) => { + let (c1, terms1) = x1.to_constant_and_terms_inner(scale, constant, terms); + x2.to_constant_and_terms_inner(scale, c1, terms1) + } + } + } + + pub fn to_constant_and_terms(&self) -> (Option, Vec>) { + let (constant, terms) = self.to_constant_and_terms_inner(F::one(), F::zero(), vec![]); + let constant = if constant.is_zero() { + None + } else { + Some(constant) + }; + (constant, terms) + } + + pub fn scale(&self, scalar: F) -> Self { + if scalar.is_zero() { + return FieldVar::Constant(scalar); + } else if scalar.is_one() { + return self.clone(); + } + + match self { + FieldVar::Constant(x) => FieldVar::Constant(*x * scalar), + FieldVar::Scale(s, v) => FieldVar::Scale(*s * scalar, v.clone()), + FieldVar::Var(_) | FieldVar::Add(..) => FieldVar::Scale(scalar, Box::new(self.clone())), + } + } + + pub fn linear_combination(terms: &[ScaledCVar]) -> Self { + let mut res = FieldVar::zero(); + for (cst, term) in terms { + res = res.add(&term.scale(*cst)); + } + res + } + + pub fn sum(vs: &[&Self]) -> Self { + let terms: Vec<_> = vs.iter().map(|v| (F::one(), (*v).clone())).collect(); + Self::linear_combination(&terms) + } + + pub fn mul( + &self, + other: &Self, + label: Option>, + loc: Cow<'static, str>, + cs: &mut RunState, + ) -> SnarkyResult { + let res = match (self, other) { + (FieldVar::Constant(x), FieldVar::Constant(y)) => FieldVar::Constant(*x * y), + + (FieldVar::Constant(cst), cvar) | (cvar, FieldVar::Constant(cst)) => cvar.scale(*cst), + + (_, _) => { + let self_clone = self.clone(); + let other_clone = other.clone(); + let res: FieldVar = cs.compute(loc.clone(), move |env| { + let x: F = env.read_var(&self_clone); + let y: F = env.read_var(&other_clone); + x * y + })?; + + let label = label.or(Some("checked_mul".into())); + + cs.assert_r1cs(label, loc, self.clone(), other.clone(), res.clone())?; + + res + } + }; + + Ok(res) + } + + /** [equal_constraints z z_inv r] asserts that + if z = 0 then r = 1, or + if z <> 0 then r = 0 and z * z_inv = 1 + */ + fn equal_constraints( + state: &mut RunState, + loc: Cow<'static, str>, + z: Self, + z_inv: Self, + r: Self, + ) -> SnarkyResult<()> { + let one_minus_r = FieldVar::Constant(F::one()) - &r; + let zero = FieldVar::zero(); + state.assert_r1cs( + Some("equals_1".into()), + loc.clone(), + z_inv, + z.clone(), + one_minus_r, + )?; + state.assert_r1cs(Some("equals_2".into()), loc, r, z, zero) + } + + /** [equal_vars z] computes [(r, z_inv)] that satisfy the constraints in + [equal_constraints z z_inv r]. + + In particular, [r] is [1] if [z = 0] and [0] otherwise. + */ + fn equal_vars(env: &dyn WitnessGeneration, z: &FieldVar) -> (F, F) { + let z: F = env.read_var(z); + if let Some(z_inv) = z.inverse() { + (F::zero(), z_inv) + } else { + (F::one(), F::zero()) + } + } + + pub fn equal( + &self, + state: &mut RunState, + loc: Cow<'static, str>, + other: &FieldVar, + ) -> SnarkyResult> { + let res = match (self, other) { + (FieldVar::Constant(x), FieldVar::Constant(y)) => { + if x == y { + Boolean::true_() + } else { + Boolean::false_() + } + } + _ => { + let z = self - other; + let z_clone = z.clone(); + let (res, z_inv): (FieldVar, FieldVar) = + state.compute(loc.clone(), move |env| Self::equal_vars(env, &z_clone))?; + Self::equal_constraints(state, loc, z, z_inv, res.clone())?; + + Boolean::create_unsafe(res) + } + }; + + Ok(res) + } + + /// Seals the value of a variable. + /// + /// As a [`FieldVar`] can represent an AST, + /// it might not be a good idea to clone it and reuse it in several places. + /// This is because the exact same reduction that will eventually happen on each clone + /// will end up creating the same set of constraints multiple times in the circuit. + /// + /// It is useful to call on a variable that represents a long computation + /// that hasn't been constrained yet (e.g. by an assert call, or a call to a custom gate), + /// before using it further in the circuit. + pub fn seal(&self, state: &mut RunState, loc: Cow<'static, str>) -> SnarkyResult { + match self.to_constant_and_terms() { + (None, terms) if terms.len() == 1 && terms[0].0.is_one() => { + Ok(FieldVar::Var(terms[0].1)) + } + (Some(c), terms) if terms.is_empty() => Ok(FieldVar::Constant(c)), + _ => { + let y: FieldVar = state.compute(loc.clone(), |env| env.read_var(self))?; + // this call will reduce [self] + self.assert_equals(state, loc, &y)?; + + Ok(y) + } + } + } +} + +// +// Traits +// + +impl SnarkyType for FieldVar +where + F: PrimeField, +{ + type Auxiliary = (); + + type OutOfCircuit = F; + + const SIZE_IN_FIELD_ELEMENTS: usize = 1; + + fn to_cvars(&self) -> (Vec>, Self::Auxiliary) { + (vec![self.clone()], ()) + } + + fn from_cvars_unsafe(cvars: Vec>, _aux: Self::Auxiliary) -> Self { + assert_eq!(cvars.len(), 1); + cvars[0].clone() + } + + fn check(&self, _cs: &mut RunState, _loc: Cow<'static, str>) -> SnarkyResult<()> { + // do nothing + Ok(()) + } + + fn constraint_system_auxiliary() -> Self::Auxiliary {} + + fn value_to_field_elements(x: &Self::OutOfCircuit) -> (Vec, Self::Auxiliary) { + (vec![*x], ()) + } + + fn value_of_field_elements(fields: Vec, _aux: Self::Auxiliary) -> Self::OutOfCircuit { + assert_eq!(fields.len(), 1); + + fields[0] + } +} + +// +// Assertions +// + +impl FieldVar +where + F: PrimeField, +{ + pub fn assert_equals( + &self, + state: &mut RunState, + loc: Cow<'static, str>, + other: &FieldVar, + ) -> SnarkyResult<()> { + match (self, other) { + (FieldVar::Constant(x), FieldVar::Constant(y)) => { + if x == y { + Ok(()) + } else { + Err( + state.compilation_error(SnarkyCompilationError::ConstantAssertEquals( + x.to_string(), + y.to_string(), + )), + ) + } + } + (_, _) => state.add_constraint( + Constraint::BasicSnarkyConstraint(BasicSnarkyConstraint::Equal( + self.clone(), + other.clone(), + )), + Some("assert equals".into()), + loc, + ), + } + } +} + +// +// Operations +// + +impl Add for &FieldVar +where + F: PrimeField, +{ + type Output = FieldVar; + + fn add(self, other: Self) -> Self::Output { + match (self, other) { + (FieldVar::Constant(x), y) | (y, FieldVar::Constant(x)) if x.is_zero() => y.clone(), + (FieldVar::Constant(x), FieldVar::Constant(y)) => FieldVar::Constant(*x + y), + (_, _) => FieldVar::Add(Box::new(self.clone()), Box::new(other.clone())), + } + } +} + +impl Add for FieldVar +where + F: PrimeField, +{ + type Output = FieldVar; + + fn add(self, other: Self) -> Self::Output { + self.add(&other) + } +} + +impl<'a, F> Add<&'a Self> for FieldVar +where + F: PrimeField, +{ + type Output = FieldVar; + + fn add(self, other: &Self) -> Self::Output { + (&self).add(other) + } +} + +impl Add> for &FieldVar +where + F: PrimeField, +{ + type Output = FieldVar; + + fn add(self, other: FieldVar) -> Self::Output { + self.add(&other) + } +} + +impl Sub for &FieldVar +where + F: PrimeField, +{ + type Output = FieldVar; + + fn sub(self, other: Self) -> Self::Output { + match (self, other) { + (FieldVar::Constant(x), FieldVar::Constant(y)) => FieldVar::Constant(*x - y), + _ => self.add(&other.scale(-F::one())), + } + } +} + +impl Sub> for FieldVar +where + F: PrimeField, +{ + type Output = FieldVar; + + fn sub(self, other: FieldVar) -> Self::Output { + self.sub(&other) + } +} + +impl<'a, F> Sub<&'a FieldVar> for FieldVar +where + F: PrimeField, +{ + type Output = FieldVar; + + fn sub(self, other: &Self) -> Self::Output { + (&self).sub(other) + } +} + +impl Sub> for &FieldVar +where + F: PrimeField, +{ + type Output = FieldVar; + + fn sub(self, other: FieldVar) -> Self::Output { + self.sub(&other) + } +} + +impl Neg for &FieldVar +where + F: PrimeField, +{ + type Output = FieldVar; + + fn neg(self) -> Self::Output { + self.scale(-F::one()) + } +} + +impl Neg for FieldVar +where + F: PrimeField, +{ + type Output = FieldVar; + + fn neg(self) -> Self::Output { + self.scale(-F::one()) + } +} diff --git a/kimchi/src/snarky/errors.rs b/kimchi/src/snarky/errors.rs new file mode 100644 index 0000000000..c267bf5126 --- /dev/null +++ b/kimchi/src/snarky/errors.rs @@ -0,0 +1,120 @@ +use std::{backtrace::Backtrace, borrow::Cow}; + +use thiserror::Error; + +/// A result type for Snarky errors. +pub type SnarkyResult = std::result::Result>; + +/// A result type for Snarky runtime errors. +pub type SnarkyRuntimeResult = std::result::Result>; + +/// A result type for Snarky compilation errors. +pub type SnarkyCompileResult = std::result::Result; + +#[derive(Debug, Error)] +#[error("an error ocurred in snarky")] +pub struct RealSnarkyError { + /// The actual error. + pub source: SnarkyError, + + /// A location string, usually a file name and line number. + /// Location information is usually useful for: + /// + /// - assert that failed (so we need to keep track of the location that created each gates) + pub loc: Option, + + /// A stack of labels, + /// where each label represents an important function call. + pub label_stack: Option>>, + + /// A Rust backtrace of where the error came from. + /// This can be especially useful for debugging snarky when wrapped by a different language implementation. + backtrace: Option, +} + +impl RealSnarkyError { + /// Creates a new [RealSnarkyError]. + pub fn new(source: SnarkyError) -> Self { + let backtrace = std::env::var("SNARKY_BACKTRACE") + .ok() + .map(|_| Backtrace::capture()); + Self { + source, + loc: None, + label_stack: None, + backtrace, + } + } + + /// Creates a new [RealSnarkyError]. + pub fn new_with_ctx( + source: SnarkyError, + loc: Cow<'static, str>, + label_stack: Vec>, + ) -> Self { + let backtrace = std::env::var("SNARKY_BACKTRACE") + .ok() + .map(|_| Backtrace::capture()); + + Self { + source, + loc: Some(loc.to_string()), + label_stack: Some(label_stack), + backtrace, + } + } +} + +/// Snarky errors can come from either a compilation or runtime error. +#[derive(Debug, Clone, Error)] +pub enum SnarkyError { + #[error("a compilation error occurred")] + CompilationError(SnarkyCompilationError), + + #[error("a runtime error occurred")] + RuntimeError(SnarkyRuntimeError), +} + +/// Errors that can occur during compilation of a circuit. +#[derive(Debug, Clone, Error)] +pub enum SnarkyCompilationError { + #[error("the two values were not equal: {0} != {1}")] + ConstantAssertEquals(String, String), +} + +/// Errors that can occur during runtime (proving). +#[derive(Debug, Clone, Error)] +pub enum SnarkyRuntimeError { + #[error( + "unsatisfied constraint #{8}: `{0} * {1} + {2} * {3} + {4} * {5} + {6} * {1} * {3} + {7} != 0`" + )] + UnsatisfiedGenericConstraint( + String, + String, + String, + String, + String, + String, + String, + String, + usize, + ), + + #[error("unsatisfied constraint #{0}: {1} is not a boolean (0 or 1)")] + UnsatisfiedBooleanConstraint(usize, String), + + #[error("unsatisfied constraint #{0}: {1} is not equal to {2}")] + UnsatisfiedEqualConstraint(usize, String, String), + + #[error("unsatisfied constraint #{0}: {1}^2 is not equal to {2}")] + UnsatisfiedSquareConstraint(usize, String, String), + + #[error("unsatisfied constraint #{0}: {1} * {2} is not equal to {3}")] + UnsatisfiedR1CSConstraint(usize, String, String, String), + + #[error("the number of public inputs passed ({0}) does not match the number of public inputs expected ({1})")] + PubInputMismatch(usize, usize), + + #[error("the value returned by the circuit has an incorrect number of field variables. It hardcoded {1} field variables, but returned {0}")] + CircuitReturnVar(usize, usize), +} diff --git a/kimchi/src/snarky/folding.rs b/kimchi/src/snarky/folding.rs new file mode 100644 index 0000000000..ae28384a02 --- /dev/null +++ b/kimchi/src/snarky/folding.rs @@ -0,0 +1,217 @@ +use self::instance::{Instance, RelaxedInstance}; +use super::{poseidon::DuplexState, prelude::*}; +use crate::{ + curve::KimchiCurve, + loc, + snarky::{api::SnarkyCircuit, cvar::FieldVar, runner::RunState}, +}; +use ark_ec::AffineRepr; +use ark_ff::{BigInteger, One, PrimeField}; +use mina_curves::pasta::Fp; +use poly_commitment::{ipa::OpeningProof, OpenProof}; +use std::{marker::PhantomData, ops::Add}; + +mod instance; + +const CHALLENGE_BITS: usize = 127; + +struct FoldingCircuit, const N: usize> { + _field: PhantomData, + _proof: PhantomData

, + /// Commitment sets and their sizes + commitments: Vec, + /// Challenges sets and their sizes + challenges: Vec, +} + +/// The value of the input an output of the function being folded +struct Argument([F; N]); + +type Point = [F; 2]; + +/// Represents the result of a (Poseidon) hash, encoded in the circuit +type Hash = FieldVar; + +/// Represents an element of the other curve's field. As the other field can +/// have a higher order, it could require more than one limbs. We suppose the +/// other field values can be encoded on two limbs. +#[derive(Debug, Clone)] +pub struct ForeignElement([F; 2]); + +/// A challenge encoded on 127 bits +#[derive(Debug, Clone)] +struct SmallChallenge(FieldVar); + +#[derive(Debug, Clone)] +pub struct FullChallenge(ForeignElement); + +pub struct Private { + /// First input + z_0: Argument, + /// Input for this step and output of the previous one + z_i: Argument, + /// All instances accumulated so far + u_acc: RelaxedInstance, + /// Newest instance to be folded into the accumulated instance + u_i: Instance, + /// 2 commitments to error terms to be used when folding the error column + t: [Point; 2], + i: F, +} + +/// This should run the inner circuit and provide its output, in our case it may be simpler +/// to just run the circuit separatly and provide the variables pointing to the output +fn apply( + _z_i: Argument, N>, +) -> Argument, N> { + // TODO + todo!() +} + +type F = ::ScalarField; + +fn challenge_linear_combination( + _full: FullChallenge>, + _small: SmallChallenge, + _combiner: &SmallChallenge, +) -> FullChallenge> { + // TODO + todo!() +} + +fn commitment_linear_combination( + _a: Point>, + _b: Point>, + _combiner: &SmallChallenge, +) -> Point> { + // TODO + todo!() +} +fn ec_add(_a: Point>, _b: Point>) -> Point> { + // TODO + todo!() +} +fn ec_scale( + _point: Point>, + _scalar: &SmallChallenge, +) -> Point> { + // TODO + todo!() +} + +/// Trims to 127 bits +fn trim( + sys: &mut RunState, + v: &FieldVar, + base: &FieldVar, +) -> SnarkyResult> { + let (high, low): (FieldVar, FieldVar) = sys.compute(loc!(), |wit| { + let val = wit.read_var(v); + let mut high = val.into_bigint(); + high.divn(CHALLENGE_BITS as u32); + let mut low = val.into_bigint(); + low.sub_with_borrow(&high); + (F::from_bigint(high).unwrap(), F::from_bigint(low).unwrap()) + })?; + let composition = high.mul(base, None, loc!(), sys)? + &low; + // TODO: constraint low to 127 bits + v.assert_equals(sys, loc!(), &composition)?; + Ok(low) +} + +/// Compute H(i, z_0, z_1, u) and keep the log2(base) bits of the result. +/// [base] is supposed to be a power of 2 +fn hash( + sys: &mut RunState, + i: FieldVar, + z_0: &Argument, N>, + z_i: &Argument, N>, + u: &RelaxedInstance>, + base: &FieldVar, +) -> SnarkyResult> { + let mut sponge = DuplexState::new(); + + sponge.absorb(sys, loc!(), &[i]); + sponge.absorb(sys, loc!(), &z_0.0); + sponge.absorb(sys, loc!(), &z_i.0); + u.absorb_into_sponge(&mut sponge, sys); + + let hash = sponge.squeeze(sys, loc!()); + trim(sys, &hash, base) +} + +impl, const N: usize> SnarkyCircuit for FoldingCircuit { + type Curve = C; + type Proof = P; + + type PrivateInput = Private, N>; + type PublicInput = (); + type PublicOutput = [Hash>; 2]; + + /// Implement the IVC circuit, see https://eprint.iacr.org/2021/370.pdf, Fig + /// 4, page 18 + fn circuit( + &self, + sys: &mut RunState, + _public: Self::PublicInput, + private: Option<&Self::PrivateInput>, + ) -> SnarkyResult { + let one = FieldVar::constant(F::::one()); + //dividing by this should make a number of 127 bits or less zero + let power = 1u128 << 127; + let power = F::::from(power); + let power = FieldVar::Constant(power); + + let u_i: Instance>> = Instance::compute(private, sys, &self.commitments)?; + let hash1 = u_i.hash1.clone(); + let hash2 = u_i.hash2.clone(); + let u_acc: RelaxedInstance>> = + RelaxedInstance::compute(private, sys, &self.commitments, &self.challenges)?; + + //check hash of inputs + let i: FieldVar> = sys.compute(loc!(), |_| private.unwrap().i)?; + let z_0: [FieldVar>; N] = sys.compute(loc!(), |_| private.unwrap().z_0.0)?; + let z_0 = Argument(z_0); + let z_i: [FieldVar>; N] = sys.compute(loc!(), |_| private.unwrap().z_i.0)?; + let z_i = Argument(z_i); + + let inputs_hash = hash(sys, i.clone(), &z_0, &z_i, &u_acc, &power)?; + + inputs_hash.assert_equals(sys, loc!(), &hash1)?; + + let t = sys.compute(loc!(), |_| private.unwrap().t)?; + let u_acc_new = u_acc.fold(sys, u_i, t, &power)?; + let z_next = apply(z_i); + let i = i.add(one); + + // here we should output this in some way so that it can be used as the input + // for the next iteration, a Mutex or similar in &self could be used + let new_hash = hash(sys, i, &z_0, &z_next, &u_acc_new, &power)?; + + Ok([hash2, new_hash]) + } +} + +#[allow(dead_code)] +fn example(private: Private) { + use mina_curves::pasta::{Vesta, VestaParameters}; + use mina_poseidon::{ + constants::PlonkSpongeConstantsKimchi, + sponge::{DefaultFqSponge, DefaultFrSponge}, + }; + type BaseSponge = DefaultFqSponge; + type ScalarSponge = DefaultFrSponge; + + let circuit = FoldingCircuit { + _field: PhantomData, + _proof: PhantomData::>, + commitments: vec![15], + challenges: vec![1, 5], + }; + let (mut prover_index, verifier_index) = circuit.compile_to_indexes().unwrap(); + let (proof, public_output) = prover_index + .prove::((), private, true) + .unwrap(); + + verifier_index.verify::(proof, (), *public_output); +} diff --git a/kimchi/src/snarky/folding/instance.rs b/kimchi/src/snarky/folding/instance.rs new file mode 100644 index 0000000000..ed7a6a5e95 --- /dev/null +++ b/kimchi/src/snarky/folding/instance.rs @@ -0,0 +1,302 @@ +use crate::{ + loc, + snarky::{ + folding::{ForeignElement, FullChallenge, Point, Private}, + poseidon::DuplexState, + snarky_type::SnarkyType, + }, + FieldVar, RunState, SnarkyResult, +}; +use ark_ff::PrimeField; +use std::iter::successors; + +use super::{ + challenge_linear_combination, commitment_linear_combination, ec_add, ec_scale, trim, + SmallChallenge, +}; + +#[derive(Debug, Clone)] +pub struct WitnessCommitments(Vec>); + +#[derive(Debug, Clone)] +pub struct Instance { + pub hash1: F, + pub hash2: F, + // pub witness_commitment: Point, + pub witness_commitments: Vec>, +} +impl Instance> { + fn absorb_into_sponge(&self, sponge: &mut DuplexState, sys: &mut RunState) { + for commitment_set in &self.witness_commitments { + for commitment in &commitment_set.0 { + sponge.absorb(sys, loc!(), commitment); + } + } + sponge.absorb(sys, loc!(), &[self.hash1.clone(), self.hash2.clone()]); + } +} +impl Instance { + pub fn compute( + private_input: Option<&Private>, + sys: &mut RunState, + commitment_sets: &[usize], + ) -> SnarkyResult>> + where + F: PrimeField, + { + let (hash1, hash2) = sys.compute(loc!(), |_| { + let ins = private_input.unwrap().u_i.clone(); + (ins.hash1, ins.hash2) + })?; + let mut witness_commitments = Vec::with_capacity(commitment_sets.len()); + for (i, set_size) in commitment_sets.iter().enumerate() { + let mut set = Vec::with_capacity(*set_size); + for j in 0..*set_size { + let commitment = sys.compute(loc!(), |_| { + let a: Point = private_input.unwrap().u_i.witness_commitments[i].0[j]; + a + })?; + set.push(commitment); + } + witness_commitments.push(WitnessCommitments(set)) + } + Ok(Instance { + hash1, + hash2, + witness_commitments, + }) + } +} +#[derive(Debug, Clone)] +struct Challenges(Vec>); + +#[derive(Debug, Clone)] +pub struct RelaxedInstance { + pub hash1: FullChallenge, + pub hash2: FullChallenge, + pub witness_commitments: Vec>, + u: FullChallenge, + error_commitment: Point, + challenges: Vec>, +} + +impl RelaxedInstance> { + pub fn absorb_into_sponge(&self, sponge: &mut DuplexState, sys: &mut RunState) { + for commitment_set in &self.witness_commitments { + for commitment in &commitment_set.0 { + sponge.absorb(sys, loc!(), commitment); + } + } + sponge.absorb(sys, loc!(), &self.hash1.0 .0); + sponge.absorb(sys, loc!(), &self.hash2.0 .0); + + sponge.absorb(sys, loc!(), &self.u.0 .0); + sponge.absorb(sys, loc!(), &self.error_commitment); + for set in self.challenges.iter() { + for challenge in set.0.iter() { + sponge.absorb(sys, loc!(), &challenge.0 .0); + } + } + } + + /// See https://eprint.iacr.org/2021/370.pdf, page 15 + /// Fold the circuit described by `sys` with the other circuit `other`. + pub fn fold( + self, + sys: &mut RunState, + other: Instance>, + error_terms: [Point>; 2], + base: &FieldVar, + ) -> SnarkyResult { + // let mut state = DuplexState::new(); + let mut challenge_generator = ChallengeGenerator::new(sys, &self, &other, None); + let r = challenge_generator.squeeze_challenge(sys, base)?; + let hash1 = challenge_linear_combination(self.hash1, SmallChallenge(other.hash1), &r); + let hash2 = challenge_linear_combination(self.hash2, SmallChallenge(other.hash2), &r); + // Combining the witnesses commitments, see W <- W1 + r W2 + let witness_commitments = self + .witness_commitments + .into_iter() + .zip(other.witness_commitments) + .map(|(a, b)| { + let set = + a.0.into_iter() + .zip(b.0) + .map(|(a, b)| commitment_linear_combination(a, b, &r)); + WitnessCommitments(set.collect()) + }) + .collect(); + let one = FieldVar::constant(F::one()); + let u = challenge_linear_combination(self.u, SmallChallenge(one.clone()), &r); + + let rr = r.0.mul(&r.0, None, loc!(), sys)?; + let [t1, t2] = error_terms; + let t1 = ec_scale(t1, &r); + let t2 = ec_scale(t2, &SmallChallenge(rr)); + let error_commitment = ec_add(t1, t2); + let error_commitment = ec_add(self.error_commitment, error_commitment); + + let mut new_sets = Vec::with_capacity(self.challenges.len()); + for _ in 0..self.challenges.len() { + let chall = challenge_generator.squeeze_challenge(sys, base)?; + new_sets.push(chall); + } + let mut new_sets: Vec>> = self + .challenges + .iter() + .zip(new_sets) + .map(|(acc_set, new)| { + successors(Some(one.clone()), |last| { + Some(last.mul(&new.0, None, loc!(), sys).unwrap()) + }) + .take(acc_set.0.len()) + .collect() + }) + .collect(); + for set in new_sets.iter_mut() { + for challenge in set.iter_mut() { + let mut trimed = trim(sys, challenge, base)?; + std::mem::swap(challenge, &mut trimed); + } + } + let challenges = self + .challenges + .into_iter() + .zip(new_sets) + .map(|(a, b)| { + let set = + a.0.into_iter() + .zip(b) + .map(|(a, b)| challenge_linear_combination(a, SmallChallenge(b), &r)) + .collect(); + Challenges(set) + }) + .collect(); + + Ok(RelaxedInstance { + hash1, + hash2, + witness_commitments, + u, + error_commitment, + challenges, + }) + } +} + +impl RelaxedInstance { + pub fn compute( + private_input: Option<&Private>, + sys: &mut RunState, + commitment_sets: &[usize], + challenge_sets: &[usize], + ) -> SnarkyResult>> + where + F: PrimeField, + { + // let instance = Instance::compute(private_input, sys, commitment_sets)?; + let mut witness_commitments = Vec::with_capacity(commitment_sets.len()); + for (i, set_size) in commitment_sets.iter().enumerate() { + let mut set = Vec::with_capacity(*set_size); + for j in 0..*set_size { + let commitment = sys.compute(loc!(), |_| { + let a: Point = private_input.unwrap().u_i.witness_commitments[i].0[j]; + a + })?; + set.push(commitment); + } + witness_commitments.push(WitnessCommitments(set)) + } + let hash1 = sys.compute(loc!(), |_| private_input.unwrap().u_acc.hash1.clone())?; + let hash2 = sys.compute(loc!(), |_| private_input.unwrap().u_acc.hash2.clone())?; + + let u = sys.compute(loc!(), |_| private_input.unwrap().u_acc.u.clone())?; + let error_commitment = sys.compute(loc!(), |_| { + let e: Point = private_input.unwrap().u_acc.error_commitment; + e + })?; + let mut challenges = Vec::with_capacity(challenge_sets.len()); + for (i, size) in challenge_sets.iter().enumerate() { + let mut set = Vec::with_capacity(*size); + for j in 0..*size { + let challenge = sys.compute(loc!(), |_| { + let challenge: FullChallenge = + private_input.unwrap().u_acc.challenges[i].0[j].clone(); + challenge + }); + set.push(challenge?); + } + challenges.push(Challenges(set)); + } + Ok(RelaxedInstance { + witness_commitments, + hash1, + hash2, + u, + error_commitment, + challenges, + }) + } +} +impl SnarkyType for FullChallenge> { + type Auxiliary = (); + + type OutOfCircuit = FullChallenge; + + const SIZE_IN_FIELD_ELEMENTS: usize = 2; + + fn to_cvars(&self) -> (Vec>, Self::Auxiliary) { + (self.0 .0.to_vec(), ()) + } + + fn from_cvars_unsafe(cvars: Vec>, _aux: Self::Auxiliary) -> Self { + let chall: [FieldVar; 2] = cvars.try_into().unwrap(); + FullChallenge(ForeignElement(chall)) + } + + fn check( + &self, + _cs: &mut RunState, + _loc: std::borrow::Cow<'static, str>, + ) -> SnarkyResult<()> { + //TODO: maybe check the size of each limb + Ok(()) + } + + fn constraint_system_auxiliary() -> Self::Auxiliary {} + + fn value_to_field_elements(value: &Self::OutOfCircuit) -> (Vec, Self::Auxiliary) { + (value.0 .0.to_vec(), ()) + } + + fn value_of_field_elements(fields: Vec, _aux: Self::Auxiliary) -> Self::OutOfCircuit { + let chall: [F; 2] = fields.try_into().unwrap(); + FullChallenge(ForeignElement(chall)) + } +} + +struct ChallengeGenerator { + state: DuplexState, +} + +impl ChallengeGenerator { + fn new( + sys: &mut RunState, + relaxed: &RelaxedInstance>, + other: &Instance>, + initial_state: Option>, + ) -> Self { + let mut state = initial_state.unwrap_or_default(); + relaxed.absorb_into_sponge(&mut state, sys); + other.absorb_into_sponge(&mut state, sys); + ChallengeGenerator { state } + } + fn squeeze_challenge( + &mut self, + sys: &mut RunState, + base: &FieldVar, + ) -> SnarkyResult> { + let challenge = self.state.squeeze(sys, loc!()); + Ok(SmallChallenge(trim(sys, &challenge, base)?)) + } +} diff --git a/kimchi/src/snarky/mod.rs b/kimchi/src/snarky/mod.rs index 15698a11cf..48d1773e51 100644 --- a/kimchi/src/snarky/mod.rs +++ b/kimchi/src/snarky/mod.rs @@ -1,3 +1,34 @@ +// TODO: uncomment +// #![deny(missing_docs)] + +//! Snarky is the front end to kimchi, allowing users to write their own programs and convert them to kimchi circuits. +//! +//! See the Mina book for more detail about this implementation. +//! +//! See the `tests.rs` file for examples of how to use snarky. + +pub mod api; pub mod asm; +pub mod boolean; pub mod constants; pub mod constraint_system; +pub mod cvar; +pub mod errors; +pub mod folding; +pub mod poseidon; +pub(crate) mod range_checks; +pub mod runner; +pub mod snarky_type; +pub mod union_find; + +#[cfg(test)] +mod tests; + +/// A handy module that you can import the content of to easily use snarky. +pub mod prelude { + use super::*; + pub use crate::loc; + pub use cvar::FieldVar; + pub use errors::SnarkyResult; + pub use runner::RunState; +} diff --git a/kimchi/src/snarky/poseidon.rs b/kimchi/src/snarky/poseidon.rs new file mode 100644 index 0000000000..30e1d728d3 --- /dev/null +++ b/kimchi/src/snarky/poseidon.rs @@ -0,0 +1,203 @@ +//! Functions associated to the Poseidon hash function. + +use crate::{ + circuits::polynomials::poseidon::{ROUNDS_PER_HASH, ROUNDS_PER_ROW, SPONGE_WIDTH}, + snarky::{ + constraint_system::KimchiConstraint, + prelude::{FieldVar, RunState}, + runner::Constraint, + }, +}; +use ark_ff::PrimeField; +use itertools::Itertools; +use mina_poseidon::{ + constants::PlonkSpongeConstantsKimchi, permutation::full_round, + poseidon::ArithmeticSpongeParams, +}; +use std::{borrow::Cow, iter::successors}; + +use super::constraint_system::PoseidonInput; + +pub fn poseidon( + runner: &mut RunState, + loc: Cow<'static, str>, + preimage: (FieldVar, FieldVar), +) -> (FieldVar, FieldVar) { + let initial_state = [preimage.0, preimage.1, FieldVar::zero()]; + let (constraint, hash) = { + let params = runner.poseidon_params(); + let mut iter = successors((initial_state, 0_usize).into(), |(prev, i)| { + //this case may justify moving to Cow + let state = round(runner, loc.clone(), prev, *i, ¶ms); + Some((state, i + 1)) + }) + .take(ROUNDS_PER_HASH + 1) + .map(|(r, _)| r); + + let states: Vec<_> = iter + .by_ref() + .take(ROUNDS_PER_HASH) + .chunks(ROUNDS_PER_ROW) + .into_iter() + .flat_map(|mut it| { + let mut n = || it.next().unwrap(); + let (r0, r1, r2, r3, r4) = (n(), n(), n(), n(), n()); + [r0, r4, r1, r2, r3].into_iter() + }) + .collect_vec(); + let last = iter.next().unwrap(); + let hash = { + let [a, b, _] = last.clone(); + (a, b) + }; + let constraint = Constraint::KimchiConstraint(KimchiConstraint::Poseidon2(PoseidonInput { + states: states.into_iter().map(|s| s.to_vec()).collect(), + last: last.to_vec(), + })); + (constraint, hash) + }; + + runner + .add_constraint(constraint, Some("Poseidon".into()), loc) + .expect("compiler bug"); + + hash +} + +fn round( + runner: &mut RunState, + loc: Cow<'static, str>, + elements: &[FieldVar; SPONGE_WIDTH], + round: usize, + params: &ArithmeticSpongeParams, +) -> [FieldVar; SPONGE_WIDTH] { + runner + .compute(loc, |env| { + let mut state = elements.clone().map(|var| env.read_var(&var)).to_vec(); + full_round::(params, &mut state, round); + state.try_into().unwrap() + }) + .expect("compiler bug") +} + +// +// Duplex API +// + +/// A duplex construction allows one to absorb and squeeze data alternatively. +pub struct DuplexState +where + F: PrimeField, +{ + rev_queue: Vec>, + absorbing: bool, + squeezed: Option>, + state: [FieldVar; 3], +} + +impl Default for DuplexState +where + F: PrimeField, +{ + fn default() -> Self { + let zero = FieldVar::zero(); + let state = [zero.clone(), zero.clone(), zero]; + DuplexState { + rev_queue: vec![], + absorbing: true, + squeezed: None, + state, + } + } +} + +/// The rate of the sponge. +/// The part that is modified when absorbing, and released when squeezing. +/// Unlike the capacity that must not be touched. +const RATE_SIZE: usize = 2; + +impl DuplexState +where + F: PrimeField, +{ + /// Creates a new sponge. + pub fn new() -> DuplexState { + Default::default() + } + + /// Absorb. + pub fn absorb( + &mut self, + sys: &mut RunState, + loc: Cow<'static, str>, + inputs: &[FieldVar], + ) { + // no need to permute to switch to absorbing + if !self.absorbing { + assert!(self.rev_queue.is_empty()); + self.squeezed = None; + self.absorbing = true; + } + + // absorb + for input in inputs { + // we only permute when we try to absorb too much (we lazy) + if self.rev_queue.len() == RATE_SIZE { + let left = self.rev_queue.pop().unwrap(); + let right = self.rev_queue.pop().unwrap(); + self.state[0] = &self.state[0] + left; + self.state[1] = &self.state[1] + right; + self.permute(sys, loc.clone()); + } + + self.rev_queue.insert(0, input.clone()); + } + } + + /// Permute. You should most likely not use this function directly, + /// and use [Self::absorb] and [Self::squeeze] instead. + fn permute( + &mut self, + sys: &mut RunState, + loc: Cow<'static, str>, + ) -> (FieldVar, FieldVar) { + let left = self.state[0].clone(); + let right = self.state[1].clone(); + sys.poseidon(loc, (left, right)) + } + + /// Squeeze. + pub fn squeeze(&mut self, sys: &mut RunState, loc: Cow<'static, str>) -> FieldVar { + // if we're switching to squeezing, don't forget about the queue + if self.absorbing { + assert!(self.squeezed.is_none()); + if let Some(left) = self.rev_queue.pop() { + self.state[0] = &self.state[0] + left; + } + if let Some(right) = self.rev_queue.pop() { + self.state[1] = &self.state[1] + right; + } + self.absorbing = false; + } + + // if we still have some left over, release that + if let Some(squeezed) = self.squeezed.take() { + return squeezed; + } + + // otherwise permute and squeeze + let (left, right) = self.permute(sys, loc); + + // cache the right, release the left + self.squeezed = Some(right); + left + } +} + +// TODO: create a macro to derive this function automatically +pub trait CircuitAbsorb +where + F: PrimeField, +{ + fn absorb(&self, duplex: &mut DuplexState, sys: &mut RunState); +} diff --git a/kimchi/src/snarky/range_checks.rs b/kimchi/src/snarky/range_checks.rs new file mode 100644 index 0000000000..4bcab96e5a --- /dev/null +++ b/kimchi/src/snarky/range_checks.rs @@ -0,0 +1,251 @@ +use super::{constraint_system::KimchiConstraint, runner::Constraint}; +use crate::{circuits::polynomial::COLUMNS, FieldVar, RunState, SnarkyResult}; +use ark_ff::{BigInteger, PrimeField}; +use itertools::Itertools; +use std::borrow::Cow; + +///creates a field elements from the next B bits +fn parse_limb(bits: impl Iterator) -> F { + assert!(B <= 16); + let mut l = 0_u16; + let bits = bits.take(B).collect_vec(); + let mut bits = bits.into_iter().rev(); + for _ in 0..B { + let b = bits.next().unwrap(); + let b = if b { 1 } else { 0 }; + l = l + l + b; + } + F::from(l) +} + +///extracts N limbs of B bits each from the iterator provided +fn parse_limbs( + mut bits: impl Iterator, +) -> [F; N] { + [(); N].map(|_| parse_limb::(bits.by_ref())) +} + +///for v0 and v1 +struct RangeCheckLimbs1 { + crumbs: [F; 8], + limbs: [F; 6], +} + +impl RangeCheckLimbs1 { + ///extracts the limbs needed for range check from the bits of f + fn parse(f: F) -> Self { + let mut bits = f.into_bigint().to_bits_le().into_iter(); + let crumbs = parse_limbs::(bits.by_ref()); + let limbs = parse_limbs::(bits); + Self { crumbs, limbs } + } + ///produces limbs and crumbs with the expected endianess + fn into_repr(mut self) -> ([F; 6], [F; 8]) { + self.crumbs.reverse(); + self.limbs.reverse(); + let Self { crumbs, limbs } = self; + (limbs, crumbs) + } +} + +///for v2 +struct RangeCheckLimbs2 { + crumbs_low: [F; 19], + limbs: [F; 4], + crumbs_high: [F; 1], +} + +impl RangeCheckLimbs2 { + ///extracts the limbs needed for range check from the bits of f + fn parse(f: F) -> Self { + let mut bits = f.into_bigint().to_bits_le().into_iter(); + let crumbs_low = parse_limbs::(bits.by_ref()); + let limbs = parse_limbs::(bits.by_ref()); + let crumbs_high = parse_limbs::(bits); + Self { + crumbs_low, + limbs, + crumbs_high, + } + } + ///produces limbs and crumbs with the expected endianess + fn into_repr(mut self) -> ([F; 1], [F; 4], [F; 19]) { + self.crumbs_high.reverse(); + self.limbs.reverse(); + self.crumbs_low.reverse(); + let Self { + crumbs_low, + limbs, + crumbs_high, + } = self; + (crumbs_high, limbs, crumbs_low) + } +} + +pub fn range_check( + runner: &mut RunState, + loc: Cow<'static, str>, + v0: FieldVar, + v1: FieldVar, + v2: FieldVar, +) -> SnarkyResult<()> { + let v0_limbs: ([FieldVar; 6], [FieldVar; 8]) = runner.compute(loc.clone(), |w| { + let v = w.read_var(&v0); + let limbs = RangeCheckLimbs1::parse(v); + limbs.into_repr() + })?; + let v0p0 = v0_limbs.0[0].clone(); + let v0p1 = v0_limbs.0[1].clone(); + let r0 = [v0] + .into_iter() + .chain(v0_limbs.0) + .chain(v0_limbs.1) + .collect_vec(); + let r0: [FieldVar; COLUMNS] = r0.try_into().unwrap(); + + let v1_limbs: ([FieldVar; 6], [FieldVar; 8]) = runner.compute(loc.clone(), |w| { + let v = w.read_var(&v1); + let limbs = RangeCheckLimbs1::parse(v); + limbs.into_repr() + })?; + let v1p0 = v1_limbs.0[0].clone(); + let v1p1 = v1_limbs.0[1].clone(); + let r1 = [v1] + .into_iter() + .chain(v1_limbs.0) + .chain(v1_limbs.1) + .collect_vec(); + + type Limbs = [FieldVar; N]; + let r1: [FieldVar; COLUMNS] = r1.try_into().unwrap(); + let v2_limbs: ((Limbs, Limbs), Limbs) = + runner.compute(loc.clone(), |w| { + let v = w.read_var(&v2); + let limbs = RangeCheckLimbs2::parse(v); + let (a, b, c) = limbs.into_repr(); + ((a, b), c) + })?; + let ((v2_crumb_high, v2_limb), v2_crumb_low) = v2_limbs; + let mut v2_crumb_high = v2_crumb_high.into_iter(); + let mut v2_crumb_low = v2_crumb_low.into_iter(); + + let r2 = [v2] + .into_iter() + .chain([FieldVar::zero()]) + .chain(v2_crumb_high.next()) + .chain(v2_limb) + .chain(v2_crumb_low.by_ref().take(8)) + .collect_vec(); + + let r2: [FieldVar; COLUMNS] = r2.try_into().unwrap(); + + let first_3 = v2_crumb_low.by_ref().take(3).collect_vec(); + let r3 = first_3 + .into_iter() + .chain([v0p0, v0p1, v1p0, v1p1]) + .chain(v2_crumb_low.take(8)) + .collect_vec(); + let r3: [FieldVar; COLUMNS] = r3.try_into().unwrap(); + + let rows = [r0, r1, r2, r3].map(|r| r.to_vec()).to_vec(); + + let constraint = Constraint::KimchiConstraint(KimchiConstraint::RangeCheck(rows)); + runner.add_constraint(constraint, Some("Range check".into()), loc)?; + Ok(()) +} + +#[cfg(test)] +mod test { + use crate::{ + circuits::expr::constraints::ExprOps, loc, snarky::api::SnarkyCircuit, FieldVar, RunState, + SnarkyResult, + }; + use mina_curves::pasta::{Fp, Vesta, VestaParameters}; + use mina_poseidon::{ + constants::PlonkSpongeConstantsKimchi, + sponge::{DefaultFqSponge, DefaultFrSponge}, + }; + use poly_commitment::ipa::OpeningProof; + + type BaseSponge = DefaultFqSponge; + type ScalarSponge = DefaultFrSponge; + struct TestCircuit {} + + impl SnarkyCircuit for TestCircuit { + type Curve = Vesta; + type Proof = OpeningProof; + + type PrivateInput = Fp; + type PublicInput = (); + type PublicOutput = (); + + fn circuit( + &self, + sys: &mut RunState, + _public: Self::PublicInput, + private: Option<&Self::PrivateInput>, + ) -> SnarkyResult { + let v: FieldVar = sys.compute(loc!(), |_| *private.unwrap())?; + + sys.range_check(loc!(), v.clone(), v.clone(), v)?; + + Ok(()) + } + } + + #[test] + fn snarky_range_check() { + // compile + let test_circuit = TestCircuit {}; + + let (mut prover_index, verifier_index) = test_circuit.compile_to_indexes().unwrap(); + + let mut rng = o1_utils::tests::make_test_rng(None); + + use rand::Rng; + // prove + { + let private_input: Fp = Fp::from(rng.gen::()); + + let debug = true; + let (proof, _public_output) = prover_index + .prove::((), private_input, debug) + .unwrap(); + + // verify proof + verifier_index.verify::(proof, (), ()); + } + + // prove a different execution + { + let private_input = Fp::from(2).pow(88) - Fp::from(1); + let debug = true; + let (proof, _public_output) = prover_index + .prove::((), private_input, debug) + .unwrap(); + + // verify proof + verifier_index.verify::(proof, (), ()); + } + } + #[test] + #[should_panic] + fn snarky_range_check_fail() { + // compile + let test_circuit = TestCircuit {}; + + let (mut prover_index, _) = test_circuit.compile_to_indexes().unwrap(); + + // print ASM + println!("{}", prover_index.asm()); + + // prove a bad execution + { + let private_input = Fp::from(0) - Fp::from(1); + let debug = true; + let (_proof, _public_output) = prover_index + .prove::((), private_input, debug) + .unwrap(); + } + } +} diff --git a/kimchi/src/snarky/runner.rs b/kimchi/src/snarky/runner.rs new file mode 100644 index 0000000000..a7e8951ac5 --- /dev/null +++ b/kimchi/src/snarky/runner.rs @@ -0,0 +1,678 @@ +//! The circuit-generation and witness-generation logic. + +use std::borrow::Cow; + +use super::{ + api::Witness, + constants::Constants, + errors::{ + RealSnarkyError, SnarkyCompilationError, SnarkyError, SnarkyResult, SnarkyRuntimeResult, + }, + poseidon::poseidon, + range_checks::range_check, +}; +use crate::{ + circuits::gate::CircuitGate, + curve::KimchiCurve, + snarky::{ + boolean::Boolean, + constraint_system::{BasicSnarkyConstraint, KimchiConstraint, SnarkyConstraintSystem}, + cvar::FieldVar, + errors::SnarkyRuntimeError, + snarky_type::SnarkyType, + }, +}; +use ark_ff::PrimeField; + +impl Constraint +where + F: PrimeField, +{ + /// In witness generation, this checks if the constraint is satisfied by some witness values. + pub fn check_constraint(&self, env: &impl WitnessGeneration) -> SnarkyRuntimeResult<()> { + match self { + Constraint::BasicSnarkyConstraint(c) => c.check_constraint(env), + Constraint::KimchiConstraint(c) => c.check_constraint(env), + } + } +} + +/// An enum that wraps either a [`BasicSnarkyConstraint`] or a \[`KimchiConstraintSystem`\]. +// TODO: we should get rid of this once basic constraint system is gone +#[derive(Debug)] +pub enum Constraint { + /// Old R1CS-like constraints. + BasicSnarkyConstraint(BasicSnarkyConstraint>), + + /// Custom gates in kimchi. + KimchiConstraint(KimchiConstraint, F>), +} + +/// The state used when compiling a circuit in snarky, or used in witness generation as well. +#[derive(Debug)] +pub struct RunState +where + F: PrimeField, +{ + /// The constraint system used to build the circuit. + /// If not set, the constraint system is not built. + pub system: Option>, + + /// The public input of the circuit used in witness generation. + // TODO: can we merge public_input and private_input? + public_input: Vec, + + // TODO: we could also just store `usize` here + pub(crate) public_output: Vec>, + + /// The private input of the circuit used in witness generation. Still not sure what that is, or why we care about this. + private_input: Vec, + + /// If set, the witness generation will check if the constraints are satisfied. + /// This is useful to simulate running the circuit and return an error if an assertion fails. + pub eval_constraints: bool, + + /// The size of the public input part. This contains the public output as well. + // TODO: maybe remove the public output part here? This will affect OCaml-side though. + pub num_public_inputs: usize, + + /// A counter used to track variables (this includes public inputs) as they're being created. + pub next_var: usize, + + /// Indication that we're running the witness generation. + /// This does not necessarily mean that constraints are not created, + /// as we can do both at the same time. + // TODO: perhaps we should try to make the distinction between witness/constraint generation clearer + pub has_witness: bool, + + /// Indication that we're running in prover mode. + /// In this mode, we do not want to create constraints. + // TODO: I think we should be able to safely remove this as we don't use this in Rust. Check with snarkyJS if they need this here though. + pub as_prover: bool, + + /// A stack of labels, to get better errors. + labels_stack: Vec>, + + /// This does not count exactly the number of constraints, + /// but rather the number of times we call [RunState::add_constraint]. + constraints_counter: usize, + + /// A map from a constraint index to a source location + /// (usually a file name and line number). + constraints_locations: Vec>, +} + +// +// witness generation +// + +/// A witness generation environment. +/// This is passed to any closure in [RunState::compute] so that they can access the witness generation environment. +pub trait WitnessGeneration +where + F: PrimeField, +{ + /// Allows the caller to obtain the value behind a circuit variable. + fn read_var(&self, var: &FieldVar) -> F; + + fn constraints_counter(&self) -> usize; +} + +impl> WitnessGeneration for &G { + fn read_var(&self, var: &FieldVar) -> F { + G::read_var(*self, var) + } + + fn constraints_counter(&self) -> usize { + G::constraints_counter(*self) + } +} + +impl WitnessGeneration for &dyn WitnessGeneration { + fn read_var(&self, var: &FieldVar) -> F { + (**self).read_var(var) + } + + fn constraints_counter(&self) -> usize { + (**self).constraints_counter() + } +} + +impl WitnessGeneration for RunState +where + F: PrimeField, +{ + fn read_var(&self, var: &FieldVar) -> F { + var.eval(self) + } + + fn constraints_counter(&self) -> usize { + self.constraints_counter + } +} + +// +// circuit generation +// + +impl RunState +where + F: PrimeField, +{ + /// Creates a new [`Self`] based on the size of the public input, + /// and the size of the public output. + /// If `with_system` is set it will create a [SnarkyConstraintSystem] in + /// order to compile a new circuit. + pub fn new>( + public_input_size: usize, + public_output_size: usize, + with_system: bool, + ) -> Self { + // init + let num_public_inputs = public_input_size + public_output_size; + + // create the CS + let constants = Constants::new::(); + let system = if with_system { + let mut system = SnarkyConstraintSystem::create(constants); + system.set_primary_input_size(num_public_inputs); + Some(system) + } else { + None + }; + + // create the runner + let mut sys = Self { + system, + public_input: Vec::with_capacity(num_public_inputs), + public_output: Vec::with_capacity(public_output_size), + private_input: vec![], + eval_constraints: true, + num_public_inputs, + next_var: 0, + has_witness: false, + as_prover: false, + labels_stack: vec![], + constraints_counter: 0, + constraints_locations: vec![], + }; + + // allocate the public inputs + for _ in 0..public_input_size { + sys.alloc_var(); + } + + // allocate the public output and store it + for _ in 0..public_output_size { + let cvar = sys.alloc_var(); + sys.public_output.push(cvar); + } + + // + sys + } + + /// Used internaly to evaluate variables. + /// Can panic if used with a wrong index. + pub fn read_var_idx(&self, idx: usize) -> F { + if idx < self.num_public_inputs { + self.public_input[idx] + } else { + self.private_input[idx - self.num_public_inputs] + } + } + + /// Returns the public input snarky variable. + // TODO: perhaps this should be renamed `compile_circuit` and encapsulate more logic (since this is only used to compile a given circuit) + pub fn public_input>(&self) -> T { + assert_eq!( + T::SIZE_IN_FIELD_ELEMENTS, + self.num_public_inputs - self.public_output.len() + ); + + let mut cvars = Vec::with_capacity(T::SIZE_IN_FIELD_ELEMENTS); + for i in 0..T::SIZE_IN_FIELD_ELEMENTS { + cvars.push(FieldVar::Var(i)); + } + let aux = T::constraint_system_auxiliary(); + T::from_cvars_unsafe(cvars, aux) + } + + /// Allocates a new var representing a private input. + pub fn alloc_var(&mut self) -> FieldVar { + let v = self.next_var; + self.next_var += 1; + FieldVar::Var(v) + } + + /// Stores a field element as an unconstrained private input. + pub fn store_field_elt(&mut self, x: F) -> FieldVar { + let v = self.next_var; + self.next_var += 1; + self.private_input.push(x); + FieldVar::Var(v) + } + + /// Creates a new non-deterministic variable associated to a value type ([SnarkyType]), + /// and a closure that can compute it when in witness generation mode. + pub fn compute( + &mut self, + loc: Cow<'static, str>, + to_compute_value: FUNC, + ) -> SnarkyResult + where + T: SnarkyType, + FUNC: FnOnce(&dyn WitnessGeneration) -> T::OutOfCircuit, + { + self.compute_inner(true, loc, to_compute_value) + } + + /// Same as [Self::compute] except that it does not attempt to constrain the value it computes. + /// This is to be used internally only, when we know that the value cannot be malformed. + pub(crate) fn compute_unsafe( + &mut self, + loc: Cow<'static, str>, + to_compute_value: FUNC, + ) -> SnarkyResult + where + T: SnarkyType, + FUNC: Fn(&dyn WitnessGeneration) -> T::OutOfCircuit, + { + self.compute_inner(false, loc, to_compute_value) + } + + /// The logic called by both [Self::compute] and [Self::compute_unsafe]. + fn compute_inner( + &mut self, + checked: bool, + loc: Cow<'static, str>, + to_compute_value: FUNC, + ) -> SnarkyResult + where + T: SnarkyType, + FUNC: FnOnce(&dyn WitnessGeneration) -> T::OutOfCircuit, + { + // we're in witness generation mode + if self.has_witness { + // compute the value by running the closure + let value: T::OutOfCircuit = to_compute_value(self); + + // convert the value into field elements + let (fields, aux) = T::value_to_field_elements(&value); + let mut field_vars = vec![]; + + // convert each field element into a circuit var + for field in fields { + let v = self.store_field_elt(field); + field_vars.push(v); + } + + // parse them as a snarky type + let snarky_type = T::from_cvars_unsafe(field_vars, aux); + + // constrain the conversion + if checked { + snarky_type.check(self, loc)?; + } + + // return the snarky type + Ok(snarky_type) + } + /* we're in constraint generation mode */ + else { + // create enough variables to store the given type + let mut cvars = vec![]; + for _ in 0..T::SIZE_IN_FIELD_ELEMENTS { + // TODO: rename to alloc_cvar + let v = self.alloc_var(); + cvars.push(v); + } + + // parse them as a snarky type + let aux = T::constraint_system_auxiliary(); + let snarky_type = T::from_cvars_unsafe(cvars, aux); + + // constrain the created circuit variables + if checked { + snarky_type.check(self, loc)?; + } + + // return the snarky type + Ok(snarky_type) + } + } + + // TODO: get rid of this. + /// Creates a constraint for `assert_eq!(a * b, c)`. + pub fn assert_r1cs( + &mut self, + label: Option>, + loc: Cow<'static, str>, + a: FieldVar, + b: FieldVar, + c: FieldVar, + ) -> SnarkyResult<()> { + let constraint = BasicSnarkyConstraint::R1CS(a, b, c); + self.add_constraint(Constraint::BasicSnarkyConstraint(constraint), label, loc) + } + + // TODO: get rid of this + /// Creates a constraint for `assert_eq!(x, y)`; + pub fn assert_eq( + &mut self, + label: Option>, + loc: Cow<'static, str>, + x: FieldVar, + y: FieldVar, + ) -> SnarkyResult<()> { + let constraint = BasicSnarkyConstraint::Equal(x, y); + self.add_constraint(Constraint::BasicSnarkyConstraint(constraint), label, loc) + } + + /// Adds a list of [`Constraint`] to the circuit. + // TODO: clean up all these add constraints functions + // TODO: do I really need to pass a vec? + pub fn add_constraint( + &mut self, + constraint: Constraint, + label: Option>, + // TODO: we don't need to pass that through all the calls down the stack, we can just save it at this point (and the latest loc in the state is the one that threw) + loc: Cow<'static, str>, + ) -> SnarkyResult<()> { + self.with_label(label, |env| { + // increment the constraint counter + env.constraints_counter += 1; + + // TODO: + // [START_TODO] + // my understanding is that this should work with the OCaml side, + // as `generate_witness_conv` on the OCaml side will have an empty constraint_system at this point which means constraints can't be created (see next line) + // instead, I just ensure that when we're in witness generation we don't create constraints + // I don't think we ever do both at the same time on the OCaml side side anyway. + // Note: if we want to address the TODO below, I think we should instead do this: + // have an enum: 1) compile 2) witness generation 3) both + // and have the both enum variant be used from an API that does both + // [END_TODO] + env.constraints_locations.push(loc.clone()); + + // We check the constraint + // TODO: this is checked at the front end level, perhaps we should check at the constraint system / backend level so that we can tell exactly what row is messed up? (for internal debugging that would really help) + if env.has_witness && env.eval_constraints { + constraint + .check_constraint(env) + .map_err(|e| env.runtime_error(*e))?; + } + + if !env.has_witness { + // TODO: we should have a mode "don't create constraints" instead of having an option here + let cs = match &mut env.system { + Some(cs) => cs, + None => return Ok(()), + }; + + match constraint { + Constraint::BasicSnarkyConstraint(c) => { + cs.add_basic_snarky_constraint(&env.labels_stack, &loc, c); + } + Constraint::KimchiConstraint(c) => { + cs.add_constraint(&env.labels_stack, &loc, c); + } + } + } + + Ok(()) + }) + } + + /// Adds a constraint that returns `then_` if `b` is `true`, `else_` otherwise. + /// Equivalent to `if b { then_ } else { else_ }`. + // TODO: move this out + pub fn if_( + &mut self, + loc: Cow<'static, str>, + b: Boolean, + then_: FieldVar, + else_: FieldVar, + ) -> SnarkyResult> { + // r = e + b (t - e) + // r - e = b (t - e) + let cvars = b.to_cvars().0; + let b = &cvars[0]; + if let FieldVar::Constant(b) = b { + if b.is_one() { + return Ok(then_); + } else { + return Ok(else_); + } + } + + match (&then_, &else_) { + (FieldVar::Constant(t), FieldVar::Constant(e)) => { + let t_times_b = b.scale(*t); + let one_minus_b = FieldVar::Constant(F::one()) - b; + Ok(t_times_b + &one_minus_b.scale(*e)) + } + _ => { + let b_clone = b.clone(); + let then_clone = then_.clone(); + let else_clone = else_.clone(); + let res: FieldVar = self.compute(loc.clone(), move |env| { + let b = env.read_var(&b_clone); + let res_var = if b == F::one() { + &then_clone + } else { + &else_clone + }; + let res: F = res_var.read(env); + res + })?; + let then_ = &then_ - &else_; + let else_ = &res - &else_; + // TODO: annotation? + self.assert_r1cs(Some("if_".into()), loc, b.clone(), then_, else_)?; + + Ok(res) + } + } + } + + /// Wires the given snarky variable to the public output part of the public input. + pub(crate) fn wire_public_output( + &mut self, + return_var: impl SnarkyType, + ) -> SnarkyResult<()> { + // obtain cvars for the returned vars + let (return_cvars, _aux) = return_var.to_cvars(); + + // obtain the vars involved in the public output part of the public input + let public_output_cvars = self.public_output.clone(); + if return_cvars.len() != public_output_cvars.len() { + return Err(self.runtime_error(SnarkyRuntimeError::CircuitReturnVar( + return_cvars.len(), + public_output_cvars.len(), + ))); + } + + // wire these to the public output part of the public input + // note: this will reduce the cvars contained in the output vars + for (a, b) in return_cvars + .into_iter() + .zip(public_output_cvars.into_iter()) + { + self.assert_eq( + Some("wiring public output".into()), + "this should never error".into(), + a, + b, + )?; + } + + Ok(()) + } + + /// Finalizes the public output using the actual variables returned by the circuit. + pub(crate) fn wire_output_and_compile( + &mut self, + return_var: impl SnarkyType, + ) -> SnarkyResult<&[CircuitGate]> { + // wire output + self.wire_public_output(return_var)?; + + // compile + if let Some(cs) = &mut self.system { + Ok(cs.finalize_and_get_gates()) + } else { + // TODO: do we really want to panic here? + panic!("woot"); + } + } + + /// Getter for the OCaml side. + #[cfg(feature = "ocaml_types")] + pub fn get_private_inputs(&self) -> Vec { + self.private_input.clone() + } + + /// This adds a label in the stack of labels. + /// Every error from now one will contain this label, + /// until the label is popped (via [Self::pop_label]). + pub fn add_label(&mut self, label: Cow<'static, str>) { + self.labels_stack.push(label); + } + + /// This removes a label from any error that could come up from now on. + /// Normally used shortly after [Self::add_label]. + pub fn pop_label(&mut self) { + self.labels_stack.pop(); + } + + /// A wrapper around code that needs to be labeled + /// (for better errors). + pub fn with_label(&mut self, label: Option>, closure: FUNC) -> T + where + FUNC: FnOnce(&mut Self) -> T, + { + let need_to_pop = label.is_some(); + + if let Some(label) = label { + self.add_label(label); + } + + let res = closure(self); + + if need_to_pop { + self.pop_label(); + } + + res + } + + /// Creates an [RealSnarkyError] using the current context. + pub fn error(&self, error: SnarkyError) -> RealSnarkyError { + let loc = if self.constraints_counter == 0 { + "error during initialization".into() + } else { + self.constraints_locations[self.constraints_counter - 1].clone() + }; + RealSnarkyError::new_with_ctx(error, loc, self.labels_stack.clone()) + } + + /// Creates a runtime error. + pub fn runtime_error(&self, error: SnarkyRuntimeError) -> Box { + Box::new(self.error(SnarkyError::RuntimeError(error))) + } + + /// Crates a compilation error. + pub fn compilation_error(&self, error: SnarkyCompilationError) -> Box { + Box::new(self.error(SnarkyError::CompilationError(error))) + } + + pub fn generate_witness_init(&mut self, mut public_input: Vec) -> SnarkyResult<()> { + // check that the given public_input is of the correct length + // (not including the public output) + let obtained = public_input.len(); + let expected = self.num_public_inputs - self.public_output.len(); + if expected != obtained { + return Err( + self.runtime_error(SnarkyRuntimeError::PubInputMismatch(obtained, expected)) + ); + } + + // pad with zeros for the public output part + public_input.extend(std::iter::repeat(F::zero()).take(self.public_output.len())); + + // re-initialize `next_var` (which will grow every time we compile or generate a witness) + self.next_var = self.num_public_inputs; + + // set the mode to "witness generation" + self.has_witness = true; + + // set the public inputs + self.public_input = public_input; + + // reset the private inputs + self.private_input = Vec::with_capacity(self.private_input.len()); + + // reset the constraint counter for better debugging + self.constraints_counter = 0; + + // reset the constraints' locations + // we have to do this to imitate what the OCaml side does + // (the OCaml side always starts with a fresh state) + self.constraints_locations = Vec::with_capacity(self.constraints_locations.len()); + + Ok(()) + } + + /// Returns the public output generated after running the circuit, + /// and the witness of the execution trace. + pub fn generate_witness(&mut self) -> Witness { + // TODO: asserting this is dumb.. what if there's no private input : D + assert!(!self.private_input.is_empty()); + + // TODO: do we really want to panic here? + let system = self.system.as_mut().unwrap(); + + let get_one = |var_idx| { + if var_idx < self.num_public_inputs { + self.public_input[var_idx] + } else { + self.private_input[var_idx - self.num_public_inputs] + } + }; + + // compute witness + // TODO: can we avoid passing a closure here? a reference to a Inputs struct would be better perhaps. + let witness = system.compute_witness(get_one); + + // clear state (TODO: find better solution) + self.public_input = vec![]; + self.next_var = self.num_public_inputs; + + // return public output and witness + Witness(witness) + } + + pub(crate) fn poseidon_params(&self) -> mina_poseidon::poseidon::ArithmeticSpongeParams { + // TODO: do we really want to panic here? + self.system.as_ref().map(|sys| sys.sponge_params()).unwrap() + } + + pub fn poseidon( + &mut self, + loc: Cow<'static, str>, + preimage: (FieldVar, FieldVar), + ) -> (FieldVar, FieldVar) { + poseidon(self, loc, preimage) + } + ///constrains the 3 provided values to fit in 88 bits + pub fn range_check( + &mut self, + loc: Cow<'static, str>, + v0: FieldVar, + v1: FieldVar, + v2: FieldVar, + ) -> SnarkyResult<()> { + range_check(self, loc, v0, v1, v2) + } +} diff --git a/kimchi/src/snarky/snarky_type.rs b/kimchi/src/snarky/snarky_type.rs new file mode 100644 index 0000000000..8468505dd5 --- /dev/null +++ b/kimchi/src/snarky/snarky_type.rs @@ -0,0 +1,249 @@ +//! The [SnarkyType] trait is a useful trait that allows us to define our own +//! snarky variables on top of [FieldVar]. +//! Without it, we'd be limited to using the [FieldVar] type directly, handling +//! and returning arrays of [FieldVar]s all the time. +//! So this is useful for the same reason that programming languages allow users +//! to define their own types, instead of relying only on `usize`. +//! +//! To define your own snarky type, implement the `SnarkyType` trait. +//! A good example is the `Boolean` type. + +use super::{ + cvar::FieldVar, + errors::SnarkyResult, + runner::{RunState, WitnessGeneration}, +}; + +use ark_ff::PrimeField; + +use std::{borrow::Cow, fmt::Debug}; + +/// A snarky type is a type that can be used in a circuit. +/// It references an equivalent "out-of-circuit" type that one can use outside of the circuit. +/// (For example, to construct private or public inputs, or a public output, to the circuit.) +pub trait SnarkyType: Debug + Sized +where + F: PrimeField, +{ + /// Some 'out-of-circuit' data, which is carried as part of Self. + /// This data isn't encoded as CVars in the circuit, since the data may be large (e.g. a sparse merkle tree), + /// or may only be used by witness computations / for debugging. + type Auxiliary; + + /// The equivalent "out-of-circuit" type. + /// For example, the [super::boolean::Boolean] snarky type has an out-of-circuit type of [bool]. + type OutOfCircuit; + + /// The number of field elements that this type takes. + const SIZE_IN_FIELD_ELEMENTS: usize; + + /// Returns the circuit variables (and auxiliary data) behind this type. + fn to_cvars(&self) -> (Vec>, Self::Auxiliary); + + /// Creates a new instance of this type from the given circuit variables (And some auxiliary data). + fn from_cvars_unsafe(cvars: Vec>, aux: Self::Auxiliary) -> Self; + + /// Checks that the circuit variables behind this type are valid. + /// For some definition of valid. + /// For example, a Boolean snarky type would check that the field element representing it is either 0 or 1. + /// The function does this by adding constraints to your constraint system. + fn check(&self, cs: &mut RunState, loc: Cow<'static, str>) -> SnarkyResult<()>; + + /// The "default" value of [Self::Auxiliary]. + /// This is passed to [Self::from_cvars_unsafe] when we are not generating a witness, + /// since we have no candidate value to get the auxiliary data from. + /// Note that we use an explicit value here rather than Auxiliary: Default, + /// since the default value for the type may not match the default value we actually want to pass! + fn constraint_system_auxiliary() -> Self::Auxiliary; + + /// Converts an out-of-circuit value + fn value_to_field_elements(value: &Self::OutOfCircuit) -> (Vec, Self::Auxiliary); + + fn value_of_field_elements(fields: Vec, aux: Self::Auxiliary) -> Self::OutOfCircuit; + + // + // new functions that might help us with generics? + // + + fn compute( + cs: &mut RunState, + loc: Cow<'static, str>, + to_compute_value: FUNC, + ) -> SnarkyResult + where + FUNC: Fn(&dyn WitnessGeneration) -> Self::OutOfCircuit, + { + cs.compute(loc, to_compute_value) + } + + fn read(&self, g: G) -> Self::OutOfCircuit + where + G: WitnessGeneration, + { + let (cvars, aux) = self.to_cvars(); + let values = cvars.iter().map(|cvar| g.read_var(cvar)).collect(); + Self::value_of_field_elements(values, aux) + } +} + +/// A trait to convert between a snarky type and its out-of-circuit equivalent. +// TODO: should we then remove these functions + OutOfCircuit from SnarkyType? (And have SnarkyType : CircuitAndValue) +pub trait CircuitAndValue: SnarkyType +where + F: PrimeField, +{ + fn to_value(fields: Vec, aux: Self::Auxiliary) -> Self::OutOfCircuit; + fn from_value(value: &Self::OutOfCircuit) -> (Vec, Self::Auxiliary); +} + +// +// Auto traits +// + +impl SnarkyType for () +where + F: PrimeField, +{ + type Auxiliary = (); + + type OutOfCircuit = (); + + const SIZE_IN_FIELD_ELEMENTS: usize = 0; + + fn to_cvars(&self) -> (Vec>, Self::Auxiliary) { + (vec![], ()) + } + + fn from_cvars_unsafe(_cvars: Vec>, _aux: Self::Auxiliary) -> Self {} + + fn check(&self, _cs: &mut RunState, _loc: Cow<'static, str>) -> SnarkyResult<()> { + Ok(()) + } + + fn constraint_system_auxiliary() -> Self::Auxiliary {} + + fn value_to_field_elements(_value: &Self::OutOfCircuit) -> (Vec, Self::Auxiliary) { + (vec![], ()) + } + + fn value_of_field_elements(_fields: Vec, _aux: Self::Auxiliary) -> Self::OutOfCircuit {} +} + +impl SnarkyType for (T1, T2) +where + F: PrimeField, + T1: SnarkyType, + T2: SnarkyType, +{ + type Auxiliary = (T1::Auxiliary, T2::Auxiliary); + + type OutOfCircuit = (T1::OutOfCircuit, T2::OutOfCircuit); + + const SIZE_IN_FIELD_ELEMENTS: usize = T1::SIZE_IN_FIELD_ELEMENTS + T2::SIZE_IN_FIELD_ELEMENTS; + + fn to_cvars(&self) -> (Vec>, Self::Auxiliary) { + let (mut cvars1, aux1) = self.0.to_cvars(); + let (cvars2, aux2) = self.1.to_cvars(); + cvars1.extend(cvars2); + (cvars1, (aux1, aux2)) + } + + fn from_cvars_unsafe(cvars: Vec>, aux: Self::Auxiliary) -> Self { + assert_eq!(cvars.len(), Self::SIZE_IN_FIELD_ELEMENTS); + let (cvars1, cvars2) = cvars.split_at(T1::SIZE_IN_FIELD_ELEMENTS); + let (aux1, aux2) = aux; + ( + T1::from_cvars_unsafe(cvars1.to_vec(), aux1), + T2::from_cvars_unsafe(cvars2.to_vec(), aux2), + ) + } + + fn check(&self, cs: &mut RunState, loc: Cow<'static, str>) -> SnarkyResult<()> { + self.0.check(cs, loc.clone())?; + self.1.check(cs, loc)?; + Ok(()) + } + + fn constraint_system_auxiliary() -> Self::Auxiliary { + ( + T1::constraint_system_auxiliary(), + T2::constraint_system_auxiliary(), + ) + } + + fn value_to_field_elements(value: &Self::OutOfCircuit) -> (Vec, Self::Auxiliary) { + let (mut fields, aux1) = T1::value_to_field_elements(&value.0); + let (fields2, aux2) = T2::value_to_field_elements(&value.1); + fields.extend(fields2); + (fields, (aux1, aux2)) + } + + fn value_of_field_elements(fields: Vec, aux: Self::Auxiliary) -> Self::OutOfCircuit { + let (fields1, fields2) = fields.split_at(T1::SIZE_IN_FIELD_ELEMENTS); + + let out1 = T1::value_of_field_elements(fields1.to_vec(), aux.0); + let out2 = T2::value_of_field_elements(fields2.to_vec(), aux.1); + + (out1, out2) + } +} + +impl SnarkyType for [T; N] +where + F: PrimeField, + T: SnarkyType, +{ + // TODO: convert this to a `[T::Auxiliary; N]` + type Auxiliary = Vec; + + type OutOfCircuit = [T::OutOfCircuit; N]; + + const SIZE_IN_FIELD_ELEMENTS: usize = N * T::SIZE_IN_FIELD_ELEMENTS; + + fn to_cvars(&self) -> (Vec>, Self::Auxiliary) { + let (cvars, aux): (Vec>, Vec<_>) = self.iter().map(|t| t.to_cvars()).unzip(); + let cvars = cvars.concat(); + (cvars, aux) + } + + fn from_cvars_unsafe(cvars: Vec>, aux: Self::Auxiliary) -> Self { + let mut cvars_and_aux = cvars.chunks(T::SIZE_IN_FIELD_ELEMENTS).zip(aux); + + std::array::from_fn(|_| { + let (cvars, aux) = cvars_and_aux.next().unwrap(); + assert_eq!(cvars.len(), T::SIZE_IN_FIELD_ELEMENTS); + T::from_cvars_unsafe(cvars.to_vec(), aux) + }) + } + + fn check(&self, cs: &mut RunState, loc: Cow<'static, str>) -> SnarkyResult<()> { + for t in self.iter() { + t.check(cs, loc.clone())?; + } + Ok(()) + } + + fn constraint_system_auxiliary() -> Self::Auxiliary { + let mut aux = Vec::with_capacity(T::SIZE_IN_FIELD_ELEMENTS); + for _ in 0..N { + aux.push(T::constraint_system_auxiliary()); + } + aux + } + + fn value_to_field_elements(value: &Self::OutOfCircuit) -> (Vec, Self::Auxiliary) { + let (fields, aux): (Vec>, Vec<_>) = + value.iter().map(|v| T::value_to_field_elements(v)).unzip(); + (fields.concat(), aux) + } + + fn value_of_field_elements(fields: Vec, aux: Self::Auxiliary) -> Self::OutOfCircuit { + let mut values_and_aux = fields.chunks(T::SIZE_IN_FIELD_ELEMENTS).zip(aux); + + std::array::from_fn(|_| { + let (fields, aux) = values_and_aux.next().unwrap(); + assert_eq!(fields.len(), T::SIZE_IN_FIELD_ELEMENTS); + T::value_of_field_elements(fields.to_vec(), aux) + }) + } +} diff --git a/kimchi/src/snarky/tests.rs b/kimchi/src/snarky/tests.rs new file mode 100644 index 0000000000..9af28ab79c --- /dev/null +++ b/kimchi/src/snarky/tests.rs @@ -0,0 +1,143 @@ +use crate::{ + loc, + snarky::{ + api::SnarkyCircuit, + boolean::Boolean, + cvar::FieldVar, + errors::{SnarkyError, SnarkyRuntimeError}, + runner::RunState, + }, +}; +use ark_ff::One; +use mina_curves::pasta::{Fp, Vesta, VestaParameters}; +use mina_poseidon::{ + constants::PlonkSpongeConstantsKimchi, + sponge::{DefaultFqSponge, DefaultFrSponge}, +}; +use poly_commitment::ipa::OpeningProof; + +use super::prelude::*; + +type BaseSponge = DefaultFqSponge; +type ScalarSponge = DefaultFrSponge; + +struct TestCircuit {} + +struct Priv { + x: Fp, + y: Fp, + z: Fp, +} + +impl SnarkyCircuit for TestCircuit { + type Curve = Vesta; + type Proof = OpeningProof; + + type PrivateInput = Priv; + type PublicInput = Boolean; + type PublicOutput = (Boolean, FieldVar); + + fn circuit( + &self, + sys: &mut RunState, + public: Self::PublicInput, + private: Option<&Self::PrivateInput>, + ) -> SnarkyResult { + let x: FieldVar = sys.compute(loc!(), |_| private.unwrap().x)?; + let y: FieldVar = sys.compute(loc!(), |_| private.unwrap().y)?; + let z: FieldVar = sys.compute(loc!(), |_| private.unwrap().z)?; + + sys.assert_r1cs(Some("x * y = z".into()), loc!(), x, y, z)?; + + let other: Boolean = sys.compute(loc!(), |_| true)?; + + // res1 = public & other + let res1 = public.and(&other, sys, loc!()); + + // res2 = res1 + 3; + let three = FieldVar::constant(Fp::from(3)); + let res2 = res1.to_field_var() + three; + + Ok((res1, res2)) + } +} + +#[test] +fn test_simple_circuit() { + // compile + let test_circuit = TestCircuit {}; + + let (mut prover_index, verifier_index) = test_circuit.compile_to_indexes().unwrap(); + + // print ASM + println!("{}", prover_index.asm()); + + // prove + { + let private_input = Priv { + x: Fp::one(), + y: Fp::from(2), + z: Fp::from(2), + }; + + let public_input = true; + let debug = true; + let (proof, public_output) = prover_index + .prove::(public_input, private_input, debug) + .unwrap(); + + let expected_public_output = (true, Fp::from(4)); + assert_eq!(*public_output, expected_public_output); + + // verify proof + verifier_index.verify::(proof, public_input, *public_output); + } + + // prove a different execution + { + let private_input = Priv { + x: Fp::one(), + y: Fp::from(3), + z: Fp::from(3), + }; + let public_input = true; + let debug = true; + let (proof, public_output) = prover_index + .prove::(public_input, private_input, debug) + .unwrap(); + + let expected_public_output = (true, Fp::from(4)); + assert_eq!(*public_output, expected_public_output); + + // verify proof + verifier_index.verify::(proof, public_input, *public_output); + } + + // prove a bad execution + { + let private_input = Priv { + x: Fp::one(), + y: Fp::from(3), + z: Fp::from(2), + }; + let public_input = true; + let debug = true; + + let res = + prover_index.prove::(public_input, private_input, debug); + + match res.unwrap_err().source { + SnarkyError::RuntimeError(SnarkyRuntimeError::UnsatisfiedR1CSConstraint( + _, + x, + y, + z, + )) => { + assert_eq!(x, Fp::one().to_string()); + assert_eq!(y, Fp::from(3).to_string()); + assert_eq!(z, Fp::from(2).to_string()); + } + err => panic!("not the err expected: {err}"), + } + } +} diff --git a/kimchi/src/snarky/union_find.rs b/kimchi/src/snarky/union_find.rs new file mode 100644 index 0000000000..be1c7c754b --- /dev/null +++ b/kimchi/src/snarky/union_find.rs @@ -0,0 +1,102 @@ +use std::{collections::HashMap, hash::Hash}; + +#[derive(Debug, Clone, Default)] +/// Tarjan's Union-Find Data structure +pub struct DisjointSet { + set_size: usize, + /// The structure saves the parent information of each subset in continuous + /// memory(a vec) for better performance. + parent: Vec, + + /// Each T entry is mapped onto a usize tag. + map: HashMap, +} + +impl DisjointSet +where + T: Clone + Hash + Eq, +{ + pub fn new() -> Self { + DisjointSet { + set_size: 0, + parent: Vec::new(), + map: HashMap::new(), + } + } + + pub fn make_set(&mut self, x: T) { + let len = &mut self.set_size; + if self.map.get(&x).is_some() { + return; + } + + self.map.insert(x, *len); + self.parent.push(*len); + + *len += 1; + } + + /// Returns Some(num), num is the tag of subset in which x is. + /// If x is not in the data structure, it returns None. + pub fn find(&mut self, x: T) -> Option { + let pos = match self.map.get(&x) { + Some(p) => *p, + None => return None, + }; + + let ret = DisjointSet::::find_internal(&mut self.parent, pos); + Some(ret) + } + + fn find_internal(p: &mut Vec, n: usize) -> usize { + if p[n] != n { + let parent = p[n]; + p[n] = DisjointSet::::find_internal(p, parent); + p[n] + } else { + n + } + } + + /// Union the subsets to which x and y belong. + /// If it returns `Some`, it is the tag for unified subset. + /// If it returns `None`, at least one of x and y is not in the + /// disjoint-set. + pub fn union(&mut self, x: T, y: T) -> Option { + let (x_root, y_root) = match (self.find(x), self.find(y)) { + (Some(x), Some(y)) => (x, y), + _ => { + return None; + } + }; + + self.parent[x_root] = y_root; + Some(y_root) + } +} + +#[test] +fn it_works() { + let mut ds = DisjointSet::::new(); + ds.make_set(1); + ds.make_set(2); + ds.make_set(3); + + assert!(ds.find(1) != ds.find(2)); + assert!(ds.find(2) != ds.find(3)); + ds.union(1, 2).unwrap(); + ds.union(2, 3).unwrap(); + assert!(ds.find(1) == ds.find(3)); + + assert!(ds.find(4).is_none()); + ds.make_set(4); + assert!(ds.find(4).is_some()); + + ds.make_set(-1); + assert!(ds.find(-1) != ds.find(3)); + + ds.union(-1, 4).unwrap(); + ds.union(2, 4).unwrap(); + + assert!(ds.find(-1) == ds.find(3)); +} diff --git a/kimchi/src/tests/and.rs b/kimchi/src/tests/and.rs index c8a8c680b2..c9e751d841 100644 --- a/kimchi/src/tests/and.rs +++ b/kimchi/src/tests/and.rs @@ -78,7 +78,7 @@ where let gates = create_test_gates_and::(bytes); let cs = ConstraintSystem::create(gates).build().unwrap(); - // Initalize inputs + // Initialize inputs let input1 = rng.gen(input1, Some(bytes * 8)); let input2 = rng.gen(input2, Some(bytes * 8)); diff --git a/kimchi/src/tests/chunked.rs b/kimchi/src/tests/chunked.rs index f15d8eb558..2ddf08d90b 100644 --- a/kimchi/src/tests/chunked.rs +++ b/kimchi/src/tests/chunked.rs @@ -104,7 +104,7 @@ fn test_2_to_18_chunked_generic_gate_pub() { }*/ #[test] -fn test_2_to_17_chunked_generic_gate_pub() { +fn heavy_test_2_to_17_chunked_generic_gate_pub() { test_generic_gate_with_srs_override(17, Some(1 << 16)) } diff --git a/kimchi/src/tests/ec.rs b/kimchi/src/tests/ec.rs index d00938de2c..9cac8177a8 100644 --- a/kimchi/src/tests/ec.rs +++ b/kimchi/src/tests/ec.rs @@ -118,7 +118,7 @@ fn ec_test() { } for &p in ps.iter().take(num_infs) { - let q = -p; + let q: Other = -p; let p2: Other = (p + p).into(); let (x1, y1) = (p.x, p.y); diff --git a/kimchi/src/tests/endomul.rs b/kimchi/src/tests/endomul.rs index ebbdebacc8..88bc27c1e2 100644 --- a/kimchi/src/tests/endomul.rs +++ b/kimchi/src/tests/endomul.rs @@ -13,7 +13,7 @@ use mina_poseidon::{ constants::PlonkSpongeConstantsKimchi, sponge::{DefaultFqSponge, DefaultFrSponge, ScalarChallenge}, }; -use poly_commitment::srs::endos; +use poly_commitment::ipa::endos; use std::{array, ops::Mul}; type SpongeParams = PlonkSpongeConstantsKimchi; @@ -71,6 +71,7 @@ fn endomul_test() { // let g = Other::generator().into_group(); let acc0 = { let t = Other::new_unchecked(endo_q * base.x, base.y); + // Ensuring we use affine coordinates let p = t + base; let acc: Other = (p + p).into(); (acc.x, acc.y) diff --git a/kimchi/src/tests/endomul_scalar.rs b/kimchi/src/tests/endomul_scalar.rs index 60f7c8f895..228d54e8ba 100644 --- a/kimchi/src/tests/endomul_scalar.rs +++ b/kimchi/src/tests/endomul_scalar.rs @@ -12,7 +12,7 @@ use mina_poseidon::{ constants::PlonkSpongeConstantsKimchi, sponge::{DefaultFqSponge, DefaultFrSponge, ScalarChallenge}, }; -use poly_commitment::srs::endos; +use poly_commitment::ipa::endos; use std::array; type SpongeParams = PlonkSpongeConstantsKimchi; diff --git a/kimchi/src/tests/foreign_field_add.rs b/kimchi/src/tests/foreign_field_add.rs index 14e61e58fc..55b6bcebbd 100644 --- a/kimchi/src/tests/foreign_field_add.rs +++ b/kimchi/src/tests/foreign_field_add.rs @@ -6,6 +6,9 @@ use crate::{ polynomial::COLUMNS, polynomials::{ foreign_field_add::witness::{self, FFOps}, + foreign_field_common::{ + BigUintForeignFieldHelpers, HI, LIMB_BITS, LO, MI, TWO_TO_LIMB, + }, generic::GenericGateSpec, range_check::{self, witness::extend_multi}, }, @@ -24,17 +27,14 @@ use mina_poseidon::{ }; use num_bigint::{BigUint, RandBigInt}; use num_traits::FromPrimitive; -use o1_utils::{ - foreign_field::{BigUintForeignFieldHelpers, ForeignElement, HI, LO, MI, TWO_TO_LIMB}, - tests::make_test_rng, - FieldHelpers, Two, -}; +use o1_utils::{foreign_field::ForeignElement, tests::make_test_rng, FieldHelpers, Two}; use poly_commitment::{ - evaluation_proof::OpeningProof, - srs::{endos, SRS}, + ipa::{endos, OpeningProof, SRS}, + SRS as _, }; use rand::{rngs::StdRng, Rng}; use std::{array, sync::Arc}; + type PallasField = ::BaseField; type VestaField = ::BaseField; @@ -364,7 +364,10 @@ fn test_ffadd( } // checks that the result cells of the witness are computed as expected -fn check_result(witness: [Vec; COLUMNS], result: Vec>) { +fn check_result( + witness: [Vec; COLUMNS], + result: Vec>, +) { for (i, res) in result.iter().enumerate() { assert_eq!(witness[0][i + 2], res[LO]); assert_eq!(witness[1][i + 2], res[MI]); @@ -462,7 +465,7 @@ fn test_zero_sum_native() { ); // Check result is the native modulus - let native_limbs = ForeignElement::::from_biguint(native_modulus); + let native_limbs = ForeignElement::::from_biguint(native_modulus); check_result(witness, vec![native_limbs]); } @@ -492,7 +495,7 @@ fn test_max_number() { // compute result in the foreign field after taking care of the exceeding bits let sum = secp256k1_max() + secp256k1_max(); let sum_mod = sum - secp256k1_modulus(); - let sum_mod_limbs = ForeignElement::::from_biguint(sum_mod); + let sum_mod_limbs = ForeignElement::::from_biguint(sum_mod); check_ovf(witness.clone(), PallasField::one()); check_result(witness, vec![sum_mod_limbs]); } @@ -504,10 +507,10 @@ fn test_max_number() { // and it is checked that in both cases the result is the same fn test_zero_minus_one() { // FIRST AS NEG - let right_be_neg = ForeignElement::::from_biguint(One::one()) + let right_be_neg = ForeignElement::::from_biguint(One::one()) .neg(&secp256k1_modulus()) .to_biguint(); - let right_for_neg: ForeignElement = + let right_for_neg: ForeignElement = ForeignElement::from_biguint(right_be_neg.clone()); let (witness_neg, _index) = test_ffadd( secp256k1_modulus(), @@ -531,7 +534,7 @@ fn test_zero_minus_one() { // test 1 - 1 + 1 where (-1) is in the foreign field // the first check is done with sub(1, 1) and then with add(neg(neg(1))) fn test_one_minus_one_plus_one() { - let neg_neg_one = ForeignElement::::from_biguint(One::one()) + let neg_neg_one = ForeignElement::::from_biguint(One::one()) .neg(&secp256k1_modulus()) .neg(&secp256k1_modulus()) .to_biguint(); @@ -558,11 +561,12 @@ fn test_one_minus_one_plus_one() { // then tested as 0 - 1 - 1 ) // TODO: tested as 0 - ( 1 + 1) -> put sign in front of left instead (perhaps in the future we want this) fn test_minus_minus() { - let neg_one_for = - ForeignElement::::from_biguint(One::one()).neg(&secp256k1_modulus()); - let neg_one = neg_one_for.to_biguint(); - let neg_two = ForeignElement::::from_biguint(BigUint::from_u32(2).unwrap()) + let neg_one_for = ForeignElement::::from_biguint(One::one()) .neg(&secp256k1_modulus()); + let neg_one = neg_one_for.to_biguint(); + let neg_two = + ForeignElement::::from_biguint(BigUint::from_u32(2).unwrap()) + .neg(&secp256k1_modulus()); let (witness_neg, _index) = test_ffadd( secp256k1_modulus(), vec![neg_one.clone(), neg_one], @@ -810,8 +814,8 @@ fn test_zero_sub_fmax() { &[FFOps::Sub], false, ); - let negated = - ForeignElement::::from_biguint(secp256k1_max()).neg(&secp256k1_modulus()); + let negated = ForeignElement::::from_biguint(secp256k1_max()) + .neg(&secp256k1_modulus()); check_result(witness, vec![negated]); } @@ -830,7 +834,7 @@ fn test_pasta_add_max_vesta() { false, ); let right = right_input % vesta_modulus; - let right_foreign = ForeignElement::::from_biguint(right); + let right_foreign = ForeignElement::::from_biguint(right); check_result(witness, vec![right_foreign]); } @@ -846,7 +850,7 @@ fn test_pasta_sub_max_vesta() { false, ); let neg_max_vesta = - ForeignElement::::from_biguint(right_input).neg(&vesta_modulus); + ForeignElement::::from_biguint(right_input).neg(&vesta_modulus); check_result(witness, vec![neg_max_vesta]); } @@ -862,7 +866,7 @@ fn test_pasta_add_max_pallas() { false, ); let right = right_input % vesta_modulus; - let foreign_right = ForeignElement::::from_biguint(right); + let foreign_right = ForeignElement::::from_biguint(right); check_result(witness, vec![foreign_right]); } @@ -878,7 +882,7 @@ fn test_pasta_sub_max_pallas() { false, ); let neg_max_pallas = - ForeignElement::::from_biguint(right_input).neg(&vesta_modulus); + ForeignElement::::from_biguint(right_input).neg(&vesta_modulus); check_result(witness, vec![neg_max_pallas]); } @@ -897,8 +901,9 @@ fn test_random_add() { &[FFOps::Add], false, ); - let result = - ForeignElement::::from_biguint((left_big + right_big) % foreign_mod); + let result = ForeignElement::::from_biguint( + (left_big + right_big) % foreign_mod, + ); check_result(witness, vec![result]); } @@ -918,9 +923,11 @@ fn test_random_sub() { false, ); let result = if left_big < right_big { - ForeignElement::::from_biguint(left_big + foreign_mod - right_big) + ForeignElement::::from_biguint( + left_big + foreign_mod - right_big, + ) } else { - ForeignElement::::from_biguint(left_big - right_big) + ForeignElement::::from_biguint(left_big - right_big) }; check_result(witness, vec![result]); } @@ -943,12 +950,12 @@ fn test_foreign_is_native_add() { ); // check result was computed correctly let sum_big = compute_sum(pallas, &left_input, &right_input); - let result = ForeignElement::::from_biguint(sum_big.clone()); + let result = ForeignElement::::from_biguint(sum_big.clone()); check_result(witness, vec![result.clone()]); // check result is in the native field let two_to_limb = PallasField::from(TWO_TO_LIMB); - let left = ForeignElement::::from_be(&left_input); - let right = ForeignElement::::from_be(&right_input); + let left = ForeignElement::::from_be(&left_input); + let right = ForeignElement::::from_be(&right_input); let left = (left[HI] * two_to_limb + left[MI]) * two_to_limb + left[LO]; let right = (right[HI] * two_to_limb + right[MI]) * two_to_limb + right[LO]; let sum = left + right; @@ -976,12 +983,12 @@ fn test_foreign_is_native_sub() { ); // check result was computed correctly let dif_big = compute_dif(pallas, &left_input, &right_input); - let result = ForeignElement::::from_biguint(dif_big.clone()); + let result = ForeignElement::::from_biguint(dif_big.clone()); check_result(witness, vec![result.clone()]); // check result is in the native field let two_to_limb = PallasField::from(TWO_TO_LIMB); - let left = ForeignElement::::from_be(&left_input); - let right = ForeignElement::::from_be(&right_input); + let left = ForeignElement::::from_be(&left_input); + let right = ForeignElement::::from_be(&right_input); let left = (left[HI] * two_to_limb + left[MI]) * two_to_limb + left[LO]; let right = (right[HI] * two_to_limb + right[MI]) * two_to_limb + right[LO]; let dif = left - right; @@ -1012,7 +1019,9 @@ fn test_random_small_add() { let result = compute_sum(foreign_mod, &left_input, &right_input); check_result( witness, - vec![ForeignElement::::from_biguint(result)], + vec![ForeignElement::::from_biguint( + result, + )], ); } @@ -1038,7 +1047,9 @@ fn test_random_small_sub() { let result = compute_dif(foreign_mod, &left_input, &right_input); check_result( witness, - vec![ForeignElement::::from_biguint(result)], + vec![ForeignElement::::from_biguint( + result, + )], ); } @@ -1228,7 +1239,7 @@ fn test_random_chain() { let operations = (0..nops).map(|_| random_operation(rng)).collect::>(); let (witness, _index) = test_ffadd(secp256k1_modulus(), big_inputs, &operations, true); let mut left = vec![inputs[0].clone()]; - let results: Vec> = operations + let results: Vec> = operations .iter() .enumerate() .map(|(i, op)| { @@ -1237,7 +1248,7 @@ fn test_random_chain() { FFOps::Sub => compute_dif(foreign_mod.clone(), &left[i], &inputs[i + 1]), }; left.push(result.to_bytes_be()); - ForeignElement::::from_biguint(result) + ForeignElement::::from_biguint(result) }) .collect(); check_result(witness, results); diff --git a/kimchi/src/tests/foreign_field_mul.rs b/kimchi/src/tests/foreign_field_mul.rs index 4a293fc15c..aa30ecb911 100644 --- a/kimchi/src/tests/foreign_field_mul.rs +++ b/kimchi/src/tests/foreign_field_mul.rs @@ -3,7 +3,12 @@ use crate::{ constraints::ConstraintSystem, gate::{CircuitGate, CircuitGateError, CircuitGateResult, Connect, GateType}, polynomial::COLUMNS, - polynomials::foreign_field_mul, + polynomials::{ + foreign_field_common::{ + BigUintArrayCompose, BigUintForeignFieldHelpers, FieldArrayCompose, + }, + foreign_field_mul, + }, }, curve::KimchiCurve, plonk_sponge::FrSponge, @@ -12,19 +17,14 @@ use crate::{ use ark_ec::AffineRepr; use ark_ff::{Field, PrimeField, Zero}; use mina_curves::pasta::{Fp, Fq, Pallas, PallasParameters, Vesta, VestaParameters}; -use num_bigint::BigUint; -use num_traits::One; -use o1_utils::{ - foreign_field::{BigUintArrayCompose, BigUintForeignFieldHelpers, FieldArrayCompose}, - FieldHelpers, Two, -}; - use mina_poseidon::{ constants::PlonkSpongeConstantsKimchi, sponge::{DefaultFqSponge, DefaultFrSponge}, FqSponge, }; -use num_bigint::RandBigInt; +use num_bigint::{BigUint, RandBigInt}; +use num_traits::One; +use o1_utils::{FieldHelpers, Two}; type PallasField = ::BaseField; type VestaField = ::BaseField; diff --git a/kimchi/src/tests/framework.rs b/kimchi/src/tests/framework.rs index b49fdf2aad..eec339da2a 100644 --- a/kimchi/src/tests/framework.rs +++ b/kimchi/src/tests/framework.rs @@ -27,7 +27,7 @@ use groupmap::GroupMap; use mina_poseidon::sponge::FqSponge; use num_bigint::BigUint; use poly_commitment::{ - commitment::CommitmentCurve, evaluation_proof::OpeningProof as DlogOpeningProof, OpenProof, + commitment::CommitmentCurve, ipa::OpeningProof as DlogOpeningProof, OpenProof, }; use rand_core::{CryptoRng, RngCore}; use std::{fmt::Write, time::Instant}; diff --git a/kimchi/src/tests/generic.rs b/kimchi/src/tests/generic.rs index 4997439c9f..c3f1abd22f 100644 --- a/kimchi/src/tests/generic.rs +++ b/kimchi/src/tests/generic.rs @@ -9,6 +9,7 @@ use mina_poseidon::{ constants::PlonkSpongeConstantsKimchi, sponge::{DefaultFqSponge, DefaultFrSponge}, }; +use poly_commitment::SRS; use std::array; type SpongeParams = PlonkSpongeConstantsKimchi; @@ -91,20 +92,15 @@ fn test_generic_gate_pub_empty() { #[cfg(feature = "bn254")] #[test] -fn test_generic_gate_pairing() { +fn test_generic_gate_kzg() { type Fp = ark_bn254::Fr; type SpongeParams = PlonkSpongeConstantsKimchi; type BaseSponge = DefaultFqSponge; type ScalarSponge = DefaultFrSponge; - use ark_ff::UniformRand; - let public = vec![Fp::from(3u8); 5]; let gates = create_circuit(0, public.len()); - let rng = &mut rand::rngs::OsRng; - let x = Fp::rand(rng); - // create witness let mut witness: [Vec; COLUMNS] = array::from_fn(|_| vec![Fp::zero(); gates.len()]); fill_in_witness(0, &mut witness, &public); @@ -112,13 +108,13 @@ fn test_generic_gate_pairing() { // create and verify proof based on the witness >, + poly_commitment::kzg::KZGProof>, > as Default>::default() .gates(gates) .witness(witness) .public_inputs(public) - .setup_with_custom_srs(|d1, usize| { - let srs = poly_commitment::pairing_proof::PairingSRS::create(x, usize); + .setup_with_custom_srs(|d1, srs_size| { + let srs = poly_commitment::kzg::PairingSRS::create(srs_size); srs.full_srs.get_lagrange_basis(d1); srs }) diff --git a/kimchi/src/tests/keccak.rs b/kimchi/src/tests/keccak.rs index a29e13c40b..cb0f0e5387 100644 --- a/kimchi/src/tests/keccak.rs +++ b/kimchi/src/tests/keccak.rs @@ -1,54 +1,147 @@ use std::array; -use crate::circuits::{ - constraints::ConstraintSystem, - gate::CircuitGate, - polynomials::keccak::{self, ROT_TAB}, - wires::Wire, +use crate::{ + circuits::{ + constraints::ConstraintSystem, + gate::{CircuitGate, GateType}, + polynomials::keccak::{constants::KECCAK_COLS, witness::extend_keccak_witness, Keccak}, + wires::Wire, + }, + curve::KimchiCurve, }; -use ark_ec::AffineRepr; -use mina_curves::pasta::{Fp, Pallas, Vesta}; -use rand::Rng; +use ark_ff::{Field, PrimeField, Zero}; +use mina_curves::pasta::Pallas; +use num_bigint::BigUint; +use o1_utils::{BigUintHelpers, FieldHelpers}; -//use super::framework::TestFramework; -type PallasField = ::BaseField; - -fn create_test_constraint_system() -> ConstraintSystem { - let (mut next_row, mut gates) = { CircuitGate::::create_keccak(0) }; +fn create_test_constraint_system( + bytelength: usize, +) -> ConstraintSystem +where + G::BaseField: PrimeField, +{ + let mut gates = vec![]; + let next_row = CircuitGate::extend_keccak(&mut gates, bytelength); + // Adding dummy row to avoid out of bounds in squeeze constraints accessing Next row + gates.push(CircuitGate { + typ: GateType::Zero, + wires: Wire::for_row(next_row), + coeffs: vec![], + }); ConstraintSystem::create(gates).build().unwrap() } -#[test] -// Test that all of the offsets in the rotation table work fine -fn test_keccak_table() { - let cs = create_test_constraint_system(); - let state = array::from_fn(|_| { - array::from_fn(|_| rand::thread_rng().gen_range(0..2u128.pow(64)) as u64) - }); - let witness = keccak::create_witness_keccak_rot(state); - for row in 0..=48 { - assert_eq!( - cs.gates[row].verify_witness::( - row, - &witness, - &cs, - &witness[0][0..cs.public].to_vec() - ), - Ok(()) - ); +fn create_keccak_witness(message: BigUint) -> [Vec; KECCAK_COLS] +where + G::BaseField: PrimeField, +{ + let mut witness: [Vec; KECCAK_COLS] = + array::from_fn(|_| vec![G::ScalarField::zero(); 0]); + extend_keccak_witness(&mut witness, message); + // Adding dummy row to avoid out of bounds in squeeze constraints accessing Next row + let dummy_row: [Vec; KECCAK_COLS] = + array::from_fn(|_| vec![G::ScalarField::zero()]); + for col in 0..KECCAK_COLS { + witness[col].extend(dummy_row[col].iter()); + } + witness +} + +fn eprint_witness(witness: &[Vec; KECCAK_COLS], round: usize) { + fn to_u64(elem: F) -> u64 { + let mut bytes = FieldHelpers::::to_bytes(&elem); + bytes.reverse(); + bytes.iter().fold(0, |acc: u64, x| (acc << 8) + *x as u64) } - let mut rot = 0; - for (x, row) in ROT_TAB.iter().enumerate() { - for (y, &bits) in row.iter().enumerate() { - if bits == 0 { - continue; + fn eprint_matrix(state: &[u64]) { + for x in 0..5 { + eprint!(" "); + for y in 0..5 { + let quarters = &state[4 * (5 * y + x)..4 * (5 * y + x) + 4]; + let word = + Keccak::compose(&Keccak::collapse(&Keccak::reset(&Keccak::shift(quarters)))); + eprint!("{:016x} ", word); } - assert_eq!( - PallasField::from(state[x][y].rotate_left(bits)), - witness[1][1 + 2 * rot], - ); - rot += 1; + eprintln!(); } } + + let row = witness + .iter() + .map(|x| to_u64::(x[round])) + .collect::>(); + let next = witness + .iter() + .map(|x| to_u64::(x[round + 1])) + .collect::>(); + + eprintln!("----------------------------------------"); + eprintln!("ROUND {}", round); + eprintln!("State A:"); + eprint_matrix(&row[0..100]); + eprintln!("State G:"); + eprint_matrix(&next[0..100]); +} + +// Sets up test for a given message and desired input bytelength +fn setup_keccak_test(message: BigUint) -> BigUint +where + G::BaseField: PrimeField, +{ + let bytelength = message.to_bytes_be().len(); + let padded_len = { + let mut sized = message.to_bytes_be(); + sized.resize(bytelength - sized.len(), 0); + Keccak::pad(&sized).len() + }; + let _index = create_test_constraint_system::(padded_len); + let witness = create_keccak_witness::(message); + + for r in 1..=24 { + eprint_witness::(&witness, r); + } + + let mut hash = vec![]; + let hash_row = witness[0].len() - 2; // Hash row is dummy row + eprintln!(); + eprintln!("----------------------------------------"); + eprint!("Hash: "); + for b in 0..32 { + hash.push(FieldHelpers::to_bytes(&witness[200 + b][hash_row])[0]); + eprint!("{:02x}", hash[b]); + } + eprintln!(); + eprintln!(); + + BigUint::from_bytes_be(&hash) +} + +#[test] +// Tests a random block of 1080 bits +fn test_random_block() { + let expected_random = setup_keccak_test::( + BigUint::from_hex("832588523900cca2ea9b8c0395d295aa39f9a9285a982b71cc8475067a8175f38f235a2234abc982a2dfaaddff2895a28598021895206a733a22bccd21f124df1413858a8f9a1134df285a888b099a8c2235eecdf2345f3afd32f3ae323526689172850672938104892357aad32523523f423423a214325d13523aadb21414124aaadf32523126"), + ); + let hash_random = + BigUint::from_hex("845e9dd4e22b4917a80c5419a0ddb3eebf5f4f7cc6035d827314a18b718f751f"); + assert_eq!(expected_random, hash_random); +} + +#[test] +// Test hash of message zero with 1 byte input length +fn test_zero() { + let expected1 = setup_keccak_test::(BigUint::from_bytes_be(&[0x00])); + let hash1 = + BigUint::from_hex("bc36789e7a1e281436464229828f817d6612f7b477d66591ff96a9e064bcc98a"); + assert_eq!(expected1, hash1); +} + +#[test] +// Test hash of message using 3 blocks +fn test_blocks() { + let expected_3blocks = setup_keccak_test::(BigUint::from_hex("832588523900cca2ea9b8c0395d295aa39f9a9285a982b71cc8475067a8175f38f235a2234abc982a2dfaaddff2895a28598021895206a733a22bccd21f124df1413858a8f9a1134df285a888b099a8c2235eecdf2345f3afd32f3ae323526689172850672938104892357aad32523523f423423a214325d13523aadb21414124aaadf32523126832588523900cca2ea9b8c0395d295aa39f9a9285a982b71cc8475067a8175f38f235a2234abc982a2dfaaddff2895a28598021895206a733a22bccd21f124df1413858a8f9a1134df285a888b099a8c2235eecdf2345f3afd32f3ae323526689172850672938104892357aad32523523f423423a214325d13523aadb21414124aaadf32523126832588523900cca2ea9b8c0395d295aa39f9a9285a982b71cc8475067a8175f38f235a2234abc982a2dfaaddff2895a28598021895206a733a22bccd21f124df1413858a8f9a1134df285a888b099a8c2235eecdf2345f3afd32f3ae323526689172850672938104892357aad32523523f")); + let hash_3blocks = + BigUint::from_hex("7e369e1a4362148fca24c67c76f14dbe24b75c73e9b0efdb8c46056c8514287e"); + assert_eq!(expected_3blocks, hash_3blocks); } diff --git a/kimchi/src/tests/lookup.rs b/kimchi/src/tests/lookup.rs index 630445b5d7..9c4c443873 100644 --- a/kimchi/src/tests/lookup.rs +++ b/kimchi/src/tests/lookup.rs @@ -22,9 +22,13 @@ type BaseSponge = DefaultFqSponge; type ScalarSponge = DefaultFrSponge; fn setup_lookup_proof(use_values_from_table: bool, num_lookups: usize, table_sizes: Vec) { + let seed: [u8; 32] = thread_rng().gen(); + eprintln!("Seed: {:?}", seed); + let mut rng = StdRng::from_seed(seed); + let mut lookup_table_values: Vec> = table_sizes .iter() - .map(|size| (0..*size).map(|_| rand::random()).collect()) + .map(|size| (0..*size).map(|_| rng.gen()).collect()) .collect(); // Zero table must have a zero row lookup_table_values[0][0] = From::from(0); @@ -56,16 +60,16 @@ fn setup_lookup_proof(use_values_from_table: bool, num_lookups: usize, table_siz let num_tables = table_sizes.len(); let mut tables_used = std::collections::HashSet::new(); for _ in 0..num_lookups { - let table_id = rand::random::() % num_tables; + let table_id = rng.gen::() % num_tables; tables_used.insert(table_id); let lookup_table_values: &Vec = &lookup_table_values[table_id]; lookup_table_ids.push((table_id as u64).into()); for i in 0..3 { - let index = rand::random::() % lookup_table_values.len(); + let index = rng.gen::() % lookup_table_values.len(); let value = if use_values_from_table { lookup_table_values[index] } else { - rand::random() + rng.gen() }; lookup_indexes[i].push((index as u64).into()); lookup_values[i].push(value); @@ -128,7 +132,7 @@ fn lookup_gate_rejects_bad_lookups_multiple_tables() { setup_lookup_proof(false, 500, vec![100, 50, 50, 2, 2]) } -fn setup_successfull_runtime_table_test( +fn setup_successful_runtime_table_test( runtime_table_cfgs: Vec>, runtime_tables: Vec>, lookups: Vec, @@ -554,7 +558,7 @@ fn test_runtime_table_only_one_table_with_id_zero_with_non_zero_entries_fixed_va let lookups: Vec = [0; 20].into(); - setup_successfull_runtime_table_test(vec![cfg], vec![runtime_table], lookups); + setup_successful_runtime_table_test(vec![cfg], vec![runtime_table], lookups); } #[test] @@ -581,7 +585,7 @@ fn test_runtime_table_only_one_table_with_id_zero_with_non_zero_entries_random_v let lookups: Vec = [0; 20].into(); - setup_successfull_runtime_table_test(vec![cfg], vec![runtime_table], lookups); + setup_successful_runtime_table_test(vec![cfg], vec![runtime_table], lookups); } // This test verifies that if there is a table with ID 0, it contains a row with only zeroes. diff --git a/kimchi/src/tests/mod.rs b/kimchi/src/tests/mod.rs index ec35f9635f..27d21ff483 100644 --- a/kimchi/src/tests/mod.rs +++ b/kimchi/src/tests/mod.rs @@ -1,3 +1,4 @@ +// IMPROVEME: move all tests in top-level directory tests mod and; mod chunked; mod ec; @@ -7,6 +8,7 @@ mod foreign_field_add; mod foreign_field_mul; mod framework; mod generic; +mod keccak; mod lookup; mod not; mod poseidon; @@ -14,6 +16,5 @@ mod range_check; mod recursion; mod rot; mod serde; -mod turshi; mod varbasemul; mod xor; diff --git a/kimchi/src/tests/not.rs b/kimchi/src/tests/not.rs index 2f7e127b44..dcf03c7363 100644 --- a/kimchi/src/tests/not.rs +++ b/kimchi/src/tests/not.rs @@ -23,7 +23,7 @@ use mina_poseidon::{ }; use num_bigint::BigUint; use o1_utils::{BigUintHelpers, BitwiseOps, FieldHelpers, RandomField}; -use poly_commitment::evaluation_proof::OpeningProof; +use poly_commitment::ipa::OpeningProof; type PallasField = ::BaseField; type VestaField = ::BaseField; diff --git a/kimchi/src/tests/range_check.rs b/kimchi/src/tests/range_check.rs index d8d75bfb29..42d5929489 100644 --- a/kimchi/src/tests/range_check.rs +++ b/kimchi/src/tests/range_check.rs @@ -4,41 +4,36 @@ use crate::{ gate::{CircuitGate, CircuitGateError, GateType}, polynomial::COLUMNS, polynomials::{ + foreign_field_common::{ + BigUintArrayFieldHelpers, BigUintForeignFieldHelpers, FieldArrayCompact, + KimchiForeignElement, + }, generic::GenericGateSpec, range_check::{self}, }, wires::Wire, }, proof::ProverProof, - prover_index::testing::new_index_for_test_with_lookups, + prover_index::{testing::new_index_for_test_with_lookups, ProverIndex}, + verifier::verify, }; - use ark_ec::AffineRepr; use ark_ff::{Field, One, Zero}; use ark_poly::EvaluationDomain; -use mina_curves::pasta::{Fp, Pallas, Vesta, VestaParameters}; -use num_bigint::{BigUint, RandBigInt}; -use o1_utils::{ - foreign_field::{ - BigUintArrayFieldHelpers, BigUintForeignFieldHelpers, FieldArrayCompact, - ForeignFieldHelpers, - }, - FieldHelpers, -}; - -use std::{array, sync::Arc}; - -use crate::{prover_index::ProverIndex, verifier::verify}; use groupmap::GroupMap; +use mina_curves::pasta::{Fp, Pallas, Vesta, VestaParameters}; use mina_poseidon::{ constants::PlonkSpongeConstantsKimchi, sponge::{DefaultFqSponge, DefaultFrSponge}, }; +use num_bigint::{BigUint, RandBigInt}; +use o1_utils::{foreign_field::ForeignFieldHelpers, FieldHelpers}; use poly_commitment::{ commitment::CommitmentCurve, - evaluation_proof::OpeningProof, - srs::{endos, SRS}, + ipa::{endos, OpeningProof, SRS}, + SRS as _, }; +use std::{array, sync::Arc}; use super::framework::TestFramework; @@ -1034,7 +1029,7 @@ fn verify_range_check1_test_next_row_lookups() { witness[col - 1][row] -= PallasField::one(); witness[col - 1 + 2 * row + 2][3] -= PallasField::one(); } else { - witness[col - 1][row] += PallasField::two_to_limb(); + witness[col - 1][row] += KimchiForeignElement::::two_to_limb(); } witness[col - 1 + 2 * row + 3][3] += PallasField::from(2u64.pow(12)); diff --git a/kimchi/src/tests/rot.rs b/kimchi/src/tests/rot.rs index 662f6faf2a..17c70d8d31 100644 --- a/kimchi/src/tests/rot.rs +++ b/kimchi/src/tests/rot.rs @@ -8,6 +8,7 @@ use crate::{ polynomial::COLUMNS, polynomials::{ generic::GenericGateSpec, + keccak::{constants::DIM, OFF}, rot::{self, RotMode}, }, wires::Wire, @@ -27,8 +28,8 @@ use mina_poseidon::{ }; use o1_utils::Two; use poly_commitment::{ - evaluation_proof::OpeningProof, - srs::{endos, SRS}, + ipa::{endos, OpeningProof, SRS}, + SRS as _, }; use rand::Rng; @@ -167,22 +168,20 @@ fn test_rot_random() { test_rot::(word, rot, RotMode::Right); } -#[should_panic] #[test] // Test that a bad rotation fails as expected fn test_zero_rot() { let rng = &mut o1_utils::tests::make_test_rng(None); let word = rng.gen_range(0..2u128.pow(64)) as u64; - create_rot_witness::(word, 0, RotMode::Left); + test_rot::(word, 0, RotMode::Left); } -#[should_panic] #[test] // Test that a bad rotation fails as expected fn test_large_rot() { let rng = &mut o1_utils::tests::make_test_rng(None); let word = rng.gen_range(0..2u128.pow(64)) as u64; - create_rot_witness::(word, 64, RotMode::Left); + test_rot::(word, 64, RotMode::Left); } #[test] @@ -244,11 +243,17 @@ fn test_bad_constraints() { // modify excess witness[2][1] += PallasField::one(); + witness[0][3] += PallasField::one(); assert_eq!( cs.gates[1].verify_witness::(1, &witness, &cs, &witness[0][0..cs.public]), Err(CircuitGateError::Constraint(GateType::Rot64, 9)) ); + assert_eq!( + cs.gates[3].verify_witness::(3, &witness, &cs, &witness[0][0..cs.public]), + Err(CircuitGateError::Constraint(GateType::RangeCheck0, 9)) + ); witness[2][1] -= PallasField::one(); + witness[0][3] -= PallasField::one(); // modify shifted witness[0][2] += PallasField::one(); @@ -260,6 +265,7 @@ fn test_bad_constraints() { cs.gates[2].verify_witness::(2, &witness, &cs, &witness[0][0..cs.public]), Err(CircuitGateError::Constraint(GateType::RangeCheck0, 9)) ); + witness[0][2] -= PallasField::one(); // modify value of shifted to be more than 64 bits witness[0][2] += PallasField::two_pow(64); @@ -278,6 +284,27 @@ fn test_bad_constraints() { dst: Wire { row: 0, col: 0 } }) ); + witness[2][2] -= PallasField::one(); + witness[0][2] -= PallasField::two_pow(64); + + // modify value of excess to be more than 64 bits + witness[0][3] += PallasField::two_pow(64); + witness[2][1] += PallasField::two_pow(64); + assert_eq!( + cs.gates[3].verify_witness::(3, &witness, &cs, &witness[0][0..cs.public]), + Err(CircuitGateError::Constraint(GateType::RangeCheck0, 9)) + ); + // Update decomposition + witness[2][3] += PallasField::one(); + // Make sure the 64-bit check fails + assert_eq!( + cs.gates[3].verify_witness::(3, &witness, &cs, &witness[0][0..cs.public]), + Err(CircuitGateError::CopyConstraint { + typ: GateType::RangeCheck0, + src: Wire { row: 3, col: 2 }, + dst: Wire { row: 2, col: 2 } + }) + ); } #[test] @@ -353,3 +380,63 @@ fn test_rot_finalization() { .prove_and_verify::() .unwrap(); } + +#[test] +// Test that all of the offsets in the rotation table work fine +fn test_keccak_table() { + let zero_row = 0; + let mut gates = vec![CircuitGate::::create_generic_gadget( + Wire::for_row(zero_row), + GenericGateSpec::Pub, + None, + )]; + let mut rot_row = zero_row + 1; + for col in OFF { + for rot in col { + // if rotation by 0 bits, no need to create a gate for it + if rot == 0 { + continue; + } + let mut rot64_gates = CircuitGate::create_rot64(rot_row, rot as u32); + rot_row += rot64_gates.len(); + // Append them to the full gates vector + gates.append(&mut rot64_gates); + // Check that 2 most significant limbs of shifted are zero + gates.connect_64bit(zero_row, rot_row - 1); + } + } + let cs = ConstraintSystem::create(gates).build().unwrap(); + + let state: [[u64; DIM]; DIM] = array::from_fn(|_| { + array::from_fn(|_| rand::thread_rng().gen_range(0..2u128.pow(64)) as u64) + }); + let mut witness: [Vec; COLUMNS] = array::from_fn(|_| vec![PallasField::zero()]); + for (y, col) in OFF.iter().enumerate() { + for (x, &rot) in col.iter().enumerate() { + if rot == 0 { + continue; + } + rot::extend_rot(&mut witness, state[x][y], rot as u32, RotMode::Left); + } + } + + for row in 0..=48 { + assert_eq!( + cs.gates[row].verify_witness::(row, &witness, &cs, &witness[0][0..cs.public]), + Ok(()) + ); + } + let mut rot = 0; + for (y, col) in OFF.iter().enumerate() { + for (x, &bits) in col.iter().enumerate() { + if bits == 0 { + continue; + } + assert_eq!( + PallasField::from(state[x][y].rotate_left(bits as u32)), + witness[1][1 + 3 * rot], + ); + rot += 1; + } + } +} diff --git a/kimchi/src/tests/serde.rs b/kimchi/src/tests/serde.rs index 7e9a8e4a49..e2e949d88a 100644 --- a/kimchi/src/tests/serde.rs +++ b/kimchi/src/tests/serde.rs @@ -17,7 +17,11 @@ use mina_poseidon::{ constants::PlonkSpongeConstantsKimchi, sponge::{DefaultFqSponge, DefaultFrSponge}, }; -use poly_commitment::{commitment::CommitmentCurve, evaluation_proof::OpeningProof, srs::SRS}; +use poly_commitment::{ + commitment::CommitmentCurve, + ipa::{OpeningProof, SRS}, + SRS as _, +}; use std::{array, time::Instant}; type SpongeParams = PlonkSpongeConstantsKimchi; diff --git a/kimchi/src/tests/xor.rs b/kimchi/src/tests/xor.rs index 8a5d8c7f01..cbf946b700 100644 --- a/kimchi/src/tests/xor.rs +++ b/kimchi/src/tests/xor.rs @@ -22,8 +22,8 @@ use mina_poseidon::{ use num_bigint::BigUint; use o1_utils::{BigUintHelpers, BitwiseOps, FieldHelpers, RandomField}; use poly_commitment::{ - evaluation_proof::OpeningProof, - srs::{endos, SRS}, + ipa::{endos, OpeningProof, SRS}, + SRS as _, }; use super::framework::TestFramework; diff --git a/kimchi/src/verifier.rs b/kimchi/src/verifier.rs index 6359180507..09369e56ab 100644 --- a/kimchi/src/verifier.rs +++ b/kimchi/src/verifier.rs @@ -3,8 +3,9 @@ use crate::{ circuits::{ argument::ArgumentType, + berkeley_columns::{BerkeleyChallenges, Column}, constraints::ConstraintSystem, - expr::{Column, Constants, PolishToken}, + expr::{Constants, PolishToken}, gate::GateType, lookup::{lookups::LookupPattern, tables::combine_table}, polynomials::permutation, @@ -86,6 +87,8 @@ impl<'a, G: KimchiCurve, OpeningProof: OpenProof> Context<'a, G, OpeningProof ForeignFieldMul => Some(self.verifier_index.foreign_field_mul_comm.as_ref()?), Xor16 => Some(self.verifier_index.xor_comm.as_ref()?), Rot64 => Some(self.verifier_index.rot_comm.as_ref()?), + KeccakRound => todo!(), + KeccakSponge => todo!(), } } } @@ -133,7 +136,9 @@ where let zk_rows = index.zk_rows; - //~ 1. Setup the Fq-Sponge. + //~ 1. Setup the Fq-Sponge. This sponge mostly absorbs group + // elements (points as tuples over the base field), but it + // squeezes out elements of the group's scalar field. let mut fq_sponge = EFqSponge::new(G::other_curve_sponge_params()); //~ 1. Absorb the digest of the VerifierIndex. @@ -235,11 +240,11 @@ where let alpha = alpha_chal.to_field(endo_r); //~ 1. Enforce that the length of the $t$ commitment is of size 7. - if self.commitments.t_comm.elems.len() > chunk_size * 7 { + if self.commitments.t_comm.len() > chunk_size * 7 { return Err(VerifyError::IncorrectCommitmentLength( "t", chunk_size * 7, - self.commitments.t_comm.elems.len(), + self.commitments.t_comm.len(), )); } @@ -253,7 +258,11 @@ where //~ 1. Derive $\zeta$ from $\zeta'$ using the endomorphism (TODO: specify). let zeta = zeta_chal.to_field(endo_r); - //~ 1. Setup the Fr-Sponge. + //~ 1. Setup the Fr-Sponge. This sponge absorbs elements from + // the scalar field of the curve (equal to the base field of + // the previous recursion round), and squeezes scalar elements + // of the field. The squeeze result is the same as with the + // `fq_sponge`. let digest = fq_sponge.clone().digest(); let mut fr_sponge = EFrSponge::new(G::sponge_params()); @@ -366,13 +375,13 @@ where fr_sponge.absorb_multiple(&public_evals[1]); fr_sponge.absorb_evaluations(&self.evals); - //~ 1. Sample $v'$ with the Fr-Sponge. + //~ 1. Sample the "polyscale" $v'$ with the Fr-Sponge. let v_chal = fr_sponge.challenge(); //~ 1. Derive $v$ from $v'$ using the endomorphism (TODO: specify). let v = v_chal.to_field(endo_r); - //~ 1. Sample $u'$ with the Fr-Sponge. + //~ 1. Sample the "evalscale" $u'$ with the Fr-Sponge. let u_chal = fr_sponge.challenge(); //~ 1. Derive $u$ from $u'$ using the endomorphism (TODO: specify). @@ -436,14 +445,19 @@ where ft_eval0 += numerator * denominator; let constants = Constants { - alpha, - beta, - gamma, - joint_combiner: joint_combiner.as_ref().map(|j| j.1), endo_coefficient: index.endo, mds: &G::sponge_params().mds, zk_rows, }; + let challenges = BerkeleyChallenges { + alpha, + beta, + gamma, + joint_combiner: joint_combiner + .as_ref() + .map(|j| j.1) + .unwrap_or(G::ScalarField::zero()), + }; ft_eval0 -= PolishToken::evaluate( &index.linearization.constant_term, @@ -451,6 +465,7 @@ where zeta, &evals, &constants, + &challenges, ) .unwrap(); @@ -798,7 +813,7 @@ where } let lgr_comm = verifier_index .srs() - .get_lagrange_basis(verifier_index.domain.size()); + .get_lagrange_basis(verifier_index.domain); let com: Vec<_> = lgr_comm.iter().take(verifier_index.public).collect(); if public_input.is_empty() { PolyComm::new(vec![verifier_index.srs().blinding_commitment(); chunk_size]) @@ -868,16 +883,22 @@ where // other gates are implemented using the expression framework { - // TODO: Reuse constants from oracles function + // TODO: Reuse constants and challenges from oracles function let constants = Constants { - alpha: oracles.alpha, - beta: oracles.beta, - gamma: oracles.gamma, - joint_combiner: oracles.joint_combiner.as_ref().map(|j| j.1), endo_coefficient: verifier_index.endo, mds: &G::sponge_params().mds, zk_rows, }; + let challenges = BerkeleyChallenges { + alpha: oracles.alpha, + beta: oracles.beta, + gamma: oracles.gamma, + joint_combiner: oracles + .joint_combiner + .as_ref() + .map(|j| j.1) + .unwrap_or(G::ScalarField::zero()), + }; for (col, tokens) in &verifier_index.linearization.index_terms { let scalar = PolishToken::evaluate( @@ -886,6 +907,7 @@ where oracles.zeta, &evals, &constants, + &challenges, ) .expect("should evaluate"); diff --git a/kimchi/src/verifier_index.rs b/kimchi/src/verifier_index.rs index 18fc60b515..0409f56c23 100644 --- a/kimchi/src/verifier_index.rs +++ b/kimchi/src/verifier_index.rs @@ -4,6 +4,7 @@ use crate::{ alphas::Alphas, circuits::{ + berkeley_columns::{BerkeleyChallengeTerm, Column}, expr::{Linearization, PolishToken}, lookup::{index::LookupSelectors, lookups::LookupInfo}, polynomials::permutation::{vanishes_on_last_n_rows, zk_w}, @@ -17,7 +18,7 @@ use ark_poly::{univariate::DensePolynomial, Radix2EvaluationDomain as D}; use mina_poseidon::FqSponge; use once_cell::sync::OnceCell; use poly_commitment::{ - commitment::{CommitmentCurve, PolyComm}, + commitment::{absorb_commitment, CommitmentCurve, PolyComm}, OpenProof, SRS as _, }; use serde::{de::DeserializeOwned, Deserialize, Serialize}; @@ -144,7 +145,8 @@ pub struct VerifierIndex> { pub lookup_index: Option>, #[serde(skip)] - pub linearization: Linearization>>, + pub linearization: + Linearization>, Column>, /// The mapping between powers of alpha and constraints #[serde(skip)] pub powers_of_alpha: Alphas, @@ -439,42 +441,42 @@ impl> VerifierIndex // Always present for comm in sigma_comm.iter() { - fq_sponge.absorb_g(&comm.elems); + absorb_commitment(&mut fq_sponge, comm); } for comm in coefficients_comm.iter() { - fq_sponge.absorb_g(&comm.elems); + absorb_commitment(&mut fq_sponge, comm); } - fq_sponge.absorb_g(&generic_comm.elems); - fq_sponge.absorb_g(&psm_comm.elems); - fq_sponge.absorb_g(&complete_add_comm.elems); - fq_sponge.absorb_g(&mul_comm.elems); - fq_sponge.absorb_g(&emul_comm.elems); - fq_sponge.absorb_g(&endomul_scalar_comm.elems); + absorb_commitment(&mut fq_sponge, generic_comm); + absorb_commitment(&mut fq_sponge, psm_comm); + absorb_commitment(&mut fq_sponge, complete_add_comm); + absorb_commitment(&mut fq_sponge, mul_comm); + absorb_commitment(&mut fq_sponge, emul_comm); + absorb_commitment(&mut fq_sponge, endomul_scalar_comm); // Optional gates if let Some(range_check0_comm) = range_check0_comm { - fq_sponge.absorb_g(&range_check0_comm.elems); + absorb_commitment(&mut fq_sponge, range_check0_comm); } if let Some(range_check1_comm) = range_check1_comm { - fq_sponge.absorb_g(&range_check1_comm.elems); + absorb_commitment(&mut fq_sponge, range_check1_comm); } if let Some(foreign_field_mul_comm) = foreign_field_mul_comm { - fq_sponge.absorb_g(&foreign_field_mul_comm.elems); + absorb_commitment(&mut fq_sponge, foreign_field_mul_comm); } if let Some(foreign_field_add_comm) = foreign_field_add_comm { - fq_sponge.absorb_g(&foreign_field_add_comm.elems); + absorb_commitment(&mut fq_sponge, foreign_field_add_comm); } if let Some(xor_comm) = xor_comm { - fq_sponge.absorb_g(&xor_comm.elems); + absorb_commitment(&mut fq_sponge, xor_comm); } if let Some(rot_comm) = rot_comm { - fq_sponge.absorb_g(&rot_comm.elems); + absorb_commitment(&mut fq_sponge, rot_comm); } // Lookup index; optional @@ -496,26 +498,26 @@ impl> VerifierIndex }) = lookup_index { for entry in lookup_table { - fq_sponge.absorb_g(&entry.elems); + absorb_commitment(&mut fq_sponge, entry); } if let Some(table_ids) = table_ids { - fq_sponge.absorb_g(&table_ids.elems); + absorb_commitment(&mut fq_sponge, table_ids); } if let Some(runtime_tables_selector) = runtime_tables_selector { - fq_sponge.absorb_g(&runtime_tables_selector.elems); + absorb_commitment(&mut fq_sponge, runtime_tables_selector); } if let Some(xor) = xor { - fq_sponge.absorb_g(&xor.elems); + absorb_commitment(&mut fq_sponge, xor); } if let Some(lookup) = lookup { - fq_sponge.absorb_g(&lookup.elems); + absorb_commitment(&mut fq_sponge, lookup); } if let Some(range_check) = range_check { - fq_sponge.absorb_g(&range_check.elems); + absorb_commitment(&mut fq_sponge, range_check); } if let Some(ffmul) = ffmul { - fq_sponge.absorb_g(&ffmul.elems); + absorb_commitment(&mut fq_sponge, ffmul); } } fq_sponge.digest_fq() diff --git a/kimchi/tests/test_constraints.rs b/kimchi/tests/test_constraints.rs new file mode 100644 index 0000000000..a5880df27d --- /dev/null +++ b/kimchi/tests/test_constraints.rs @@ -0,0 +1,88 @@ +use ark_ff::Zero; +use kimchi::circuits::{ + constraints::ConstraintSystem, + gate::{CircuitGate, GateType}, + lookup::{runtime_tables::RuntimeTableCfg, tables::LookupTable}, + wires::{Wire, PERMUTS}, +}; +use mina_curves::pasta::{Fp, Fq}; +use o1_utils::FieldHelpers; + +#[test] +pub fn test_domains_computation_with_runtime_tables() { + let dummy_gate = CircuitGate { + typ: GateType::Generic, + wires: [Wire::new(0, 0); PERMUTS], + coeffs: vec![Fp::zero()], + }; + // inputs + expected output + let data = [((10, 10), 128), ((0, 0), 8), ((5, 100), 512)]; + for ((number_of_rt_cfgs, size), expected_domain_size) in data.into_iter() { + let builder = ConstraintSystem::create(vec![dummy_gate.clone(), dummy_gate.clone()]); + let table_ids: Vec = (0..number_of_rt_cfgs).collect(); + let rt_cfgs: Vec> = table_ids + .into_iter() + .map(|table_id| { + let indexes: Vec = (0..size).collect(); + let first_column: Vec = indexes.into_iter().map(Fp::from).collect(); + RuntimeTableCfg { + id: table_id, + first_column, + } + }) + .collect(); + let res = builder.runtime(Some(rt_cfgs)).build().unwrap(); + assert_eq!(res.domain.d1.size, expected_domain_size) + } +} + +#[test] +fn test_lookup_domain_size_computation() { + let (next_start, range_check_gates_0) = CircuitGate::::create_range_check(0); /* 1 range_check gate */ + let (next_start, range_check_gates_1) = CircuitGate::::create_range_check(next_start); /* 1 range_check gate */ + let (next_start, xor_gates_0) = CircuitGate::::create_xor_gadget(next_start, 3); /* 1 xor gate */ + let (next_start, xor_gates_1) = CircuitGate::::create_xor_gadget(next_start, 3); /* 1 xor gate */ + let (_, ffm_gates) = + CircuitGate::::create_foreign_field_mul(next_start, &Fq::modulus_biguint()); /* 1 foreign field multiplication gate */ + let circuit_gates: Vec> = range_check_gates_0 + .into_iter() + .chain(range_check_gates_1) + .chain(xor_gates_0) + .chain(xor_gates_1) + .chain(ffm_gates) + .collect(); /* 2 range check gates + 2 xor gates + 1 foreign field multiplication */ + + // inputs + expected output + let data = [ + ( + (10, 10), + 8192, /* 8192 > 10 * 10 + 1 * 4096 + 1 * 256 + 1 + zk_row */ + ), + ( + (0, 0), + 8192, /* 8192 > 0 * 0 + 1 * 4096 + 1 * 256 + 1 + zk_row */ + ), + ( + (5, 100), + 8192, /* 8192 > 5 * 100 + 1 * 4096 + 1 * 256 + 1 + zk_row */ + ), + ]; + data.into_iter() + .for_each(|((number_of_table_ids, size), expected_domain_size)| { + let builder = ConstraintSystem::create(circuit_gates.clone()); + let table_ids: Vec = (3..number_of_table_ids + 3).collect(); + let lookup_tables: Vec> = table_ids + .into_iter() + .map(|id| { + let indexes: Vec = (0..size).collect(); + let data: Vec = indexes.into_iter().map(Fp::from).collect(); + LookupTable { + id, + data: vec![data], + } + }) + .collect(); + let res = builder.lookup(lookup_tables).build().unwrap(); + assert_eq!(res.domain.d1.size, expected_domain_size); + }); +} diff --git a/kimchi/tests/test_domain.rs b/kimchi/tests/test_domain.rs new file mode 100644 index 0000000000..6ae00d1c72 --- /dev/null +++ b/kimchi/tests/test_domain.rs @@ -0,0 +1,16 @@ +use ark_ff::Field; +use kimchi::circuits::domains::EvaluationDomains; +use mina_curves::pasta::Fp; + +#[test] +fn test_create_domain() { + if let Ok(d) = EvaluationDomains::::create(usize::MAX) { + assert!(d.d4.group_gen.pow([4]) == d.d1.group_gen); + assert!(d.d8.group_gen.pow([2]) == d.d4.group_gen); + println!("d8 = {:?}", d.d8.group_gen); + println!("d8^2 = {:?}", d.d8.group_gen.pow([2])); + println!("d4 = {:?}", d.d4.group_gen); + println!("d4 = {:?}", d.d4.group_gen.pow([4])); + println!("d1 = {:?}", d.d1.group_gen); + } +} diff --git a/kimchi/tests/test_expr.rs b/kimchi/tests/test_expr.rs new file mode 100644 index 0000000000..72a1b5d080 --- /dev/null +++ b/kimchi/tests/test_expr.rs @@ -0,0 +1,191 @@ +use ark_ff::{Field, One, UniformRand, Zero}; +use ark_poly::{domain::EvaluationDomain, univariate::DensePolynomial}; +use kimchi::{ + circuits::{ + berkeley_columns::{ + index, witness, witness_curr, BerkeleyChallengeTerm, BerkeleyChallenges, Environment, E, + }, + constraints::ConstraintSystem, + domains::EvaluationDomains, + expr::{constraints::ExprOps, *}, + gate::{CircuitGate, CurrOrNext, GateType}, + polynomials::generic::GenericGateSpec, + wires::{Wire, COLUMNS}, + }, + curve::KimchiCurve, + prover_index::ProverIndex, +}; +use mina_curves::pasta::{Fp, Pallas, Vesta}; +use poly_commitment::{ + ipa::{endos, OpeningProof, SRS}, + SRS as _, +}; +use rand::{prelude::StdRng, SeedableRng}; +use std::{ + array, + collections::{HashMap, HashSet}, + sync::Arc, +}; + +#[test] +#[should_panic] +fn test_failed_linearize() { + // w0 * w1 + let mut expr: E = E::zero(); + expr += witness_curr(0); + expr *= witness_curr(1); + + // since none of w0 or w1 is evaluated this should panic + let evaluated = HashSet::new(); + expr.linearize(evaluated).unwrap(); +} + +#[test] +#[should_panic] +fn test_degree_tracking() { + // The selector CompleteAdd has degree n-1 (so can be tracked with n evaluations in the domain d1 of size n). + // Raising a polynomial of degree n-1 to the power 8 makes it degree 8*(n-1) (and so it needs `8(n-1) + 1` evaluations). + // Since `d8` is of size `8n`, we are still good with that many evaluations to track the new polynomial. + // Raising it to the power 9 pushes us out of the domain d8, which will panic. + let mut expr: E = E::zero(); + expr += index(GateType::CompleteAdd); + let expr = expr.pow(9); + + // create a dummy env + let one = Fp::from(1u32); + let gates = vec![ + CircuitGate::create_generic_gadget( + Wire::for_row(0), + GenericGateSpec::Const(1u32.into()), + None, + ), + CircuitGate::create_generic_gadget( + Wire::for_row(1), + GenericGateSpec::Const(1u32.into()), + None, + ), + ]; + let index = { + let constraint_system = ConstraintSystem::fp_for_testing(gates); + let srs = SRS::::create(constraint_system.domain.d1.size()); + srs.get_lagrange_basis(constraint_system.domain.d1); + let srs = Arc::new(srs); + + let (endo_q, _endo_r) = endos::(); + ProverIndex::>::create(constraint_system, endo_q, srs) + }; + + let witness_cols: [_; COLUMNS] = array::from_fn(|_| DensePolynomial::zero()); + let permutation = DensePolynomial::zero(); + let domain_evals = index.cs.evaluate(&witness_cols, &permutation); + + let env = Environment { + constants: Constants { + endo_coefficient: one, + mds: &Vesta::sponge_params().mds, + zk_rows: 3, + }, + challenges: BerkeleyChallenges { + alpha: one, + beta: one, + gamma: one, + joint_combiner: one, + }, + witness: &domain_evals.d8.this.w, + coefficient: &index.column_evaluations.coefficients8, + vanishes_on_zero_knowledge_and_previous_rows: &index + .cs + .precomputations() + .vanishes_on_zero_knowledge_and_previous_rows, + z: &domain_evals.d8.this.z, + l0_1: l0_1(index.cs.domain.d1), + domain: index.cs.domain, + index: HashMap::new(), + lookup: None, + }; + + // this should panic as we don't have a domain large enough + expr.evaluations(&env); +} + +#[test] +fn test_unnormalized_lagrange_basis() { + let zk_rows = 3; + let domain = EvaluationDomains::::create(2usize.pow(10) + zk_rows) + .expect("failed to create evaluation domain"); + let rng = &mut StdRng::from_seed([17u8; 32]); + + // Check that both ways of computing lagrange basis give the same result + let d1_size: i32 = domain.d1.size().try_into().expect("domain size too big"); + for i in 1..d1_size { + let pt = Fp::rand(rng); + assert_eq!( + unnormalized_lagrange_basis(&domain.d1, d1_size - i, &pt), + unnormalized_lagrange_basis(&domain.d1, -i, &pt) + ); + } +} + +#[test] +fn test_arithmetic_ops() { + fn test_1>() -> T { + T::zero() + T::one() + } + assert_eq!(test_1::>(), E::zero() + E::one()); + assert_eq!(test_1::(), Fp::one()); + + fn test_2>() -> T { + T::one() + T::one() + } + assert_eq!(test_2::>(), E::one() + E::one()); + assert_eq!(test_2::(), Fp::from(2u64)); + + fn test_3>(x: T) -> T { + T::from(2u64) * x + } + assert_eq!( + test_3::>(E::from(3u64)), + E::from(2u64) * E::from(3u64) + ); + assert_eq!(test_3(Fp::from(3u64)), Fp::from(6u64)); + + fn test_4>(x: T) -> T { + x.clone() * (x.square() + T::from(7u64)) + } + assert_eq!( + test_4::>(E::from(5u64)), + E::from(5u64) * (Expr::square(E::from(5u64)) + E::from(7u64)) + ); + assert_eq!(test_4::(Fp::from(5u64)), Fp::from(160u64)); +} + +#[test] +fn test_combining_constraints_does_not_increase_degree() { + // Combining two constraints of degree 2 gives a degree 2 combined + // constraint. + // In other words, using the challenge `alpha` doesn't increase the degree. + // Testing with Berkeley configuration + + let mut expr1: E = E::zero(); + // (X0 + X1) * X2 + expr1 += witness(0, CurrOrNext::Curr); + expr1 += witness(1, CurrOrNext::Curr); + expr1 *= witness(2, CurrOrNext::Curr); + assert_eq!(expr1.degree(1, 0), 2); + + // (X2 + X0) * X1 + let mut expr2: E = E::zero(); + expr2 += witness(2, CurrOrNext::Curr); + expr2 += witness(0, CurrOrNext::Curr); + expr2 *= witness(1, CurrOrNext::Curr); + assert_eq!(expr2.degree(1, 0), 2); + + let combined_expr = Expr::combine_constraints(0..2, vec![expr1.clone(), expr2.clone()]); + assert_eq!(combined_expr.degree(1, 0), 2); + + expr2 *= witness(3, CurrOrNext::Curr); + assert_eq!(expr2.degree(1, 0), 3); + + let combined_expr = Expr::combine_constraints(0..2, vec![expr1.clone(), expr2.clone()]); + assert_eq!(combined_expr.degree(1, 0), 3); +} diff --git a/kimchi/src/tests/turshi.rs b/kimchi/tests/turshi.rs similarity index 99% rename from kimchi/src/tests/turshi.rs rename to kimchi/tests/turshi.rs index 7a6e1f0775..854f9d61a6 100644 --- a/kimchi/src/tests/turshi.rs +++ b/kimchi/tests/turshi.rs @@ -1,4 +1,4 @@ -use crate::circuits::{ +use kimchi::circuits::{ gate::CircuitGate, polynomials::turshi::{testing::*, witness::*}, }; diff --git a/msm/Cargo.toml b/msm/Cargo.toml new file mode 100644 index 0000000000..5a56b733f5 --- /dev/null +++ b/msm/Cargo.toml @@ -0,0 +1,42 @@ +[package] +name = "kimchi-msm" +version = "0.1.0" +description = "Large MSM circuit" +repository = "https://github.com/o1-labs/proof-systems" +homepage = "https://o1-labs.github.io/proof-systems/" +documentation = "https://o1-labs.github.io/proof-systems/rustdoc/" +readme = "README.md" +edition = "2021" +license = "Apache-2.0" + +[lib] +path = "src/lib.rs" + +[dependencies] +ark-std.workspace = true +ark-bn254.workspace = true +ark-serialize.workspace = true +log.workspace = true +o1-utils.workspace = true +itertools.workspace = true +kimchi.workspace = true +folding.workspace = true +poly-commitment.workspace = true +groupmap.workspace = true +mina-curves.workspace = true +mina-poseidon.workspace = true +num-bigint.workspace = true +num-integer.workspace = true +num-traits.workspace = true +rmp-serde.workspace = true +serde_json.workspace = true +serde.workspace = true +serde_with.workspace = true +strum.workspace = true +strum_macros.workspace = true +ark-poly.workspace = true +ark-ff.workspace = true +ark-ec.workspace = true +rand.workspace = true +rayon.workspace = true +thiserror.workspace = true \ No newline at end of file diff --git a/msm/README.md b/msm/README.md new file mode 100644 index 0000000000..1e34f08d9e --- /dev/null +++ b/msm/README.md @@ -0,0 +1,51 @@ +## MSM circuit + +This is a specialised circuit to compute a large MSM where the elements are +defined over a foreign field, like the one used by the Pickles verifier. +For simplicity, we suppose the foreign field is on maximum 256 bits. + +The witness is splitted in two different parts: +- Columns: this is the column of the circuit. +- Lookups: this contain a list of lookups per row. + [MVLookup](https://eprint.iacr.org/2022/1530.pdf) is used. It contains + the individual polynomials $f_i(X)$ containing the looked up values, the + table/polynomial $t(X)$ the polynomials $f_i(X)$ look into and the + multiplicity polynomial $m(X)$ as defined in the paper. + +The individual columns of the circuit are defined in `src/column.rs`. +The polynomials can be expressed using the expression framework, instantiated in `src/column.rs`, with +```rust +type E = Expr, MSMColumn> +``` + +The proof structure is defined in `src/proof.rs`. A proof consists of +commitments to the circuit columns and the lookup polynomials, in addition to +their evaluations, the evaluation points and the opening proofs. +The prover/verifier are respectively in `src/prover.rs` and `src/verifier.rs`. +Note that the prover/verifier implementation is generic enough to be used to +implement other circuits. + +The protocol used is a Plonk-ish arithmetisations without the permutation +argument. The circuit is wide enough (i.e. has enough +columns) to handle one elliptic curve addition on one row. + +The foreign field elements are defined with 16 limbs of 16bits. Limbs-wise +elementary addition and multiplication is used, and range checks on 16 bits is +used. +An addition of two field elements $a = (a_{1}, ..., a_{16})$ and $b = (b_{1}, ..., +b_{16})$ will be performed on one row, using one power of alpha for each (i.e one +constraint for each). + +The elliptic curve points are represented in affine coordinates. +As a reminder, the equations for addition are: + +```math +\lambda = \frac{y_2 - y_1}{x_2 - x_1} +``` + +```math +\begin{align} +x_3 & = \lambda^2 - x_{1} - x_{2} \\ +y_3 & = \lambda (x_{1} - x_{3}) - y_{1} +\end{align} +``` diff --git a/msm/range-coeffs.sage b/msm/range-coeffs.sage new file mode 100644 index 0000000000..e558e479b7 --- /dev/null +++ b/msm/range-coeffs.sage @@ -0,0 +1,128 @@ +# Range computations for the EC ADD circuit. + +#var('b') # Symbolic integer division is hard + +# Works with (15,4) too, or with pretty much anything? +b = 75 +n = 4 + +def compute_c_eq1(i): + if i == -1 or i > 2*n-2: + return 0 + elif i < n: + res = 2*(i+1)*(2^b - 1)^2 + (2^b - 1) + compute_c_eq1(i-1) + ci = res // 2^b + return ci + else: #if n <= i <= 2*n-2 + res = 2*(2*n-i-1)*(2^b - 1)^2 + compute_c_eq1(i-1) + ci = res // 2^b + return ci + +def compute_c_eq1_direct(i): + if i == -1 or i > 2*n-2: + return 0 + elif i < n: + return (i+1)*2^(b+1) - 2*i - 3 + else: # n <= i <= 2*n-2 + return (2*n-i-1)*2^(b+1) - 2*(2*n-i) + 3 + + + +print([compute_c_eq1(i) for i in [0..2*n-2]]) +print(max([compute_c_eq1(i) for i in [0..2*n-2]])) +print([compute_c_eq1(i) == compute_c_eq1_direct(i) for i in [0..2*n-2]]) + + +def compute_c_eq2(i): + if i == -1 or i > 2*n-2: + return (0,0) + (cprev1, cprev2) = compute_c_eq2(i-1) + if i < n: + res1 = 2*(i+1)*(2^b - 1)^2 + cprev1 + res2 = (i+1)*(2^b - 1)^2 + 3 * (2^b - 1) + cprev2 + c1 = res1 // 2^b + c2 = res2 // 2^b + return (c1,c2) + else: # n <= i <= 2*n-2 + res1 = 2*(2*n-i-1)*(2^b - 1)^2 + cprev1 + res2 = 2*(2*n-i-1)*(2^b - 1)^2 + cprev2 + c1 = res1 // 2^b + c2 = res2 // 2^b + return (c1,c2) + + +def compute_c_eq2_direct(i): + if i == -1 or i > 2*n-2: + return (0,0) + if i < n: + c1 = (i+1)*2^(b+1) - 2*i - 4 + if i == 0: + c2 = 2^b + else: + c2 = i*2^b + 2^b - i + return (c1,c2) + else: + c1 = (2*n-i-1)*2^(b+1) - 2*(2*n-i) + 3 + if i == n: + c2 = (2*n-i-1)*2^(b+1) - 2*(2*n-i) + 3 - (n-1) + else: + c2 = (2*n-i-1)*2^(b+1) - 2*(2*n-i) + 3 + return (c1,c2) + + +print([compute_c_eq2(i) for i in [0..2*n-2]]) +print(max([compute_c_eq2(i)[0] for i in [0..2*n-2]])) +print(max([compute_c_eq2(i)[1] for i in [0..2*n-2]])) +print([compute_c_eq2(i) == compute_c_eq2_direct(i) for i in [0..2*n-2]]) + + +def compute_c_eq3(i): + if i == -1 or i > 2*n-2: + return (0,0) + (cprev1, cprev2) = compute_c_eq3(i-1) + if i < n: + res1 = 2*(i+1)*(2^b - 1)^2 + cprev1 + res2 = (i+1)*(2^b - 1)^2 + 2 * (2^b - 1) + cprev2 + c1 = res1 // 2^b + c2 = res2 // 2^b + return (c1,c2) + else: + res1 = 2*(2*n-i-1)*(2^b - 1)^2 + cprev1 + res2 = 2*(2*n-i-1)*(2^b - 1)^2 + cprev2 + c1 = res1 // 2^b + c2 = res2 // 2^b + return (c1,c2) + + +def compute_c_eq3_direct(i): + if i == -1 or i > 2*n-2: + return (0,0) + if i < n: + c1 = (i+1)*2^(b+1) - 2*i - 4 + c2 = (i + 1) * 2^b - i - 1 + return (c1,c2) + else: + c1 = (2*n-i-1)*2^(b+1) - 2*(2*n-i) + 3 + if i == n: + c2 = (2*n-i-1)*2^(b+1) - 2*(2*n-i) + 3 - (n-1) + else: + c2 = (2*n-i-1)*2^(b+1) - 2*(2*n-i) + 3 + return (c1,c2) + + + +print([compute_c_eq3(i) for i in [0..2*n-2]]) +print(max([compute_c_eq3(i)[0] for i in [0..2*n-2]])) +print(max([compute_c_eq3(i)[1] for i in [0..2*n-2]])) +print([compute_c_eq3(i) == compute_c_eq3_direct(i) for i in [0..2*n-2]]) + +#print([compute_c_eq3(i)[0] // 2^b for i in [0..n-1]]) +#print([compute_c_eq3(i)[0] % 2^b - 2^b for i in [0..n-1]]) +#print([compute_c_eq3(i)[0] // 2^b for i in [n..2*n-2]]) +#print([compute_c_eq3(i)[0] % 2^b - 2^b for i in [n..2*n-2]]) +#print([compute_c_eq3_direct(i)[0] // 2^b for i in [n..2*n-2]]) +#print([compute_c_eq3_direct(i)[0] % 2^b - 2^b for i in [n..2*n-2]]) +#print([compute_c_eq3(i)[1] // 2^b for i in [n..2*n-2]]) +#print([compute_c_eq3(i)[1] % 2^b - 2^b for i in [n..2*n-2]]) +#print([compute_c_eq3_direct(i)[1] // 2^b for i in [n..2*n-2]]) +#print([compute_c_eq3_direct(i)[1] % 2^b - 2^b for i in [n..2*n-2]]) diff --git a/msm/src/circuit_design/capabilities.rs b/msm/src/circuit_design/capabilities.rs new file mode 100644 index 0000000000..5e7780b343 --- /dev/null +++ b/msm/src/circuit_design/capabilities.rs @@ -0,0 +1,172 @@ +/// Provides generic environment traits that allow manipulating columns and +/// requesting lookups. +/// +/// Every trait implies two categories of implementations: +/// constraint ones (that operate over expressions, building a +/// circuit), and witness ones (that operate over values, building +/// values for the circuit). +use crate::{columns::ColumnIndexer, logup::LookupTableID}; +use ark_ff::PrimeField; + +/// Environment capability for accessing and reading columns. This is necessary for +/// building constraints. +pub trait ColAccessCap { + // NB: 'static here means that `Variable` does not contain any + // references with a lifetime less than 'static. Which is true in + // our case. Necessary for `set_assert_mapper` + type Variable: Clone + + std::ops::Add + + std::ops::Sub + + std::ops::Mul + + std::ops::Neg + + From + + std::fmt::Debug + + 'static; + + /// Asserts that the value is zero. + fn assert_zero(&mut self, cst: Self::Variable); + + /// Sets an assert predicate `f(X)` such that when assert_zero is + /// called on x, it will actually perform `assert_zero(f(x))`. + fn set_assert_mapper(&mut self, mapper: Box Self::Variable>); + + /// Reads value from a column position. + fn read_column(&self, col: CIx) -> Self::Variable; + + /// Turns a constant value into a variable. + fn constant(value: F) -> Self::Variable; +} + +/// Environment capability similar to `ColAccessCap` but for /also +/// writing/ columns. Used on the witness side. +pub trait ColWriteCap +where + Self: ColAccessCap, +{ + fn write_column(&mut self, col: CIx, value: &Self::Variable); +} + +/// Capability for invoking table lookups. +pub trait LookupCap +where + Self: ColAccessCap, +{ + /// Look up (read) value from a lookup table. + fn lookup(&mut self, lookup_id: LT, value: Vec); + + /// Write a value into a runtime table. Panics if called on a fixed table. + fn lookup_runtime_write(&mut self, lookup_id: LT, value: Vec); +} + +/// Capability for reading and moving forward in a multirow fashion. +/// Holds a "current" row that can be moved forward with `next_row`. +/// The `ColWriteCap` and `ColAccessCap` reason in terms of current +/// row. The two other methods can be used to read/write previous. +pub trait MultiRowReadCap +where + Self: ColWriteCap, +{ + /// Read value from a (row,column) position. + fn read_row_column(&mut self, row: usize, col: CIx) -> Self::Variable; + + /// Progresses to the next row. + fn next_row(&mut self); + + /// Returns the current row. + fn curr_row(&self) -> usize; +} + +// TODO this trait is very powerful. It basically abstract +// WitnessBuilderEnv (and other, similar environments). Nothing +// similar can be implemented for constraint building envs. +// +// Where possible, do your computation over Variable or directly via +// F-typed inputs to a function. +/// A direct field access capability modelling an abstract witness +/// builder. Not for constraint building. +pub trait DirectWitnessCap +where + Self: MultiRowReadCap, +{ + /// Convert an abstract variable to a field element! Inverse of Env::constant(). + fn variable_to_field(value: Self::Variable) -> F; +} + +//////////////////////////////////////////////////////////////////////////// +// Hybrid capabilities +//////////////////////////////////////////////////////////////////////////// + +/// Capability for computing arithmetic functions and enforcing +/// constraints simultaneously. +/// +/// The "hybrid" in the name of the trait (and other traits here) +/// means "maybe". +/// +/// That is, it allows computations which /might be/ no-ops (even +/// partially) in the constraint builder case. For example, "hcopy", +/// despite its name, does not do any "write", so hcopy !=> +/// write_column. +pub trait HybridCopyCap +where + Self: ColAccessCap, +{ + /// Given variable `x` and position `ix`, it (hybrid) writes `x` + /// into `ix`, and returns the value. + fn hcopy(&mut self, x: &Self::Variable, ix: CIx) -> Self::Variable; +} + +//////////////////////////////////////////////////////////////////////////// +// Helpers +//////////////////////////////////////////////////////////////////////////// + +/// Write an array of values simultaneously. +pub fn read_column_array( + env: &mut Env, + column_map: ColMap, +) -> [Env::Variable; ARR_N] +where + F: PrimeField, + Env: ColAccessCap, + ColMap: Fn(usize) -> CIx, +{ + core::array::from_fn(|i| env.read_column(column_map(i))) +} + +/// Write a field element directly as a constant. +pub fn write_column_const(env: &mut Env, col: CIx, var: &F) +where + F: PrimeField, + Env: ColWriteCap, +{ + env.write_column(col, &Env::constant(*var)); +} + +/// Write an array of values simultaneously. +pub fn write_column_array( + env: &mut Env, + input: [Env::Variable; ARR_N], + column_map: ColMap, +) where + F: PrimeField, + Env: ColWriteCap, + ColMap: Fn(usize) -> CIx, +{ + input.iter().enumerate().for_each(|(i, var)| { + env.write_column(column_map(i), var); + }) +} + +/// Write an array of /field/ values simultaneously. +pub fn write_column_array_const( + env: &mut Env, + input: &[F; ARR_N], + column_map: ColMap, +) where + F: PrimeField, + Env: ColWriteCap, + ColMap: Fn(usize) -> CIx, +{ + input.iter().enumerate().for_each(|(i, var)| { + env.write_column(column_map(i), &Env::constant(*var)); + }) +} diff --git a/msm/src/circuit_design/composition.rs b/msm/src/circuit_design/composition.rs new file mode 100644 index 0000000000..15028c4f44 --- /dev/null +++ b/msm/src/circuit_design/composition.rs @@ -0,0 +1,419 @@ +/// Tools to /compose/ different circuit designers. +/// +/// Assume we have several sets of columns: +/// Col0 ⊂ Col1 ⊂ Col2, and +/// Col1' ⊂ Col2 +/// +/// +/// For example they are laid out like this: +/// +/// |-------------------------------------------| Col2 +/// |-| |----------------------| Col1 +/// |-----| |--| |---| Col0 +/// |---| |-----| |---------------| Col1' +/// +/// Some columns might be even shared (e.g. Col1 and Col1' share column#0). +/// +/// Using capabilities one can define functions that operate over different sets of columns, +/// and does not "know" in which bigger context it operates. +/// - function0>(env: Env, ...) +/// - function1>(env: Env, ...) +/// - function1'>(env: Env, ...) +/// - function2>(env: Env, ...) +/// +/// This is similar to memory separation: a program function0 might +/// need just three columns for A * B - C constraint, and if it works +/// in a 1000 column environment it needs to be told /which three +/// exactly/ will it see. +/// +/// One only needs a single concrete Env (e.g. WitnessBuilder or +/// Constraint Builder) over the "top level" Col2, and then all these +/// functions can be called using lenses. Each lens describes how the +/// layouts will be mapped. +/// +/// |-------------------------------------------| Col2 +/// | | | +/// | Col2To1Lens | | +/// V | | +/// |-| |----------------------| | (compose(Col2To1Lens . Col1To0Lens) Col1 +/// | | | +/// | Col1To0Lens / | +/// | / | +/// V V | Col2To1'Lens +/// |-----| |--| |---| | Col0 +/// V +/// |---| |-----| |---------------| Col1' +/// +/// +/// Similar "mapping" intuition applies to lookup tables. +use crate::{ + circuit_design::capabilities::{ + ColAccessCap, ColWriteCap, DirectWitnessCap, HybridCopyCap, LookupCap, MultiRowReadCap, + }, + columns::ColumnIndexer, + logup::LookupTableID, +}; +use ark_ff::PrimeField; +use std::marker::PhantomData; + +/// `MPrism` allows one to Something like a Prism, but for Maybe and not just any Applicative. +/// +/// See +/// - +/// - +pub trait MPrism { + /// The lens source type, i.e., the object containing the field. + type Source; + + /// The lens target type, i.e., the field to be accessed or modified. + type Target; + + fn traverse(&self, source: Self::Source) -> Option; + + fn re_get(&self, target: Self::Target) -> Self::Source; +} + +/// Identity `MPrism` from any type `T` to itself. +/// +/// Can be used in many situations. E.g. when `foo1` and `foo2` both +/// call `bar` that is parameterised by a lens, and `foo1` has +/// identical context to `bar` (so requires the ID lens), but `foo2` +/// needs an actual non-ID lens. +#[derive(Clone, Copy, Debug)] +pub struct IdMPrism(pub PhantomData); + +impl Default for IdMPrism { + fn default() -> Self { + IdMPrism(PhantomData) + } +} + +impl MPrism for IdMPrism { + type Source = T; + type Target = T; + + fn traverse(&self, source: Self::Source) -> Option { + Some(source) + } + + fn re_get(&self, target: Self::Target) -> Self::Source { + target + } +} + +pub struct ComposedMPrism { + /// The left-hand side of the composition. + lhs: LHS, + + /// The right-hand side of the composition. + rhs: RHS, +} + +impl ComposedMPrism +where + LHS: MPrism, + LHS::Target: 'static, + RHS: MPrism, +{ + pub fn compose(lhs: LHS, rhs: RHS) -> ComposedMPrism { + ComposedMPrism { lhs, rhs } + } +} + +impl MPrism for ComposedMPrism +where + LHS: MPrism, + LHS::Target: 'static, + RHS: MPrism, +{ + type Source = LHS::Source; + type Target = RHS::Target; + + fn traverse(&self, source: Self::Source) -> Option { + let r1: Option<_> = self.lhs.traverse(source); + let r2: Option<_> = r1.and_then(|x| self.rhs.traverse(x)); + r2 + } + + fn re_get(&self, target: Self::Target) -> Self::Source { + self.lhs.re_get(self.rhs.re_get(target)) + } +} + +//////////////////////////////////////////////////////////////////////////// +// Interpreter and sub-interpreter +//////////////////////////////////////////////////////////////////////////// + +/// Generic sub-environment struct: don't use directly. It's an +/// internal object to avoid copy-paste. +/// +/// We can't use `SubEnv` directly because rust is not idris: it is +/// impossible to instantiate `SubEnv` with two /completely/ different +/// lenses and then write proper trait implementations. Rust complains +/// about conflicting trait implementations. +struct SubEnv<'a, F: PrimeField, CIx1: ColumnIndexer, Env1: ColAccessCap, L> { + env: &'a mut Env1, + lens: L, + phantom: PhantomData<(F, CIx1)>, +} + +/// Sub environment with a lens that is mapping columns. +pub struct SubEnvColumn<'a, F: PrimeField, CIx1: ColumnIndexer, Env1: ColAccessCap, L>( + SubEnv<'a, F, CIx1, Env1, L>, +); + +/// Sub environment with a lens that is mapping lookup tables. +pub struct SubEnvLookup<'a, F: PrimeField, CIx1: ColumnIndexer, Env1: ColAccessCap, L>( + SubEnv<'a, F, CIx1, Env1, L>, +); + +//////////////////////////////////////////////////////////////////////////// +// Trait implementations +//////////////////////////////////////////////////////////////////////////// + +impl<'a, F: PrimeField, CIx1: ColumnIndexer, Env1: ColAccessCap, L> + SubEnv<'a, F, CIx1, Env1, L> +{ + pub fn new(env: &'a mut Env1, lens: L) -> Self { + SubEnv { + env, + lens, + phantom: Default::default(), + } + } +} + +impl<'a, F: PrimeField, CIx1: ColumnIndexer, Env1: ColAccessCap, L> + SubEnvColumn<'a, F, CIx1, Env1, L> +{ + pub fn new(env: &'a mut Env1, lens: L) -> Self { + SubEnvColumn(SubEnv::new(env, lens)) + } +} + +impl<'a, F: PrimeField, CIx1: ColumnIndexer, Env1: ColAccessCap, L> + SubEnvLookup<'a, F, CIx1, Env1, L> +{ + pub fn new(env: &'a mut Env1, lens: L) -> Self { + SubEnvLookup(SubEnv::new(env, lens)) + } +} + +impl< + 'a, + F: PrimeField, + CIx1: ColumnIndexer, + CIx2: ColumnIndexer, + Env1: ColAccessCap, + L: MPrism, + > ColAccessCap for SubEnv<'a, F, CIx1, Env1, L> +{ + type Variable = Env1::Variable; + + fn assert_zero(&mut self, cst: Self::Variable) { + self.env.assert_zero(cst); + } + + fn set_assert_mapper(&mut self, mapper: Box Self::Variable>) { + self.env.set_assert_mapper(mapper); + } + + fn constant(value: F) -> Self::Variable { + Env1::constant(value) + } + + fn read_column(&self, ix: CIx2) -> Self::Variable { + self.env.read_column(self.lens.re_get(ix)) + } +} + +impl< + 'a, + F: PrimeField, + CIx1: ColumnIndexer, + CIx2: ColumnIndexer, + Env1: ColWriteCap, + L: MPrism, + > ColWriteCap for SubEnv<'a, F, CIx1, Env1, L> +{ + fn write_column(&mut self, ix: CIx2, value: &Self::Variable) { + self.env.write_column(self.lens.re_get(ix), value) + } +} + +impl< + 'a, + F: PrimeField, + CIx1: ColumnIndexer, + CIx2: ColumnIndexer, + Env1: HybridCopyCap, + L: MPrism, + > HybridCopyCap for SubEnv<'a, F, CIx1, Env1, L> +{ + fn hcopy(&mut self, x: &Self::Variable, ix: CIx2) -> Self::Variable { + self.env.hcopy(x, self.lens.re_get(ix)) + } +} + +impl< + 'a, + F: PrimeField, + CIx1: ColumnIndexer, + CIx2: ColumnIndexer, + Env1: ColAccessCap, + L: MPrism, + > ColAccessCap for SubEnvColumn<'a, F, CIx1, Env1, L> +{ + type Variable = Env1::Variable; + + fn assert_zero(&mut self, cst: Self::Variable) { + self.0.assert_zero(cst); + } + + fn set_assert_mapper(&mut self, mapper: Box Self::Variable>) { + self.0.set_assert_mapper(mapper); + } + + fn constant(value: F) -> Self::Variable { + Env1::constant(value) + } + + fn read_column(&self, ix: CIx2) -> Self::Variable { + self.0.read_column(ix) + } +} + +impl< + 'a, + F: PrimeField, + CIx1: ColumnIndexer, + CIx2: ColumnIndexer, + Env1: ColWriteCap, + L: MPrism, + > ColWriteCap for SubEnvColumn<'a, F, CIx1, Env1, L> +{ + fn write_column(&mut self, ix: CIx2, value: &Self::Variable) { + self.0.write_column(ix, value); + } +} + +impl< + 'a, + F: PrimeField, + CIx1: ColumnIndexer, + CIx2: ColumnIndexer, + Env1: HybridCopyCap, + L: MPrism, + > HybridCopyCap for SubEnvColumn<'a, F, CIx1, Env1, L> +{ + fn hcopy(&mut self, x: &Self::Variable, ix: CIx2) -> Self::Variable { + self.0.hcopy(x, ix) + } +} + +impl<'a, F: PrimeField, CIx1: ColumnIndexer, Env1: ColAccessCap, L> ColAccessCap + for SubEnvLookup<'a, F, CIx1, Env1, L> +{ + type Variable = Env1::Variable; + + fn assert_zero(&mut self, cst: Self::Variable) { + self.0.env.assert_zero(cst); + } + + fn set_assert_mapper(&mut self, mapper: Box Self::Variable>) { + self.0.env.set_assert_mapper(mapper); + } + + fn constant(value: F) -> Self::Variable { + Env1::constant(value) + } + + fn read_column(&self, ix: CIx1) -> Self::Variable { + self.0.env.read_column(ix) + } +} + +impl<'a, F: PrimeField, CIx1: ColumnIndexer, Env1: ColWriteCap, L> ColWriteCap + for SubEnvLookup<'a, F, CIx1, Env1, L> +{ + fn write_column(&mut self, ix: CIx1, value: &Self::Variable) { + self.0.env.write_column(ix, value); + } +} + +impl<'a, F: PrimeField, CIx1: ColumnIndexer, Env1: HybridCopyCap, L> HybridCopyCap + for SubEnvLookup<'a, F, CIx1, Env1, L> +{ + fn hcopy(&mut self, x: &Self::Variable, ix: CIx1) -> Self::Variable { + self.0.env.hcopy(x, ix) + } +} + +impl< + 'a, + F: PrimeField, + CIx: ColumnIndexer, + LT1: LookupTableID, + LT2: LookupTableID, + Env1: LookupCap, + L: MPrism, + > LookupCap for SubEnvLookup<'a, F, CIx, Env1, L> +{ + fn lookup(&mut self, lookup_id: LT2, value: Vec) { + self.0.env.lookup(self.0.lens.re_get(lookup_id), value) + } + + fn lookup_runtime_write(&mut self, lookup_id: LT2, value: Vec) { + self.0 + .env + .lookup_runtime_write(self.0.lens.re_get(lookup_id), value) + } +} + +impl< + 'a, + F: PrimeField, + CIx1: ColumnIndexer, + CIx2: ColumnIndexer, + LT: LookupTableID, + Env1: LookupCap, + L: MPrism, + > LookupCap for SubEnvColumn<'a, F, CIx1, Env1, L> +{ + fn lookup(&mut self, lookup_id: LT, value: Vec) { + self.0.env.lookup(lookup_id, value) + } + + fn lookup_runtime_write(&mut self, lookup_id: LT, value: Vec) { + self.0.env.lookup_runtime_write(lookup_id, value) + } +} + +impl<'a, F: PrimeField, CIx: ColumnIndexer, Env1: MultiRowReadCap, L> + MultiRowReadCap for SubEnvLookup<'a, F, CIx, Env1, L> +{ + /// Read value from a (row,column) position. + fn read_row_column(&mut self, row: usize, col: CIx) -> Self::Variable { + self.0.env.read_row_column(row, col) + } + + /// Progresses to the next row. + fn next_row(&mut self) { + self.0.env.next_row(); + } + + /// Returns the current row. + fn curr_row(&self) -> usize { + self.0.env.curr_row() + } +} + +impl<'a, F: PrimeField, CIx: ColumnIndexer, Env1: DirectWitnessCap, L> + DirectWitnessCap for SubEnvLookup<'a, F, CIx, Env1, L> +{ + fn variable_to_field(value: Self::Variable) -> F { + Env1::variable_to_field(value) + } +} + +// TODO add traits for SubEnvColumn diff --git a/msm/src/circuit_design/constraints.rs b/msm/src/circuit_design/constraints.rs new file mode 100644 index 0000000000..d9bd7396af --- /dev/null +++ b/msm/src/circuit_design/constraints.rs @@ -0,0 +1,122 @@ +use ark_ff::PrimeField; +use kimchi::circuits::{ + berkeley_columns::BerkeleyChallengeTerm, + expr::{ConstantExpr, ConstantTerm, Expr, ExprInner, Variable}, + gate::CurrOrNext, +}; +use std::collections::BTreeMap; + +use crate::{ + circuit_design::capabilities::{ColAccessCap, HybridCopyCap, LookupCap}, + columns::{Column, ColumnIndexer}, + expr::E, + logup::{constraint_lookups, LookupTableID}, +}; + +pub struct ConstraintBuilderEnv { + /// An indexed set of constraints. + pub constraints: Vec, Column>>, + /// Aggregated lookups or "reads". + pub lookup_reads: BTreeMap>>>, + /// Aggregated "write" lookups, for runtime tables. + pub lookup_writes: BTreeMap>>>, + /// The function that maps the argument of `assert_zero`. + pub assert_mapper: Box) -> E>, +} + +impl ConstraintBuilderEnv { + pub fn create() -> Self { + Self { + constraints: vec![], + lookup_reads: BTreeMap::new(), + lookup_writes: BTreeMap::new(), + assert_mapper: Box::new(|x| x), + } + } +} + +impl ColAccessCap + for ConstraintBuilderEnv +{ + type Variable = E; + + fn assert_zero(&mut self, cst: Self::Variable) { + self.constraints.push((self.assert_mapper)(cst)); + } + + fn set_assert_mapper(&mut self, mapper: Box Self::Variable>) { + self.assert_mapper = mapper; + } + + fn read_column(&self, position: CIx) -> Self::Variable { + Expr::Atom(ExprInner::Cell(Variable { + col: position.to_column(), + row: CurrOrNext::Curr, + })) + } + + fn constant(value: F) -> Self::Variable { + let cst_expr_inner = ConstantExpr::from(ConstantTerm::Literal(value)); + Expr::Atom(ExprInner::Constant(cst_expr_inner)) + } +} + +impl HybridCopyCap + for ConstraintBuilderEnv +{ + fn hcopy(&mut self, x: &Self::Variable, position: CIx) -> Self::Variable { + let y = Expr::Atom(ExprInner::Cell(Variable { + col: position.to_column(), + row: CurrOrNext::Curr, + })); + as ColAccessCap>::assert_zero( + self, + y.clone() - x.clone(), + ); + y + } +} + +impl LookupCap + for ConstraintBuilderEnv +{ + fn lookup(&mut self, table_id: LT, value: Vec<>::Variable>) { + self.lookup_reads + .entry(table_id) + .or_default() + .push(value.clone()); + } + + fn lookup_runtime_write(&mut self, table_id: LT, value: Vec) { + assert!( + !table_id.is_fixed(), + "lookup_runtime_write must be called on non-fixed tables only" + ); + self.lookup_writes + .entry(table_id) + .or_default() + .push(value.clone()); + } +} + +impl ConstraintBuilderEnv { + /// Get constraints related to the application logic itself. + pub fn get_relation_constraints(&self) -> Vec> { + self.constraints.clone() + } + + /// Get constraints related to the lookup argument. + pub fn get_lookup_constraints(&self) -> Vec> { + constraint_lookups(&self.lookup_reads, &self.lookup_writes) + } + + /// Get all relevant constraints generated by the constraint builder. + pub fn get_constraints(&self) -> Vec> { + let mut constraints: Vec> = vec![]; + constraints.extend(self.get_relation_constraints()); + if !self.lookup_reads.is_empty() { + constraints.extend(self.get_lookup_constraints()); + } + constraints + } +} diff --git a/msm/src/circuit_design/helpers.rs b/msm/src/circuit_design/helpers.rs new file mode 100644 index 0000000000..e69de29bb2 diff --git a/msm/src/circuit_design/mod.rs b/msm/src/circuit_design/mod.rs new file mode 100644 index 0000000000..2c521e61cb --- /dev/null +++ b/msm/src/circuit_design/mod.rs @@ -0,0 +1,12 @@ +pub mod capabilities; +pub mod composition; +pub mod constraints; +pub mod witness; + +// Reexport main types +pub use crate::circuit_design::{ + capabilities::*, + composition::{MPrism, SubEnvColumn, SubEnvLookup}, + constraints::ConstraintBuilderEnv, + witness::WitnessBuilderEnv, +}; diff --git a/msm/src/circuit_design/witness.rs b/msm/src/circuit_design/witness.rs new file mode 100644 index 0000000000..7266594f6b --- /dev/null +++ b/msm/src/circuit_design/witness.rs @@ -0,0 +1,677 @@ +use crate::{ + circuit_design::capabilities::{ + ColAccessCap, ColWriteCap, DirectWitnessCap, HybridCopyCap, LookupCap, MultiRowReadCap, + }, + columns::{Column, ColumnIndexer}, + logup::{Logup, LogupWitness, LookupTableID}, + proof::ProofInputs, + witness::Witness, +}; +use ark_ff::PrimeField; +use log::debug; +use std::{collections::BTreeMap, iter, marker::PhantomData}; + +/// Witness builder environment. Operates on multiple rows at the same +/// time. `CIx::N_COL` must be equal to `N_WIT + N_FSEL`; passing these two +/// separately is due to a rust limitation. +pub struct WitnessBuilderEnv< + F: PrimeField, + CIx: ColumnIndexer, + const N_WIT: usize, + const N_REL: usize, + const N_DSEL: usize, + const N_FSEL: usize, + LT: LookupTableID, +> { + /// The witness columns that the environment is working with. + /// Every element of the vector is a row, and the builder is + /// always processing the last row. + pub witness: Vec>, + + /// Lookup multiplicities, a vector of values `m_i` per lookup + /// table, where `m_i` is how many times the lookup value number + /// `i` was looked up. + pub lookup_multiplicities: BTreeMap>, + + /// Lookup "read" requests per table. Each element of the map is a + /// vector of <#number_of_reads_per_row> columns. Each column is a + /// vector of elements. Each element is vector of field elements. + /// + /// - `lookup_reads[table_id][read_i]` is a column corresponding to a read #`read_i` per row. + /// - `lookup_reads[table_id][read_i][row_i]` is a value-vector that's looked up at `row_i` + pub lookup_reads: BTreeMap>>>, + + /// Values for runtime tables. Each element (value) in the map is + /// a set of on-the-fly built columns, one column per write. + /// + /// Format is the same as `lookup_reads`. + /// + /// - `runtime_tables[table_id][write_i]` is a column corresponding to a write #`write_i` per row. + /// - `runtime_tables[table_id][write_i][row_i]` is a value-vector that's looked up at `row_i` + pub runtime_lookup_writes: BTreeMap>>>, + + /// Fixed values for selector columns. `fixed_selectors[i][j]` is the + /// value for row #j of the selector #i. + pub fixed_selectors: Vec>, + + /// Function used to map assertions. + pub assert_mapper: Box F>, + + // A Phantom Data for CIx -- right now WitnessBUilderEnv does not + // depend on CIx, but in the future (with associated generics + // enabled?) it might be convenient to put all the `NT_COL` (and + // other) constants into `CIx`. Logically, all these constants + // "belong" to CIx, so there's an extra type parameter, and a + // phantom data to support it. + pub phantom_cix: PhantomData, +} + +impl< + F: PrimeField, + CIx: ColumnIndexer, + const N_WIT: usize, + const N_REL: usize, + const N_DSEL: usize, + const N_FSEL: usize, + LT: LookupTableID, + > ColAccessCap for WitnessBuilderEnv +{ + // Requiring an F element as we would need to compute values up to 180 bits + // in the 15 bits decomposition. + type Variable = F; + + fn assert_zero(&mut self, cst: Self::Variable) { + assert_eq!((self.assert_mapper)(cst), F::zero()); + } + + fn set_assert_mapper(&mut self, mapper: Box Self::Variable>) { + self.assert_mapper = mapper; + } + + fn constant(value: F) -> Self::Variable { + value + } + + fn read_column(&self, ix: CIx) -> Self::Variable { + match ix.to_column() { + Column::Relation(i) => self.witness.last().unwrap().cols[i], + Column::FixedSelector(i) => self.fixed_selectors[i][self.witness.len() - 1], + other => panic!("WitnessBuilderEnv::read_column does not support {other:?}"), + } + } +} + +impl< + F: PrimeField, + CIx: ColumnIndexer, + const N_WIT: usize, + const N_REL: usize, + const N_DSEL: usize, + const N_FSEL: usize, + LT: LookupTableID, + > ColWriteCap for WitnessBuilderEnv +{ + fn write_column(&mut self, ix: CIx, value: &Self::Variable) { + self.write_column_raw(ix.to_column(), *value); + } +} + +/// If `Env` implements real write ("for sure" writes), you can implement +/// hybrid copy (that is only required to "maybe" copy). The other way +/// around violates the semantics. +/// +/// Sadly, rust does not allow "cover" instances to define this impl +/// for every `T: ColWriteCap`. +impl< + F: PrimeField, + CIx: ColumnIndexer, + const N_WIT: usize, + const N_REL: usize, + const N_DSEL: usize, + const N_FSEL: usize, + LT: LookupTableID, + > HybridCopyCap for WitnessBuilderEnv +{ + fn hcopy(&mut self, value: &Self::Variable, ix: CIx) -> Self::Variable { + as ColWriteCap>::write_column( + self, ix, value, + ); + *value + } +} + +impl< + F: PrimeField, + CIx: ColumnIndexer, + const N_WIT: usize, + const N_REL: usize, + const N_DSEL: usize, + const N_FSEL: usize, + LT: LookupTableID, + > MultiRowReadCap for WitnessBuilderEnv +{ + /// Read value from a (row,column) position. + fn read_row_column(&mut self, row: usize, col: CIx) -> Self::Variable { + let Column::Relation(i) = col.to_column() else { + todo!() + }; + self.witness[row].cols[i] + } + + /// Progresses to the next row. + fn next_row(&mut self) { + self.next_row(); + } + + /// Returns the current row. + fn curr_row(&self) -> usize { + self.witness.len() - 1 + } +} + +impl< + F: PrimeField, + CIx: ColumnIndexer, + const N_WIT: usize, + const N_REL: usize, + const N_DSEL: usize, + const N_FSEL: usize, + LT: LookupTableID, + > DirectWitnessCap for WitnessBuilderEnv +{ + /// Convert an abstract variable to a field element! Inverse of Env::constant(). + fn variable_to_field(value: Self::Variable) -> F { + value + } +} + +impl< + F: PrimeField, + CIx: ColumnIndexer, + const N_WIT: usize, + const N_REL: usize, + const N_DSEL: usize, + const N_FSEL: usize, + LT: LookupTableID, + > LookupCap for WitnessBuilderEnv +{ + fn lookup(&mut self, table_id: LT, value: Vec<>::Variable>) { + // Recording the lookup read into the corresponding slot in `lookup_reads`. + { + let curr_row = self.curr_row(); + + let lookup_read_table = self.lookup_reads.get_mut(&table_id).unwrap(); + + let curr_write_number = + (0..lookup_read_table.len()).find(|i| lookup_read_table[*i].len() <= curr_row); + + // If we're at row 0, we can declare as many reads as we want. + // If we're at non-zero row, we cannot declare more reads than before. + let curr_write_number = if let Some(v) = curr_write_number { + v + } else { + // TODO: This must be a panic; however, we don't yet have support + // different number of lookups on different rows. + // + // See https://github.com/o1-labs/proof-systems/issues/2440 + if curr_row != 0 { + eprintln!( + "ERROR: Number of writes in row {curr_row:?} is different from row 0", + ); + } + lookup_read_table.push(vec![]); + lookup_read_table.len() - 1 + }; + + lookup_read_table[curr_write_number].push(value.clone()); + } + + // If the table is fixed we also compute multiplicities on the fly. + if table_id.is_fixed() { + let value_ix = table_id + .ix_by_value(&value) + .expect("Could not resolve lookup for a fixed table"); + + let multiplicities = self.lookup_multiplicities.get_mut(&table_id).unwrap(); + // Since we allow multiple lookups per row, runtime tables + // can in theory grow bigger than the domain size. We + // still collect multiplicities as if runtime table vector + // is not height-bounded, but we will split it into chunks + // later. + if !table_id.is_fixed() && value_ix > multiplicities.len() { + multiplicities.resize(value_ix, 0u64); + } + multiplicities[value_ix] += 1; + } + } + + fn lookup_runtime_write(&mut self, table_id: LT, value: Vec) { + assert!( + !table_id.is_fixed() && !table_id.runtime_create_column(), + "lookup_runtime_write must be called on non-fixed tables that work with dynamic writes only" + ); + + // We insert value into runtime table in any case, for each row. + let curr_row = self.witness.len() - 1; + + let runtime_table = self.runtime_lookup_writes.get_mut(&table_id).unwrap(); + + let curr_write_number = + (0..runtime_table.len()).find(|i| runtime_table[*i].len() <= curr_row); + + let curr_write_number = if let Some(v) = curr_write_number { + v + } else { + assert!( + curr_row == 0, + "Number of writes in row {curr_row:?} is different from row 0" + ); + runtime_table.push(vec![]); + runtime_table.len() - 1 + }; + + runtime_table[curr_write_number].push(value); + } +} + +impl< + F: PrimeField, + CIx: ColumnIndexer, + const N_WIT: usize, + const N_REL: usize, + const N_DSEL: usize, + const N_FSEL: usize, + LT: LookupTableID, + > WitnessBuilderEnv +{ + pub fn write_column_raw(&mut self, position: Column, value: F) { + match position { + Column::Relation(i) => self.witness.last_mut().unwrap().cols[i] = value, + Column::FixedSelector(_) => { + panic!("Witness environment can't write into fixed selector columns."); + } + Column::DynamicSelector(_) => { + // TODO: Do we want to allow writing to dynamic selector columns only 1 or 0? + panic!( + "This is a dynamic selector column. The environment is + supposed to write only in witness columns" + ); + } + Column::LookupPartialSum(_) => { + panic!( + "This is a lookup related column. The environment is + supposed to write only in witness columns" + ); + } + Column::LookupMultiplicity(_) => { + panic!( + "This is a lookup related column. The environment is + supposed to write only in witness columns" + ); + } + Column::LookupAggregation => { + panic!( + "This is a lookup related column. The environment is + supposed to write only in witness columns" + ); + } + Column::LookupFixedTable(_) => { + panic!( + "This is a lookup related column. The environment is + supposed to write only in witness columns" + ); + } + } + } + + /// Progress to the computations on the next row. + pub fn next_row(&mut self) { + self.witness.push(Witness { + cols: Box::new([F::zero(); N_WIT]), + }); + } + + /// Getting multiplicities for range check tables less or equal + /// than 15 bits. Return value is a vector of columns, where each + /// column represents a "read". Fixed lookup tables always return + /// a single-column vector, while runtime tables may return more. + pub fn get_lookup_multiplicities(&self, domain_size: usize, table_id: LT) -> Vec> { + if table_id.is_fixed() { + let mut m = Vec::with_capacity(domain_size); + m.extend( + self.lookup_multiplicities[&table_id] + .iter() + .map(|x| F::from(*x)), + ); + if table_id.length() < domain_size { + let n_repeated_dummy_value: usize = domain_size - table_id.length() - 1; + let repeated_dummy_value: Vec = iter::repeat(-F::one()) + .take(n_repeated_dummy_value) + .collect(); + m.extend(repeated_dummy_value); + m.push(F::from(n_repeated_dummy_value as u64)); + } + assert_eq!(m.len(), domain_size); + vec![m] + } else { + // For runtime tables, multiplicities are computed post + // factum, since we explicitly want (e.g. for RAM lookups) + // reads and writes to be parallel -- in many cases we + // haven't previously written the value we want to read. + + let runtime_table = self.runtime_lookup_writes.get(&table_id).unwrap(); + let num_writes = if table_id.runtime_create_column() { + assert!(runtime_table.is_empty(), "runtime_table is expected to be unused for runtime tables with on-the-fly table creation"); + 1 + } else { + runtime_table.len() + }; + + // A runtime table resolver; the inverse of `runtime_tables`: + // maps runtime lookup table to `(column, row)` + let mut resolver: BTreeMap, (usize, usize)> = BTreeMap::new(); + + { + // Populate resolver map either from "reads" or from "writes" + if table_id.runtime_create_column() { + let columns = &self.lookup_reads.get(&table_id).unwrap(); + assert!( + columns.len() == 1, + "We only allow 1 read for runtime tables yet" + ); + let column = &columns[0]; + assert!(column.len() <= domain_size,); + for (row_i, value) in column.iter().enumerate() { + if resolver.get_mut(value).is_none() { + resolver.insert(value.clone(), (0, row_i)); + } + } + } else { + for (col_i, col) in runtime_table.iter().take(num_writes).enumerate() { + for (row_i, value) in col.iter().enumerate() { + if resolver.get_mut(value).is_none() { + resolver.insert(value.clone(), (col_i, row_i)); + } + } + } + } + } + + // Resolve reads and build multiplicities vector + + let mut multiplicities = vec![vec![0u64; domain_size]; num_writes]; + + for lookup_read_column in self.lookup_reads.get(&table_id).unwrap().iter() { + for value in lookup_read_column.iter() { + if let Some((col_i, row_i)) = resolver.get_mut(value) { + multiplicities[*col_i][*row_i] += 1; + } else { + panic!("Could not resolve a runtime table read"); + } + } + } + + multiplicities + .into_iter() + .map(|v| v.into_iter().map(|x| F::from(x)).collect()) + .collect() + } + } +} + +impl< + F: PrimeField, + CIx: ColumnIndexer, + const N_WIT: usize, + const N_REL: usize, + const N_DSEL: usize, + const N_FSEL: usize, + LT: LookupTableID, + > WitnessBuilderEnv +{ + /// Create a new empty-state witness builder. + pub fn create() -> Self { + let mut lookup_reads = BTreeMap::new(); + let mut lookup_multiplicities = BTreeMap::new(); + let mut runtime_lookup_writes = BTreeMap::new(); + let fixed_selectors = vec![vec![]; N_FSEL]; + for table_id in LT::all_variants().into_iter() { + lookup_reads.insert(table_id, vec![]); + if table_id.is_fixed() { + lookup_multiplicities.insert(table_id, vec![0u64; table_id.length()]); + } else { + runtime_lookup_writes.insert(table_id, vec![]); + } + } + + Self { + witness: vec![Witness { + cols: Box::new([F::zero(); N_WIT]), + }], + + lookup_multiplicities, + lookup_reads, + runtime_lookup_writes, + fixed_selectors, + phantom_cix: PhantomData, + assert_mapper: Box::new(|x| x), + } + } + + /// Sets a fixed selector, the vector of length equal to the + /// domain size (circuit height). + pub fn set_fixed_selector_cix(&mut self, sel: CIx, sel_values: Vec) { + if let Column::FixedSelector(i) = sel.to_column() { + self.fixed_selectors[i] = sel_values; + } else { + panic!("Tried to assign values to non-fixed-selector typed column {sel:?}"); + } + } + + /// Sets all fixed selectors directly. Each item in `selectors` is + /// a vector of `domain_size` length. + pub fn set_fixed_selectors(&mut self, selectors: Vec>) { + self.fixed_selectors = selectors + } + + pub fn get_relation_witness(&self, domain_size: usize) -> Witness> { + // Boxing to avoid stack overflow + let mut witness: Box>> = Box::new(Witness { + cols: Box::new(std::array::from_fn(|_| Vec::with_capacity(domain_size))), + }); + + // Filling actually used rows first + for witness_row in self.witness.iter().take(domain_size) { + for j in 0..N_REL { + witness.cols[j].push(witness_row.cols[j]); + } + } + + // Then filling witness rows up with zeroes to the domain size + // FIXME: Maybe this is not always wise, as default instance can be non-zero. + if self.witness.len() < domain_size { + for i in 0..N_REL { + witness.cols[i].extend(vec![F::zero(); domain_size - self.witness.len()]); + } + } + + // Fill out dynamic selectors. + for i in 0..N_DSEL { + // TODO FIXME Fill out dynamic selectors! + witness.cols[N_REL + i] = vec![F::zero(); domain_size]; + } + + for i in 0..(N_REL + N_DSEL) { + assert!( + witness.cols[i].len() == domain_size, + "Witness columns length {:?} for column {:?} does not match domain size {:?}", + witness.cols[i].len(), + i, + domain_size + ); + } + + *witness + } + + /// Return all runtime tables collected so far, padded to the domain size. + pub fn get_runtime_tables(&self, domain_size: usize) -> BTreeMap>>> { + let mut runtime_tables: BTreeMap = BTreeMap::new(); + for table_id in LT::all_variants() + .into_iter() + .filter(|table_id| !table_id.is_fixed()) + { + if table_id.runtime_create_column() { + // For runtime tables with no explicit writes, we + // store only read requests, so we assemble read + // requests into a column. + runtime_tables.insert(table_id, self.lookup_reads.get(&table_id).unwrap().clone()); + } else { + // For runtime tables /with/ explicit writes, these + // writes are stored in self.runtime_tables. + runtime_tables.insert( + table_id, + self.runtime_lookup_writes.get(&table_id).unwrap().clone(), + ); + } + // We pad the runtime table with dummies if it's too small. + for column in runtime_tables.get_mut(&table_id).unwrap() { + if column.len() < domain_size { + let dummy_value = column[0].clone(); // we assume runtime tables are never empty + column.append(&mut vec![dummy_value; domain_size - column.len()]); + } + } + } + runtime_tables + } + + pub fn get_logup_witness( + &self, + domain_size: usize, + lookup_tables_data: BTreeMap>>>, + ) -> BTreeMap> { + // Building lookup values + let mut lookup_tables: BTreeMap>>> = BTreeMap::new(); + if !lookup_tables_data.is_empty() { + for table_id in LT::all_variants().into_iter() { + // Find how many lookups are done per table. + let number_of_lookup_reads = self.lookup_reads.get(&table_id).unwrap().len(); + let number_of_lookup_writes = + if table_id.is_fixed() || table_id.runtime_create_column() { + 1 + } else { + self.runtime_lookup_writes[&table_id].len() + }; + + // +1 for the fixed table + lookup_tables.insert( + table_id, + vec![vec![]; number_of_lookup_reads + number_of_lookup_writes], + ); + } + } else { + debug!("No lookup tables data provided. Skipping lookup tables."); + } + + for (table_id, columns) in self.lookup_reads.iter() { + for (read_i, column) in columns.iter().enumerate() { + lookup_tables.get_mut(table_id).unwrap()[read_i] = column + .iter() + .map(|value| Logup { + table_id: *table_id, + numerator: F::one(), + value: value.clone(), + }) + .collect(); + } + } + + // FIXME add runtime tables, runtime_lookup_reads must be used here + + let mut lookup_multiplicities: BTreeMap>> = BTreeMap::new(); + + // Counting multiplicities & adding fixed column into the last column of every table. + for (table_id, table) in lookup_tables.iter_mut() { + let lookup_m: Vec> = self.get_lookup_multiplicities(domain_size, *table_id); + lookup_multiplicities.insert(*table_id, lookup_m.clone()); + + if table_id.is_fixed() || table_id.runtime_create_column() { + assert!(lookup_m.len() == 1); + assert!( + lookup_tables_data[table_id].len() == 1, + "table {table_id:?} must have exactly one column, got {:?}", + lookup_tables_data[table_id].len() + ); + let lookup_t = lookup_tables_data[table_id][0] + .iter() + .enumerate() + .map(|(i, v)| Logup { + table_id: *table_id, + numerator: -lookup_m[0][i], + value: v.clone(), + }) + .collect(); + *(table.last_mut().unwrap()) = lookup_t; + } else { + // Add multiplicity vectors for runtime tables. + for (col_i, lookup_column) in lookup_tables_data[table_id].iter().enumerate() { + let lookup_t = lookup_column + .iter() + .enumerate() + .map(|(i, v)| Logup { + table_id: *table_id, + numerator: -lookup_m[col_i][i], + value: v.clone(), + }) + .collect(); + + let pos = table.len() - self.runtime_lookup_writes[table_id].len() + col_i; + + (*table)[pos] = lookup_t; + } + } + } + + for (table_id, m) in lookup_multiplicities.iter() { + if !table_id.is_fixed() { + // Temporary assertion; to be removed when we support bigger + // runtime table/RAMlookups functionality. + assert!(m.len() <= domain_size, + "We do not _yet_ support wrapping runtime tables that are bigger than domain size."); + } + } + + lookup_tables + .iter() + .filter_map(|(table_id, table)| { + // Only add a table if it's used. Otherwise lookups fail. + if !table.is_empty() && !table[0].is_empty() { + Some(( + *table_id, + LogupWitness { + f: table.clone(), + m: lookup_multiplicities[table_id].clone(), + }, + )) + } else { + None + } + }) + .collect() + } + + /// Generates proof inputs, repacking/collecting internal witness builder state. + pub fn get_proof_inputs( + &self, + domain_size: usize, + lookup_tables_data: BTreeMap>>>, + ) -> ProofInputs { + let evaluations = self.get_relation_witness(domain_size); + let logups = self.get_logup_witness(domain_size, lookup_tables_data); + + ProofInputs { + evaluations, + logups, + } + } +} diff --git a/msm/src/column_env.rs b/msm/src/column_env.rs new file mode 100644 index 0000000000..808ac598ba --- /dev/null +++ b/msm/src/column_env.rs @@ -0,0 +1,177 @@ +use ark_ff::FftField; +use ark_poly::{Evaluations, Radix2EvaluationDomain}; + +use crate::{logup, logup::LookupTableID, witness::Witness}; +use kimchi::circuits::{ + berkeley_columns::{BerkeleyChallengeTerm, BerkeleyChallenges}, + domains::{Domain, EvaluationDomains}, + expr::{ColumnEnvironment as TColumnEnvironment, Constants}, +}; + +/// The collection of polynomials (all in evaluation form) and constants +/// required to evaluate an expression as a polynomial. +/// +/// All are evaluations. +pub struct ColumnEnvironment< + 'a, + const N: usize, + const N_REL: usize, + const N_DSEL: usize, + const N_FSEL: usize, + F: FftField, + ID: LookupTableID, +> { + /// The witness column polynomials. Includes relation columns, + /// fixed selector columns, and dynamic selector columns. + pub witness: &'a Witness>>, + /// Fixed selectors. These are "predefined" with the circuit, and, + /// unlike public input or dynamic selectors, are not part of the + /// witness that users are supposed to change after the circuit is + /// fixed. + pub fixed_selectors: &'a [Evaluations>; N_FSEL], + /// The value `prod_{j != 1} (1 - omega^j)`, used for efficiently + /// computing the evaluations of the unnormalized Lagrange basis polynomials. + pub l0_1: F, + /// Constant values required + pub constants: Constants, + /// Challenges from the IOP. + pub challenges: BerkeleyChallenges, + /// The domains used in the PLONK argument. + pub domain: EvaluationDomains, + + /// Lookup specific polynomials + pub lookup: Option>, +} + +impl< + 'a, + const N_WIT: usize, + const N_REL: usize, + const N_DSEL: usize, + const N_FSEL: usize, + F: FftField, + ID: LookupTableID, + > TColumnEnvironment<'a, F, BerkeleyChallengeTerm, BerkeleyChallenges> + for ColumnEnvironment<'a, N_WIT, N_REL, N_DSEL, N_FSEL, F, ID> +{ + type Column = crate::columns::Column; + + fn get_column( + &self, + col: &Self::Column, + ) -> Option<&'a Evaluations>> { + // TODO: when non-literal constant generics are available, substitute N with N_REG + N_DSEL + N_FSEL + assert!(N_WIT == N_REL + N_DSEL); + assert!(N_WIT == self.witness.len()); + match *col { + // Handling the "relation columns" at the beginning of the witness columns + Self::Column::Relation(i) => { + // TODO: add a test for this + assert!(i < N_REL,"Requested column with index {:?} but the given witness is meant for {:?} relation columns", i, N_REL); + let res = &self.witness[i]; + Some(res) + } + // Handling the "dynamic selector columns" at the end of the witness columns + Self::Column::DynamicSelector(i) => { + assert!(i < N_DSEL, "Requested dynamic selector with index {:?} but the given witness is meant for {:?} dynamic selector columns", i, N_DSEL); + let res = &self.witness[N_REL + i]; + Some(res) + } + Self::Column::FixedSelector(i) => { + assert!(i < N_FSEL, "Requested fixed selector with index {:?} but the given witness is meant for {:?} fixed selector columns", i, N_FSEL); + let res = &self.fixed_selectors[i]; + Some(res) + } + Self::Column::LookupPartialSum((table_id, i)) => { + if let Some(ref lookup) = self.lookup { + let table_id = ID::from_u32(table_id); + Some(&lookup.lookup_terms_evals_d8[&table_id][i]) + } else { + panic!("No lookup provided") + } + } + Self::Column::LookupAggregation => { + if let Some(ref lookup) = self.lookup { + Some(lookup.lookup_aggregation_evals_d8) + } else { + panic!("No lookup provided") + } + } + Self::Column::LookupMultiplicity((table_id, i)) => { + if let Some(ref lookup) = self.lookup { + Some(&lookup.lookup_counters_evals_d8[&ID::from_u32(table_id)][i]) + } else { + panic!("No lookup provided") + } + } + Self::Column::LookupFixedTable(table_id) => { + if let Some(ref lookup) = self.lookup { + Some(&lookup.fixed_tables_evals_d8[&ID::from_u32(table_id)]) + } else { + panic!("No lookup provided") + } + } + } + } + + fn get_domain(&self, d: Domain) -> Radix2EvaluationDomain { + match d { + Domain::D1 => self.domain.d1, + Domain::D2 => self.domain.d2, + Domain::D4 => self.domain.d4, + Domain::D8 => self.domain.d8, + } + } + + fn column_domain(&self, col: &Self::Column) -> Domain { + match *col { + Self::Column::Relation(_) + | Self::Column::DynamicSelector(_) + | Self::Column::FixedSelector(_) => { + let domain_size = match *col { + Self::Column::Relation(i) => self.witness[i].domain().size, + Self::Column::DynamicSelector(i) => self.witness[N_REL + i].domain().size, + Self::Column::FixedSelector(i) => self.fixed_selectors[i].domain().size, + _ => panic!("Impossible"), + }; + if self.domain.d1.size == domain_size { + Domain::D1 + } else if self.domain.d2.size == domain_size { + Domain::D2 + } else if self.domain.d4.size == domain_size { + Domain::D4 + } else if self.domain.d8.size == domain_size { + Domain::D8 + } else { + panic!("Domain not supported. We do support the following multiple of the domain registered in the environment: 1, 2, 4, 8") + } + } + Self::Column::LookupAggregation + | Self::Column::LookupFixedTable(_) + | Self::Column::LookupMultiplicity(_) + | Self::Column::LookupPartialSum(_) => { + // When there is a lookup, we do suppose the domain is always D8 + // and we have at leat 6 lookups per row. + Domain::D8 + } + } + } + + fn get_constants(&self) -> &Constants { + &self.constants + } + + fn get_challenges(&self) -> &BerkeleyChallenges { + &self.challenges + } + + fn vanishes_on_zero_knowledge_and_previous_rows( + &self, + ) -> &'a Evaluations> { + panic!("Not supposed to be used in MSM") + } + + fn l0_1(&self) -> F { + self.l0_1 + } +} diff --git a/msm/src/columns.rs b/msm/src/columns.rs new file mode 100644 index 0000000000..eef53a5035 --- /dev/null +++ b/msm/src/columns.rs @@ -0,0 +1,101 @@ +use std::collections::HashMap; + +use folding::expressions::FoldingColumnTrait; +use kimchi::circuits::expr::{CacheId, FormattedOutput}; + +/// Describe a generic indexed variable X_{i}. +#[derive(PartialEq, Eq, Clone, Copy, Debug, Hash)] +pub enum Column { + /// Columns related to the relation encoded in the circuit + Relation(usize), + /// Columns related to dynamic selectors to indicate gate type + DynamicSelector(usize), + /// Constant column that is /always/ fixed for a given circuit. + FixedSelector(usize), + // Columns related to the lookup protocol + /// Partial sums. This corresponds to the `h_i`. + /// It is first indexed by the table ID, and after that internal index. + LookupPartialSum((u32, usize)), + /// Multiplicities, indexed. This corresponds to the `m_i`. First + /// indexed by table ID, then internal index. + LookupMultiplicity((u32, usize)), + /// The lookup aggregation, i.e. `phi` + LookupAggregation, + /// The fixed tables. The parameter is considered to the indexed table. + LookupFixedTable(u32), +} + +impl Column { + /// Adds offset if the column is `Relation`. Fails otherwise. + pub fn add_rel_offset(self, offset: usize) -> Column { + let Column::Relation(i) = self else { + todo!("add_rel_offset is only implemented for the relation columns") + }; + Column::Relation(offset + i) + } +} + +impl FormattedOutput for Column { + fn latex(&self, _cache: &mut HashMap) -> String { + match self { + Column::Relation(i) => format!("x_{{{i}}}"), + Column::FixedSelector(i) => format!("fs_{{{i}}}"), + Column::DynamicSelector(i) => format!("ds_{{{i}}}"), + Column::LookupPartialSum((table_id, i)) => format!("h_{{{table_id}, {i}}}"), + Column::LookupMultiplicity((table_id, i)) => format!("m_{{{table_id}, {i}}}"), + Column::LookupFixedTable(i) => format!("t_{{{i}}}"), + Column::LookupAggregation => String::from("φ"), + } + } + + fn text(&self, _cache: &mut HashMap) -> String { + match self { + Column::Relation(i) => format!("x[{i}]"), + Column::FixedSelector(i) => format!("fs[{i}]"), + Column::DynamicSelector(i) => format!("ds[{i}]"), + Column::LookupPartialSum((table_id, i)) => format!("h[{table_id}, {i}]"), + Column::LookupMultiplicity((table_id, i)) => format!("m[{table_id}, {i}]"), + Column::LookupFixedTable(i) => format!("t[{i}]"), + Column::LookupAggregation => String::from("φ"), + } + } + + fn ocaml(&self, _cache: &mut HashMap) -> String { + // FIXME + unimplemented!("Not used at the moment") + } + + fn is_alpha(&self) -> bool { + // FIXME + unimplemented!("Not used at the moment") + } +} + +/// A datatype expressing a generalized column, but with potentially +/// more convenient interface than a bare column. +pub trait ColumnIndexer: core::fmt::Debug + Copy + Eq + Ord { + /// Total number of columns in this index. + const N_COL: usize; + + /// Flatten the column "alias" into the integer-like column. + fn to_column(self) -> Column; +} + +// Implementation to be compatible with folding if we use generic column +// constraints +impl FoldingColumnTrait for Column { + fn is_witness(&self) -> bool { + match self { + // Witness + Column::Relation(_) => true, + Column::DynamicSelector(_) => true, + // FIXME: check if we want to treat lookups differently + Column::LookupPartialSum(_) => true, + Column::LookupMultiplicity(_) => true, + Column::LookupAggregation => true, + // Not witness/public values + Column::FixedSelector(_) => false, + Column::LookupFixedTable(_) => false, + } + } +} diff --git a/msm/src/expr.rs b/msm/src/expr.rs new file mode 100644 index 0000000000..48f0d6e4d9 --- /dev/null +++ b/msm/src/expr.rs @@ -0,0 +1,71 @@ +// @volhovm Not sure a whole file is necessary for just one type +// alias, but maybe more code will come. +// Consider moving to lib.rs + +use ark_ff::Field; +use kimchi::circuits::{ + berkeley_columns::BerkeleyChallengeTerm, + expr::{ConstantExpr, Expr, ExprInner, Variable}, + gate::CurrOrNext, +}; + +use crate::columns::Column; + +/// An expression over /generic/ (not circuit-specific) columns +/// defined in the msm project. To represent constraints as multi +/// variate polynomials. The variables are over the columns. +/// +/// For instance, if there are 3 columns X1, X2, X3, then to constraint X3 to be +/// equals to sum of the X1 and X2 on a row, we would use the multivariate +/// polynomial `X3 - X1 - X2 = 0`. +/// Using the expression framework, this constraint would be +/// ``` +/// use kimchi::circuits::expr::{ConstantExprInner, ExprInner, Operations, Variable}; +/// use kimchi::circuits::gate::CurrOrNext; +/// use kimchi::circuits::berkeley_columns::BerkeleyChallengeTerm; +/// use kimchi_msm::columns::Column; +/// use kimchi_msm::expr::E; +/// pub type Fp = ark_bn254::Fr; +/// let x1 = E::::Atom( +/// ExprInner::>, Column>::Cell(Variable { +/// col: Column::Relation(1), +/// row: CurrOrNext::Curr, +/// }), +/// ); +/// let x2 = E::::Atom( +/// ExprInner::>, Column>::Cell(Variable { +/// col: Column::Relation(1), +/// row: CurrOrNext::Curr, +/// }), +/// ); +/// let x3 = E::::Atom( +/// ExprInner::>, Column>::Cell(Variable { +/// col: Column::Relation(1), +/// row: CurrOrNext::Curr, +/// }), +/// ); +/// let constraint = x3 - x1 - x2; +/// ``` +/// A list of such constraints is used to represent the entire circuit and will +/// be used to build the quotient polynomial. +pub type E = Expr, Column>; + +pub fn curr_cell(col: Column) -> E { + E::Atom(ExprInner::Cell(Variable { + col, + row: CurrOrNext::Curr, + })) +} + +pub fn next_cell(col: Column) -> E { + E::Atom(ExprInner::Cell(Variable { + col, + row: CurrOrNext::Next, + })) +} + +#[test] +fn test_debug_can_be_called_on_expr() { + use crate::{columns::Column::*, Fp}; + println!("{:}", curr_cell::(Relation(0)) + curr_cell(Relation(1))) +} diff --git a/msm/src/fec/columns.rs b/msm/src/fec/columns.rs new file mode 100644 index 0000000000..1d6c0d2267 --- /dev/null +++ b/msm/src/fec/columns.rs @@ -0,0 +1,160 @@ +use crate::{ + columns::{Column, ColumnIndexer}, + serialization::interpreter::{N_LIMBS_LARGE, N_LIMBS_SMALL}, +}; + +/// Number of columns in the FEC circuits. +pub const FEC_N_COLUMNS: usize = + FECColumnInput::N_COL + FECColumnOutput::N_COL + FECColumnInter::N_COL; + +/// FEC ADD inputs: two points = four coordinates, and each in 4 +/// "large format" limbs. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub enum FECColumnInput { + XP(usize), // 4 + YP(usize), // 4 + XQ(usize), // 4 + YQ(usize), // 4 +} + +/// FEC ADD outputs: one point, each in 17 limb output format. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub enum FECColumnOutput { + XR(usize), // 17 + YR(usize), // 17 +} + +/// FEC ADD intermediate (work) columns. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub enum FECColumnInter { + F(usize), // 4 + S(usize), // 17 + Q1(usize), // 17 + Q2(usize), // 17 + Q3(usize), // 17 + Q1Sign, // 1 + Q2Sign, // 1 + Q3Sign, // 1 + Q1L(usize), // 4 + Q2L(usize), // 4 + Q3L(usize), // 4 + Carry1(usize), // 36 + Carry2(usize), // 36 + Carry3(usize), // 36 +} + +/// Columns used by the FEC Addition subcircuit. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub enum FECColumn { + Input(FECColumnInput), + Output(FECColumnOutput), + Inter(FECColumnInter), +} + +impl ColumnIndexer for FECColumnInput { + const N_COL: usize = 4 * N_LIMBS_LARGE; + fn to_column(self) -> Column { + match self { + FECColumnInput::XP(i) => { + assert!(i < N_LIMBS_LARGE); + Column::Relation(i) + } + FECColumnInput::YP(i) => { + assert!(i < N_LIMBS_LARGE); + Column::Relation(N_LIMBS_LARGE + i) + } + FECColumnInput::XQ(i) => { + assert!(i < N_LIMBS_LARGE); + Column::Relation(2 * N_LIMBS_LARGE + i) + } + FECColumnInput::YQ(i) => { + assert!(i < N_LIMBS_LARGE); + Column::Relation(3 * N_LIMBS_LARGE + i) + } + } + } +} + +impl ColumnIndexer for FECColumnOutput { + const N_COL: usize = 2 * N_LIMBS_SMALL; + fn to_column(self) -> Column { + match self { + FECColumnOutput::XR(i) => { + assert!(i < N_LIMBS_SMALL); + Column::Relation(i) + } + FECColumnOutput::YR(i) => { + assert!(i < N_LIMBS_SMALL); + Column::Relation(N_LIMBS_SMALL + i) + } + } + } +} + +impl ColumnIndexer for FECColumnInter { + const N_COL: usize = 4 * N_LIMBS_LARGE + 10 * N_LIMBS_SMALL + 9; + fn to_column(self) -> Column { + match self { + FECColumnInter::F(i) => { + assert!(i < N_LIMBS_LARGE); + Column::Relation(i) + } + FECColumnInter::S(i) => { + assert!(i < N_LIMBS_SMALL); + Column::Relation(N_LIMBS_LARGE + i) + } + FECColumnInter::Q1(i) => { + assert!(i < N_LIMBS_SMALL); + Column::Relation(N_LIMBS_LARGE + N_LIMBS_SMALL + i) + } + FECColumnInter::Q2(i) => { + assert!(i < N_LIMBS_SMALL); + Column::Relation(N_LIMBS_LARGE + 2 * N_LIMBS_SMALL + i) + } + FECColumnInter::Q3(i) => { + assert!(i < N_LIMBS_SMALL); + Column::Relation(N_LIMBS_LARGE + 3 * N_LIMBS_SMALL + i) + } + FECColumnInter::Q1Sign => Column::Relation(N_LIMBS_LARGE + 4 * N_LIMBS_SMALL), + FECColumnInter::Q2Sign => Column::Relation(N_LIMBS_LARGE + 4 * N_LIMBS_SMALL + 1), + FECColumnInter::Q3Sign => Column::Relation(N_LIMBS_LARGE + 4 * N_LIMBS_SMALL + 2), + FECColumnInter::Q1L(i) => { + assert!(i < N_LIMBS_LARGE); + Column::Relation(N_LIMBS_LARGE + 4 * N_LIMBS_SMALL + 3 + i) + } + FECColumnInter::Q2L(i) => { + assert!(i < N_LIMBS_LARGE); + Column::Relation(2 * N_LIMBS_LARGE + 4 * N_LIMBS_SMALL + 3 + i) + } + FECColumnInter::Q3L(i) => { + assert!(i < N_LIMBS_LARGE); + Column::Relation(3 * N_LIMBS_LARGE + 4 * N_LIMBS_SMALL + 3 + i) + } + FECColumnInter::Carry1(i) => { + assert!(i < 2 * N_LIMBS_SMALL + 2); + Column::Relation(4 * N_LIMBS_LARGE + 4 * N_LIMBS_SMALL + 3 + i) + } + FECColumnInter::Carry2(i) => { + assert!(i < 2 * N_LIMBS_SMALL + 2); + Column::Relation(4 * N_LIMBS_LARGE + 6 * N_LIMBS_SMALL + 5 + i) + } + FECColumnInter::Carry3(i) => { + assert!(i < 2 * N_LIMBS_SMALL + 2); + Column::Relation(4 * N_LIMBS_LARGE + 8 * N_LIMBS_SMALL + 7 + i) + } + } + } +} + +impl ColumnIndexer for FECColumn { + const N_COL: usize = FEC_N_COLUMNS; + fn to_column(self) -> Column { + match self { + FECColumn::Input(input) => input.to_column(), + FECColumn::Inter(inter) => inter.to_column().add_rel_offset(FECColumnInput::N_COL), + FECColumn::Output(output) => output + .to_column() + .add_rel_offset(FECColumnInput::N_COL + FECColumnInter::N_COL), + } + } +} diff --git a/msm/src/fec/interpreter.rs b/msm/src/fec/interpreter.rs new file mode 100644 index 0000000000..ce3f78b39c --- /dev/null +++ b/msm/src/fec/interpreter.rs @@ -0,0 +1,699 @@ +use crate::{ + circuit_design::{ + capabilities::{read_column_array, write_column_array_const, write_column_const}, + ColAccessCap, ColWriteCap, LookupCap, + }, + fec::{ + columns::{FECColumn, FECColumnInput, FECColumnInter, FECColumnOutput}, + lookups::LookupTable, + }, + serialization::interpreter::{ + bigint_to_biguint_f, combine_carry, combine_small_to_large, fold_choice2, + limb_decompose_biguint, limb_decompose_ff, LIMB_BITSIZE_LARGE, LIMB_BITSIZE_SMALL, + N_LIMBS_LARGE, N_LIMBS_SMALL, + }, +}; +use ark_ff::{PrimeField, Zero}; +use core::marker::PhantomData; +use num_bigint::{BigInt, BigUint, ToBigInt}; +use num_integer::Integer; +use num_traits::sign::Signed; +use o1_utils::field_helpers::FieldHelpers; + +/// Convenience function for printing. +pub fn limbs_to_bigints(input: [F; N]) -> Vec { + input + .iter() + .map(|f| f.to_bigint_positive()) + .collect::>() +} + +/// When P = (xP,yP) and Q = (xQ,yQ) are not negative of each other, thus function ensures +/// +/// P + Q = R where +/// +/// s = (yP - yQ) / (xP - xQ) +/// +/// xR = s^2 - xP - xQ and yR = -yP + s(xP - xR) +/// +/// +/// Equations that we check: +/// 1. s (xP - xQ) - (yP - yQ) - q_1 f = 0 +/// 2. xR - s^2 + xP + xQ - q_2 f = 0 +/// 3. yR + yP - s (xP - xR) - q_3 f = 0 +/// +/// We will use several different "packing" format. +/// +/// === Limb equations +/// +/// The standard (small limb) one, using 17 limbs of 15 bits each, is +/// mostly helpful for range-checking the element, because 15-bit +/// range checks are easy to perform. Additionally, this format is +/// helpful for checking that the value is ∈ [0,f), where f is a +/// foreign field modulus. +/// +/// We will additionally use a "large limb" format, where each limb is +/// 75 bits, so fitting exactly 5 small limbs. This format is +/// effective for trusted values that we do not need to range check. +/// Additionally, verifying equations on 75 bits is more effective in +/// terms of minimising constraints. +/// +/// Regarding our /concrete limb/ equations, they are different from +/// the generic ones above in that they have carries. The carries are +/// stored in a third format. Let us illustrate the design on the +/// first equation of the three. Its concrete final form is as follows: +/// +/// for i ∈ [0..2L-2]: +/// \sum_{j,k < L | k+j = i} s_j (xP_k - xQ_k) +/// - ((yP_i - yQ_i) if i < L else 0) +/// - q_1_sign * \sum_{j,k < L | k+j = i} q_1_j f_k +/// - (c_i * 2^B if i < 2L-2 else 0) +/// + (c_{i-1} if i > 0 else 0) = 0 +/// +/// First, note that the equation has turned into 2L-2 equations. This +/// is because the form of multiplication (and there are two +/// multiplications here, s*(xP-xQ) and q*f) implies quadratic number +/// of limb multiplications, but because our operations are modulo f, +/// every single element except for q in this equation is in the +/// field. +/// +/// Instead of having one limb-multiplication per row (e.g. +/// q_1_5*f_6), which would lead to quadratic number of constraints, +/// and quadratic number of intermediate-representation columns, we +/// "batch" all multiplications for degree $i$ in one constraint as +/// above. +/// +/// Second, note that the carries are non-uniform in the loop: for the +/// first limb, we only subtract c_0*2^B, while for the last limb we +/// only add the previous carry c_{2L-3}. This means that, as usual, +/// the number of carries is one less than the number of +/// limb-equations. In our case, every equation relies on 2L-2 "large" +/// carries. +/// +/// Finally, small remark is that for simplicity we carry the sign of +/// q separately from its absolute value. Note that in the original +/// generic equation s (xP - xQ) - (yP - yQ) - q_1 f = 0 that holds +/// over the integers, the only value (except for f) that can actually +/// be outside of the field range [0,f-1) is q_1. In fact, while every +/// other value is strictly positive, q_1 can be actually negative. +/// Carrying its sign separately greatly simplifies modelling limbs at +/// the expense of just 3 extra columns per circuit. So q_1 limbs +/// actually contains the absolute value of q_1, while q_1_sign is in +/// {-1,1}. +/// +/// === Data layout +/// +/// Now we are ready to present the data layout and to discuss the +/// representation modes. +/// +/// Let +/// L := N_LIMBS_LARGE +/// S := N_LIMBS_SMALL +/// +/// variable offset length comment +/// --------------------------------------------------------------- +/// xP: 0 L Always trusted, not range checked +/// yP: 1*L L Always trusted, not range checked +/// xQ: 2*L L Always trusted, not range checked +/// yQ: 3*L L Alawys trusted, not range checked +/// f: 4*L L Always trusted, not range checked +/// xR: 5*L S +/// yR: 5*L + 1*S S +/// s: 5*L + 2*S S +/// q_1: 5*L + 3*S S +/// q_2: 5*L + 4*S S +/// q_3: 5*L + 5*S S +/// q_2_sign: 5*L + 6*S 1 +/// q_1_sign: 5*L + 6*S + 1 1 +/// q_3_sign: 5*L + 6*S + 2 1 +/// carry_1: 5*L + 6*S + 3 2*S+2 +/// carry_2: 5*L + 8*S + 5 2*S+2 +/// carry_3: 5*L + 10*S + 7 2*S+2 +///---------------------------------------------------------------- +/// +/// +/// As we said before, all elements that are either S small limbs or 1 +/// are for range-checking. The only unusual part here is that the +/// carries are represented in 2*S+2 limbs. Let us explain. +/// +/// As we said, we need 2*L-2 carries, which in 6. Because our +/// operations contain not just one limb multiplication, but several +/// limb multiplication and extra additions, our carries will /not/ +/// fit into 75 bits. But we can prove (below) that they always fit +/// into 79 limbs. Therefore, every large carry will be represented +/// not by 5 15-bit chunks, but by 6 15-bit chunks. This gives us 6 +/// bits * 6 carries = 36 chunks, and every 6th chunk is 4 bits only. +/// This matches the 2*S+2 = 36, since S = 17. +/// +/// Note however since 79-bit carry is signed, we will store it as a list of +/// [15 15 15 15 15 9]-bit limbs, where limbs are signed. +/// E.g. 15-bit limbs are in [-2^14, 2^14-1]. This allows us to use +/// 14abs range checks. +/// +/// === Ranges +/// +/// Carries for our three equations have the following generic range +/// form (inclusive over integers). Note that all three equations look +/// exactly the same for i >= n _except_ the carry from the previous +/// limbs. +/// +/// +/// Eq1. +/// - i ∈ [0,n-1]: c1_i ∈ [-((i+1)*2^(b+1) - 2*i - 3), +/// (i+1)*2^(b+1) - 2*i - 3] (symmetric) +/// - i ∈ [n,2n-2]: c1_i ∈ [-((2*n-i-1)*2^(b+1) - 2*(2*n-i) + 3), +/// (2*n-i-1)*2^(b+1) - 2*(2*n-i) + 3] (symmetric) +/// +/// Eq2. +/// - i ∈ [0,n-1]: c2_i ∈ [-((i+1)*2^(b+1) - 2*i - 4), +/// if i == 0 2^b else (i+1)*2^b - i] +/// - i ∈ [n,2n-2]: c2_i ∈ [-((2*n-i-1)*2^(b+1) - 2*(2*n-i) + 3), +/// (2*n-i-1)*2^(b+1) - 2*(2*n-i) + 3 - (if i == n { n-1 } else 0) ] +/// +/// Eq3. +/// - i ∈ [0,n-1]: c3_i ∈ [-((i+1)*2^(b+1) - 2*i - 4), +/// (i+1)*2^b - i - 1] +/// - i ∈ [n,2n-2]: c3_i ∈ [-((2*n-i-1)*2^(b+1) - 2*(2*n-i) + 3), +/// (2*n-i-1)*2^(b+1) - 2*(2*n-i) + 3 - (if i == n { n-1 } else 0) ] +/// +/// Absolute maximum values for all carries: +/// Eq1. +/// * Upper bound = -lower bound is achieved at i = n-1, n*2^(b+1) - 2*(n-1) - 3 +/// * (+-) 302231454903657293676535 +/// +/// Eq2 and Eq3: +/// * Upper bound is achieved at i = n, (n-1)*2^(b+1) - 2*n + 3 - n -1 +/// * 226673591177742970257400 +/// * Lower bound is achieved at i = n-1, n*2^(b+1) - 2*(n-1) - 4 +/// * (-) 302231454903657293676534 +/// +/// As we can see, the values are about 2*n=8 times bigger than 2^b, +/// so concretely 4 extra bits per carry will be enough. This implies +/// that we can /definitely/ fit a large carry into 6 small limbs, +/// since it has 15 "free" bits of which we will use 4 at most. +/// +/// @volhovm: Soundness-wise I am not convinced that we need to +/// enforce these more precise ranges as compared to enforcing just 4 +/// bit more for the highest limb. Even checking that highest limb is +/// 15 bits could be quite sound. +pub fn constrain_ec_addition< + F: PrimeField, + Ff: PrimeField, + Env: ColAccessCap + LookupCap>, +>( + env: &mut Env, +) { + let xp_limbs_large: [_; N_LIMBS_LARGE] = + read_column_array(env, |i| FECColumn::Input(FECColumnInput::XP(i))); + let yp_limbs_large: [_; N_LIMBS_LARGE] = + read_column_array(env, |i| FECColumn::Input(FECColumnInput::YP(i))); + let xq_limbs_large: [_; N_LIMBS_LARGE] = + read_column_array(env, |i| FECColumn::Input(FECColumnInput::XQ(i))); + let yq_limbs_large: [_; N_LIMBS_LARGE] = + read_column_array(env, |i| FECColumn::Input(FECColumnInput::YQ(i))); + let f_limbs_large: [_; N_LIMBS_LARGE] = + read_column_array(env, |i| FECColumn::Inter(FECColumnInter::F(i))); + let xr_limbs_small: [_; N_LIMBS_SMALL] = + read_column_array(env, |i| FECColumn::Output(FECColumnOutput::XR(i))); + let yr_limbs_small: [_; N_LIMBS_SMALL] = + read_column_array(env, |i| FECColumn::Output(FECColumnOutput::YR(i))); + let s_limbs_small: [_; N_LIMBS_SMALL] = + read_column_array(env, |i| FECColumn::Inter(FECColumnInter::S(i))); + + let q1_limbs_small: [_; N_LIMBS_SMALL] = + read_column_array(env, |i| FECColumn::Inter(FECColumnInter::Q1(i))); + let q2_limbs_small: [_; N_LIMBS_SMALL] = + read_column_array(env, |i| FECColumn::Inter(FECColumnInter::Q2(i))); + let q3_limbs_small: [_; N_LIMBS_SMALL] = + read_column_array(env, |i| FECColumn::Inter(FECColumnInter::Q3(i))); + let q1_limbs_large: [_; N_LIMBS_LARGE] = + read_column_array(env, |i| FECColumn::Inter(FECColumnInter::Q1L(i))); + let q2_limbs_large: [_; N_LIMBS_LARGE] = + read_column_array(env, |i| FECColumn::Inter(FECColumnInter::Q2L(i))); + let q3_limbs_large: [_; N_LIMBS_LARGE] = + read_column_array(env, |i| FECColumn::Inter(FECColumnInter::Q3L(i))); + + let q1_sign = env.read_column(FECColumn::Inter(FECColumnInter::Q1Sign)); + let q2_sign = env.read_column(FECColumn::Inter(FECColumnInter::Q2Sign)); + let q3_sign = env.read_column(FECColumn::Inter(FECColumnInter::Q3Sign)); + + let carry1_limbs_small: [_; 2 * N_LIMBS_SMALL + 2] = + read_column_array(env, |i| FECColumn::Inter(FECColumnInter::Carry1(i))); + let carry2_limbs_small: [_; 2 * N_LIMBS_SMALL + 2] = + read_column_array(env, |i| FECColumn::Inter(FECColumnInter::Carry2(i))); + let carry3_limbs_small: [_; 2 * N_LIMBS_SMALL + 2] = + read_column_array(env, |i| FECColumn::Inter(FECColumnInter::Carry3(i))); + + // FIXME get rid of cloning + + // u128 covers our limb sizes shifts which is good + let constant_u128 = |x: u128| -> Env::Variable { Env::constant(From::from(x)) }; + + // Slope and result variables must be in the field. + for (i, x) in s_limbs_small + .iter() + .chain(xr_limbs_small.iter()) + .chain(yr_limbs_small.iter()) + .enumerate() + { + if i % N_LIMBS_SMALL == N_LIMBS_SMALL - 1 { + // If it's the highest limb, we need to check that it's representing a field element. + env.lookup( + LookupTable::RangeCheckFfHighest(PhantomData), + vec![x.clone()], + ); + } else { + env.lookup(LookupTable::RangeCheck15, vec![x.clone()]); + } + } + + // Quotient limbs must fit into 15 bits, but we don't care if they're in the field. + for x in q1_limbs_small + .iter() + .chain(q2_limbs_small.iter()) + .chain(q3_limbs_small.iter()) + { + env.lookup(LookupTable::RangeCheck15, vec![x.clone()]); + } + + // Signs must be -1 or 1. + for x in [&q1_sign, &q2_sign, &q3_sign] { + env.assert_zero(x.clone() * x.clone() - Env::constant(F::one())); + } + + // Carry limbs need to be in particular ranges. + for (i, x) in carry1_limbs_small + .iter() + .chain(carry2_limbs_small.iter()) + .chain(carry3_limbs_small.iter()) + .enumerate() + { + if i % 6 == 5 { + // This should be a diferent range check depending on which big-limb we're processing? + // So instead of one type of lookup we will have 5 different ones? + env.lookup(LookupTable::RangeCheck9Abs, vec![x.clone()]); + } else { + env.lookup(LookupTable::RangeCheck14Abs, vec![x.clone()]); + } + } + + // Make sure qi_limbs_large are properly constructed from qi_limbs_small and qi_sign + { + let q1_limbs_large_abs_expected = + combine_small_to_large::<_, _, Env>(q1_limbs_small.clone()); + for j in 0..N_LIMBS_LARGE { + env.assert_zero( + q1_limbs_large[j].clone() + - q1_sign.clone() * q1_limbs_large_abs_expected[j].clone(), + ); + } + let q2_limbs_large_abs_expected = + combine_small_to_large::<_, _, Env>(q2_limbs_small.clone()); + for j in 0..N_LIMBS_LARGE { + env.assert_zero( + q2_limbs_large[j].clone() + - q2_sign.clone() * q2_limbs_large_abs_expected[j].clone(), + ); + } + let q3_limbs_large_abs_expected = + combine_small_to_large::<_, _, Env>(q3_limbs_small.clone()); + for j in 0..N_LIMBS_LARGE { + env.assert_zero( + q3_limbs_large[j].clone() + - q3_sign.clone() * q3_limbs_large_abs_expected[j].clone(), + ); + } + } + + let xr_limbs_large = combine_small_to_large::<_, _, Env>(xr_limbs_small.clone()); + let yr_limbs_large = combine_small_to_large::<_, _, Env>(yr_limbs_small.clone()); + let s_limbs_large = combine_small_to_large::<_, _, Env>(s_limbs_small.clone()); + + let carry1_limbs_large: [_; 2 * N_LIMBS_LARGE - 2] = + combine_carry::(carry1_limbs_small.clone()); + let carry2_limbs_large: [_; 2 * N_LIMBS_LARGE - 2] = + combine_carry::(carry2_limbs_small.clone()); + let carry3_limbs_large: [_; 2 * N_LIMBS_LARGE - 2] = + combine_carry::(carry3_limbs_small.clone()); + + let limb_size_large = constant_u128(1u128 << LIMB_BITSIZE_LARGE); + let add_extra_carries = + |i: usize, carry_limbs_large: &[Env::Variable; 2 * N_LIMBS_LARGE - 2]| -> Env::Variable { + if i == 0 { + -(carry_limbs_large[0].clone() * limb_size_large.clone()) + } else if i < 2 * N_LIMBS_LARGE - 2 { + carry_limbs_large[i - 1].clone() + - carry_limbs_large[i].clone() * limb_size_large.clone() + } else if i == 2 * N_LIMBS_LARGE - 2 { + carry_limbs_large[i - 1].clone() + } else { + panic!("add_extra_carries: the index {i:?} is too high") + } + }; + + // Equation 1 + // General form: + // \sum_{k,j | k+j = i} s_j (xP_k - xQ_k) - (yP_i - yQ_i) - \sum_{k,j} q_1_k f_j - c_i * 2^B + c_{i-1} = 0 + for i in 0..2 * N_LIMBS_LARGE - 1 { + let mut constraint1 = fold_choice2(N_LIMBS_LARGE, i, |j, k| { + s_limbs_large[j].clone() * (xp_limbs_large[k].clone() - xq_limbs_large[k].clone()) + }); + if i < N_LIMBS_LARGE { + constraint1 = constraint1 - (yp_limbs_large[i].clone() - yq_limbs_large[i].clone()); + } + constraint1 = constraint1 + - fold_choice2(N_LIMBS_LARGE, i, |j, k| { + q1_limbs_large[j].clone() * f_limbs_large[k].clone() + }); + constraint1 = constraint1 + add_extra_carries(i, &carry1_limbs_large); + env.assert_zero(constraint1); + } + + // Equation 2 + // General form: xR_i - \sum s_j s_k + xP_i + xQ_i - \sum q_2_j f_k - c_i * 2^B + c_{i-1} = 0 + for i in 0..2 * N_LIMBS_LARGE - 1 { + let mut constraint2 = -fold_choice2(N_LIMBS_LARGE, i, |j, k| { + s_limbs_large[j].clone() * s_limbs_large[k].clone() + }); + if i < N_LIMBS_LARGE { + constraint2 = constraint2 + + xr_limbs_large[i].clone() + + xp_limbs_large[i].clone() + + xq_limbs_large[i].clone(); + } + constraint2 = constraint2 + - fold_choice2(N_LIMBS_LARGE, i, |j, k| { + q2_limbs_large[j].clone() * f_limbs_large[k].clone() + }); + constraint2 = constraint2 + add_extra_carries(i, &carry2_limbs_large); + env.assert_zero(constraint2); + } + + // Equation 3 + // General form: yR_i + yP_i - \sum s_j (xP_k - xR_k) - \sum q_3_j f_k - c_i * 2^B + c_{i-1} = 0 + for i in 0..2 * N_LIMBS_LARGE - 1 { + let mut constraint3 = -fold_choice2(N_LIMBS_LARGE, i, |j, k| { + s_limbs_large[j].clone() * (xp_limbs_large[k].clone() - xr_limbs_large[k].clone()) + }); + if i < N_LIMBS_LARGE { + constraint3 = constraint3 + yr_limbs_large[i].clone() + yp_limbs_large[i].clone(); + } + constraint3 = constraint3 + - fold_choice2(N_LIMBS_LARGE, i, |j, k| { + q3_limbs_large[j].clone() * f_limbs_large[k].clone() + }); + constraint3 = constraint3 + add_extra_carries(i, &carry3_limbs_large); + env.assert_zero(constraint3) + } +} + +/// Creates a witness for adding two points, p and q, each represented +/// as a pair of foreign field elements. Returns a point. +/// +/// This function is witness-generation counterpart (called by the prover) of +/// `constrain_ec_addition` -- see the documentation of the latter. +pub fn ec_add_circuit< + F: PrimeField, + Ff: PrimeField, + Env: ColWriteCap + LookupCap>, +>( + env: &mut Env, + xp: Ff, + yp: Ff, + xq: Ff, + yq: Ff, +) -> (Ff, Ff) { + let slope: Ff = (yq - yp) / (xq - xp); + let xr: Ff = slope * slope - xp - xq; + let yr: Ff = slope * (xp - xr) - yp; + + let two_bi: BigInt = TryFrom::try_from(2).unwrap(); + + let large_limb_size: F = From::from(1u128 << LIMB_BITSIZE_LARGE); + + // Foreign field modulus + let f_bui: BigUint = TryFrom::try_from(Ff::MODULUS).unwrap(); + let f_bi: BigInt = f_bui.to_bigint().unwrap(); + + // Native field modulus (prime) + let n_bui: BigUint = TryFrom::try_from(F::MODULUS).unwrap(); + let n_bi: BigInt = n_bui.to_bigint().unwrap(); + let n_half_bi = &n_bi / &two_bi; + + let xp_limbs_large: [F; N_LIMBS_LARGE] = + limb_decompose_ff::(&xp); + let yp_limbs_large: [F; N_LIMBS_LARGE] = + limb_decompose_ff::(&yp); + let xq_limbs_large: [F; N_LIMBS_LARGE] = + limb_decompose_ff::(&xq); + let yq_limbs_large: [F; N_LIMBS_LARGE] = + limb_decompose_ff::(&yq); + let f_limbs_large: [F; N_LIMBS_LARGE] = + limb_decompose_biguint::(f_bui.clone()); + let xr_limbs_large: [F; N_LIMBS_LARGE] = + limb_decompose_ff::(&xr); + let yr_limbs_large: [F; N_LIMBS_LARGE] = + limb_decompose_ff::(&yr); + + let xr_limbs_small: [F; N_LIMBS_SMALL] = + limb_decompose_ff::(&xr); + let yr_limbs_small: [F; N_LIMBS_SMALL] = + limb_decompose_ff::(&yr); + let slope_limbs_small: [F; N_LIMBS_SMALL] = + limb_decompose_ff::(&slope); + let slope_limbs_large: [F; N_LIMBS_LARGE] = + limb_decompose_ff::(&slope); + + write_column_array_const(env, &xp_limbs_large, |i| { + FECColumn::Input(FECColumnInput::XP(i)) + }); + write_column_array_const(env, &yp_limbs_large, |i| { + FECColumn::Input(FECColumnInput::YP(i)) + }); + write_column_array_const(env, &xq_limbs_large, |i| { + FECColumn::Input(FECColumnInput::XQ(i)) + }); + write_column_array_const(env, &yq_limbs_large, |i| { + FECColumn::Input(FECColumnInput::YQ(i)) + }); + write_column_array_const(env, &f_limbs_large, |i| { + FECColumn::Inter(FECColumnInter::F(i)) + }); + write_column_array_const(env, &xr_limbs_small, |i| { + FECColumn::Output(FECColumnOutput::XR(i)) + }); + write_column_array_const(env, &yr_limbs_small, |i| { + FECColumn::Output(FECColumnOutput::YR(i)) + }); + write_column_array_const(env, &slope_limbs_small, |i| { + FECColumn::Inter(FECColumnInter::S(i)) + }); + + let xp_bi: BigInt = FieldHelpers::to_bigint_positive(&xp); + let yp_bi: BigInt = FieldHelpers::to_bigint_positive(&yp); + let xq_bi: BigInt = FieldHelpers::to_bigint_positive(&xq); + let yq_bi: BigInt = FieldHelpers::to_bigint_positive(&yq); + let slope_bi: BigInt = FieldHelpers::to_bigint_positive(&slope); + let xr_bi: BigInt = FieldHelpers::to_bigint_positive(&xr); + let yr_bi: BigInt = FieldHelpers::to_bigint_positive(&yr); + + // Equation 1: s (xP - xQ) - (yP - yQ) - q_1 f = 0 + let (q1_bi, r1_bi) = (&slope_bi * (&xp_bi - &xq_bi) - (&yp_bi - &yq_bi)).div_rem(&f_bi); + assert!(r1_bi.is_zero()); + // Storing negative numbers is a mess. + let (q1_bi, q1_sign): (BigInt, F) = if q1_bi.is_negative() { + (-q1_bi, -F::one()) + } else { + (q1_bi, F::one()) + }; + + // Equation 2: xR - s^2 + xP + xQ - q_2 f = 0 + let (q2_bi, r2_bi) = (&xr_bi - &slope_bi * &slope_bi + &xp_bi + &xq_bi).div_rem(&f_bi); + assert!(r2_bi.is_zero()); + let (q2_bi, q2_sign): (BigInt, F) = if q2_bi.is_negative() { + (-q2_bi, -F::one()) + } else { + (q2_bi, F::one()) + }; + + // Equation 3: yR + yP - s (xP - xR) - q_3 f = 0 + let (q3_bi, r3_bi) = (&yr_bi + &yp_bi - &slope_bi * (&xp_bi - &xr_bi)).div_rem(&f_bi); + assert!(r3_bi.is_zero()); + let (q3_bi, q3_sign): (BigInt, F) = if q3_bi.is_negative() { + (-q3_bi, -F::one()) + } else { + (q3_bi, F::one()) + }; + + // TODO can this be better? + // Used for witness computation + // Big limbs /have/ sign in them. + let q1_limbs_large: [F; N_LIMBS_LARGE] = + limb_decompose_biguint::(q1_bi.to_biguint().unwrap()) + .into_iter() + .map(|v| v * q1_sign) + .collect::>() + .try_into() + .unwrap(); + let q2_limbs_large: [F; N_LIMBS_LARGE] = + limb_decompose_biguint::(q2_bi.to_biguint().unwrap()) + .into_iter() + .map(|v| v * q2_sign) + .collect::>() + .try_into() + .unwrap(); + let q3_limbs_large: [F; N_LIMBS_LARGE] = + limb_decompose_biguint::(q3_bi.to_biguint().unwrap()) + .into_iter() + .map(|v| v * q3_sign) + .collect::>() + .try_into() + .unwrap(); + + // Written into the columns + // small limbs are signless 15-bit + let q1_limbs_small: [F; N_LIMBS_SMALL] = + limb_decompose_biguint::(q1_bi.to_biguint().unwrap()); + let q2_limbs_small: [F; N_LIMBS_SMALL] = + limb_decompose_biguint::(q2_bi.to_biguint().unwrap()); + let q3_limbs_small: [F; N_LIMBS_SMALL] = + limb_decompose_biguint::(q3_bi.to_biguint().unwrap()); + + write_column_array_const(env, &q1_limbs_small, |i| { + FECColumn::Inter(FECColumnInter::Q1(i)) + }); + write_column_array_const(env, &q2_limbs_small, |i| { + FECColumn::Inter(FECColumnInter::Q2(i)) + }); + write_column_array_const(env, &q3_limbs_small, |i| { + FECColumn::Inter(FECColumnInter::Q3(i)) + }); + + write_column_const(env, FECColumn::Inter(FECColumnInter::Q1Sign), &q1_sign); + write_column_const(env, FECColumn::Inter(FECColumnInter::Q2Sign), &q2_sign); + write_column_const(env, FECColumn::Inter(FECColumnInter::Q3Sign), &q3_sign); + + write_column_array_const(env, &q1_limbs_large, |i| { + FECColumn::Inter(FECColumnInter::Q1L(i)) + }); + write_column_array_const(env, &q2_limbs_large, |i| { + FECColumn::Inter(FECColumnInter::Q2L(i)) + }); + write_column_array_const(env, &q3_limbs_large, |i| { + FECColumn::Inter(FECColumnInter::Q3L(i)) + }); + + let mut carry1: F = From::from(0u64); + let mut carry2: F = From::from(0u64); + let mut carry3: F = From::from(0u64); + + for i in 0..N_LIMBS_LARGE * 2 - 1 { + let compute_carry = |res: F| -> F { + // TODO enforce this as an integer division + let mut res_bi = res.to_bigint_positive(); + if res_bi > n_half_bi { + res_bi -= &n_bi; + } + let (div, rem) = res_bi.div_rem(&large_limb_size.to_bigint_positive()); + assert!( + rem.is_zero(), + "Cannot compute carry for step {i:?}: div {div:?}, rem {rem:?}" + ); + let carry_f: BigUint = bigint_to_biguint_f(div, &n_bi); + F::from_biguint(&carry_f).unwrap() + }; + + fn assign_carry( + env: &mut Env, + n_half_bi: &BigInt, + i: usize, + newcarry: F, + carryvar: &mut F, + column_mapper: ColMap, + ) where + F: PrimeField, + Env: ColWriteCap, + ColMap: Fn(usize) -> FECColumn, + { + // Last carry should be zero, otherwise we record it + if i < N_LIMBS_LARGE * 2 - 2 { + // Carries will often not fit into 5 limbs, but they /should/ fit in 6 limbs I think. + let newcarry_sign = if &newcarry.to_bigint_positive() > n_half_bi { + F::zero() - F::one() + } else { + F::one() + }; + let newcarry_abs_bui = (newcarry * newcarry_sign).to_biguint(); + // Our big carries are at most 79 bits, so we need 6 small limbs per each. + // But limbs are signed, so we split into 14-bit /signed/ limbs. + last chunk is signed 9 bit. + let newcarry_limbs: [F; 6] = + limb_decompose_biguint::( + newcarry_abs_bui.clone(), + ); + + for (j, limb) in newcarry_limbs.iter().enumerate() { + write_column_const(env, column_mapper(6 * i + j), &(newcarry_sign * limb)); + } + + *carryvar = newcarry; + } else { + // should this be in circiut? + assert!(newcarry.is_zero(), "Last carry is non-zero"); + } + } + + // Equation 1: s (xP - xQ) - (yP - yQ) - q_1 f = 0 + let mut res1 = fold_choice2(N_LIMBS_LARGE, i, |j, k| { + slope_limbs_large[j] * (xp_limbs_large[k] - xq_limbs_large[k]) + }); + if i < N_LIMBS_LARGE { + res1 -= yp_limbs_large[i] - yq_limbs_large[i]; + } + res1 -= fold_choice2(N_LIMBS_LARGE, i, |j, k| { + q1_limbs_large[j] * f_limbs_large[k] + }); + res1 += carry1; + let newcarry1 = compute_carry(res1); + assign_carry(env, &n_half_bi, i, newcarry1, &mut carry1, |i| { + FECColumn::Inter(FECColumnInter::Carry1(i)) + }); + + // Equation 2: xR - s^2 + xP + xQ - q_2 f = 0 + let mut res2 = F::zero(); + res2 -= fold_choice2(N_LIMBS_LARGE, i, |j, k| { + slope_limbs_large[j] * slope_limbs_large[k] + }); + if i < N_LIMBS_LARGE { + res2 += xr_limbs_large[i] + xp_limbs_large[i] + xq_limbs_large[i]; + } + res2 -= fold_choice2(N_LIMBS_LARGE, i, |j, k| { + q2_limbs_large[j] * f_limbs_large[k] + }); + res2 += carry2; + let newcarry2 = compute_carry(res2); + assign_carry(env, &n_half_bi, i, newcarry2, &mut carry2, |i| { + FECColumn::Inter(FECColumnInter::Carry2(i)) + }); + + // Equation 3: yR + yP - s (xP - xR) - q_3 f = 0 + let mut res3 = F::zero(); + res3 -= fold_choice2(N_LIMBS_LARGE, i, |j, k| { + slope_limbs_large[j] * (xp_limbs_large[k] - xr_limbs_large[k]) + }); + if i < N_LIMBS_LARGE { + res3 += yr_limbs_large[i] + yp_limbs_large[i]; + } + res3 -= fold_choice2(N_LIMBS_LARGE, i, |j, k| { + q3_limbs_large[j] * f_limbs_large[k] + }); + res3 += carry3; + let newcarry3 = compute_carry(res3); + assign_carry(env, &n_half_bi, i, newcarry3, &mut carry3, |i| { + FECColumn::Inter(FECColumnInter::Carry3(i)) + }); + } + + constrain_ec_addition::(env); + + (xr, yr) +} diff --git a/msm/src/fec/lookups.rs b/msm/src/fec/lookups.rs new file mode 100644 index 0000000000..2ca9d7dc66 --- /dev/null +++ b/msm/src/fec/lookups.rs @@ -0,0 +1,166 @@ +use crate::{logup::LookupTableID, Logup, LIMB_BITSIZE, N_LIMBS}; +use ark_ff::PrimeField; +use num_bigint::BigUint; +use o1_utils::FieldHelpers; +use std::marker::PhantomData; +use strum_macros::EnumIter; + +/// Enumeration of concrete lookup tables used in FEC circuit. +#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, Ord, PartialOrd, EnumIter)] +pub enum LookupTable { + /// x ∈ [0, 2^15] + RangeCheck15, + /// x ∈ [-2^14, 2^14-1] + RangeCheck14Abs, + /// x ∈ [-2^9, 2^9-1] + RangeCheck9Abs, + /// x ∈ [0, ff_highest] where ff_highest is the highest 15-bit + /// limb of the modulus of the foreign field `Ff`. + RangeCheckFfHighest(PhantomData), +} + +impl LookupTableID for LookupTable { + fn to_u32(&self) -> u32 { + match self { + Self::RangeCheck15 => 1, + Self::RangeCheck14Abs => 2, + Self::RangeCheck9Abs => 3, + Self::RangeCheckFfHighest(_) => 4, + } + } + + fn from_u32(value: u32) -> Self { + match value { + 1 => Self::RangeCheck15, + 2 => Self::RangeCheck14Abs, + 3 => Self::RangeCheck9Abs, + 4 => Self::RangeCheckFfHighest(PhantomData), + _ => panic!("Invalid lookup table id"), + } + } + + /// All tables are fixed tables. + fn is_fixed(&self) -> bool { + true + } + + fn runtime_create_column(&self) -> bool { + panic!("No runtime tables specified"); + } + + fn length(&self) -> usize { + match self { + Self::RangeCheck15 => 1 << 15, + Self::RangeCheck14Abs => 1 << 15, + Self::RangeCheck9Abs => 1 << 10, + Self::RangeCheckFfHighest(_) => TryFrom::try_from( + crate::serialization::interpreter::ff_modulus_highest_limb::(), + ) + .unwrap(), + } + } + + /// Converts a value to its index in the fixed table. + fn ix_by_value(&self, value: &[F]) -> Option { + let value = value[0]; + assert!(self.is_member(value)); + Some(match self { + Self::RangeCheck15 => TryFrom::try_from(value.to_biguint()).unwrap(), + Self::RangeCheck14Abs => { + if value < F::from(1u64 << 14) { + TryFrom::try_from(value.to_biguint()).unwrap() + } else { + TryFrom::try_from((value + F::from(2 * (1u64 << 14))).to_biguint()).unwrap() + } + } + Self::RangeCheck9Abs => { + if value < F::from(1u64 << 9) { + TryFrom::try_from(value.to_biguint()).unwrap() + } else { + TryFrom::try_from((value + F::from(2 * (1u64 << 9))).to_biguint()).unwrap() + } + } + Self::RangeCheckFfHighest(_) => TryFrom::try_from(value.to_biguint()).unwrap(), + }) + } + + fn all_variants() -> Vec { + vec![ + Self::RangeCheck15, + Self::RangeCheck14Abs, + Self::RangeCheck9Abs, + Self::RangeCheckFfHighest(PhantomData), + ] + } +} + +impl LookupTable { + fn entries_ff_highest(domain_d1_size: u64) -> Vec { + let top_modulus_f = + F::from_biguint(&crate::serialization::interpreter::ff_modulus_highest_limb::()) + .unwrap(); + (0..domain_d1_size) + .map(|i| { + if F::from(i) < top_modulus_f { + F::from(i) + } else { + F::zero() + } + }) + .collect() + } + + /// Provides a full list of entries for the given table. + pub fn entries(&self, domain_d1_size: u64) -> Vec { + assert!(domain_d1_size >= (1 << 15)); + match self { + Self::RangeCheck15 => (0..domain_d1_size).map(|i| F::from(i)).collect(), + Self::RangeCheck14Abs => (0..domain_d1_size) + .map(|i| { + if i < (1 << 14) { + F::from(i) + } else if i < 2 * (1 << 14) { + F::from(i) - F::from(2u64 * (1 << 14)) + } else { + F::zero() + } + }) + .collect(), + Self::RangeCheck9Abs => (0..domain_d1_size) + .map(|i| { + if i < (1 << 9) { + // [0,1,2 ... (1<<9)-1] + F::from(i) + } else if i < 2 * (1 << 9) { + // [-(i<<9),...-2,-1] + F::from(i) - F::from(2u64 * (1 << 9)) + } else { + F::zero() + } + }) + .collect(), + Self::RangeCheckFfHighest(_) => Self::entries_ff_highest::(domain_d1_size), + } + } + + /// Checks if a value is in a given table. + pub fn is_member(&self, value: F) -> bool { + match self { + Self::RangeCheck15 => value.to_biguint() < BigUint::from(2u128.pow(15)), + Self::RangeCheck14Abs => { + value < F::from(1u64 << 14) || value >= F::zero() - F::from(1u64 << 14) + } + Self::RangeCheck9Abs => { + value < F::from(1u64 << 9) || value >= F::zero() - F::from(1u64 << 9) + } + Self::RangeCheckFfHighest(_) => { + let f_bui: BigUint = TryFrom::try_from(Ff::MODULUS).unwrap(); + let top_modulus_f: F = + F::from_biguint(&(f_bui >> ((N_LIMBS - 1) * LIMB_BITSIZE))).unwrap(); + value < top_modulus_f + } + } + } +} + +pub type Lookup = Logup>; diff --git a/msm/src/fec/mod.rs b/msm/src/fec/mod.rs new file mode 100644 index 0000000000..6f0c17642c --- /dev/null +++ b/msm/src/fec/mod.rs @@ -0,0 +1,181 @@ +pub mod columns; +pub mod interpreter; +pub mod lookups; + +#[cfg(test)] +mod tests { + + use crate::{ + circuit_design::{ConstraintBuilderEnv, WitnessBuilderEnv}, + columns::ColumnIndexer, + fec::{ + columns::{FECColumn, FEC_N_COLUMNS}, + interpreter::{constrain_ec_addition, ec_add_circuit}, + lookups::LookupTable, + }, + logup::LookupTableID, + Ff1, Fp, + }; + use ark_ec::AffineRepr; + use ark_ff::UniformRand; + use rand::{CryptoRng, RngCore}; + use std::{ + collections::{BTreeMap, HashMap}, + ops::Mul, + }; + + type FECWitnessBuilderEnv = WitnessBuilderEnv< + Fp, + FECColumn, + { ::N_COL }, + { ::N_COL }, + 0, + 0, + LookupTable, + >; + + fn build_fec_addition_circuit( + rng: &mut RNG, + domain_size: usize, + ) -> FECWitnessBuilderEnv { + use mina_curves::pasta::Pallas; + // Fq = Ff2 + type Fq = ::ScalarField; + + let mut witness_env = WitnessBuilderEnv::create(); + + // To support less rows than domain_size we need to have selectors. + //let row_num = rng.gen_range(0..domain_size); + + let gen = Pallas::generator(); + + let kp: Fq = UniformRand::rand(rng); + let p: Pallas = gen.mul(kp).into(); + let px: Ff1 = p.x; + let py: Ff1 = p.y; + + for row_i in 0..domain_size { + let kq: Fq = UniformRand::rand(rng); + let q: Pallas = gen.mul(kq).into(); + + let qx: Ff1 = q.x; + let qy: Ff1 = q.y; + + let (rx, ry) = ec_add_circuit(&mut witness_env, px, py, qx, qy); + + let r: Pallas = ark_ec::models::short_weierstrass::Affine::new_unchecked(rx, ry); + + assert!( + r == p + q, + "fec addition circuit does not compute actual p + q, expected {} got {r:?}", + p + q + ); + + if row_i < domain_size - 1 { + witness_env.next_row(); + } + } + + witness_env + } + + #[test] + /// Builds the FF addition circuit with random values. The witness + /// environment enforces the constraints internally, so it is + /// enough to just build the circuit to ensure it is satisfied. + pub fn test_fec_addition_circuit() { + let mut rng = o1_utils::tests::make_test_rng(None); + build_fec_addition_circuit(&mut rng, 1 << 4); + } + + #[test] + pub fn test_regression_relation_constraints_fec() { + let mut constraint_env = ConstraintBuilderEnv::>::create(); + constrain_ec_addition::(&mut constraint_env); + let constraints = constraint_env.get_relation_constraints(); + + let mut constraints_degrees = HashMap::new(); + + assert_eq!(constraints.len(), 36); + + { + constraints.iter().for_each(|c| { + let degree = c.degree(1, 0); + *constraints_degrees.entry(degree).or_insert(0) += 1; + }); + + assert_eq!(constraints_degrees.get(&1), None); + assert_eq!(constraints_degrees.get(&2), Some(&36)); + assert_eq!(constraints_degrees.get(&3), None); + + assert!(constraints.iter().map(|c| c.degree(1, 0)).max() <= Some(3)); + } + } + + #[test] + pub fn test_regression_constraints_fec() { + let mut constraint_env = ConstraintBuilderEnv::>::create(); + constrain_ec_addition::(&mut constraint_env); + let constraints = constraint_env.get_constraints(); + + let mut constraints_degrees = HashMap::new(); + + assert_eq!(constraints.len(), 75); + + { + constraints.iter().for_each(|c| { + let degree = c.degree(1, 0); + *constraints_degrees.entry(degree).or_insert(0) += 1; + }); + + assert_eq!(constraints_degrees.get(&1), Some(&1)); + assert_eq!(constraints_degrees.get(&2), Some(&38)); + assert_eq!(constraints_degrees.get(&3), None); + assert_eq!(constraints_degrees.get(&4), None); + assert_eq!(constraints_degrees.get(&5), Some(&2)); + assert_eq!(constraints_degrees.get(&6), None); + assert_eq!(constraints_degrees.get(&7), Some(&34)); + } + } + + #[test] + pub fn heavy_test_fec_completeness() { + let mut rng = o1_utils::tests::make_test_rng(None); + let domain_size = 1 << 15; // Otherwise we can't do 15-bit lookups. + + let mut constraint_env = ConstraintBuilderEnv::>::create(); + constrain_ec_addition::(&mut constraint_env); + let constraints = constraint_env.get_constraints(); + + let witness_env = build_fec_addition_circuit(&mut rng, domain_size); + + // Fixed tables can be generated inside lookup_tables_data. Runtime should be generated here. + let mut lookup_tables_data = BTreeMap::new(); + for table_id in LookupTable::::all_variants().into_iter() { + lookup_tables_data.insert( + table_id, + vec![table_id + .entries(domain_size as u64) + .into_iter() + .map(|x| vec![x]) + .collect()], + ); + } + let proof_inputs = witness_env.get_proof_inputs(domain_size, lookup_tables_data); + + crate::test::test_completeness_generic::< + FEC_N_COLUMNS, + FEC_N_COLUMNS, + 0, + 0, + LookupTable, + _, + >( + constraints, + Box::new([]), + proof_inputs, + domain_size, + &mut rng, + ); + } +} diff --git a/msm/src/ffa/columns.rs b/msm/src/ffa/columns.rs new file mode 100644 index 0000000000..9d06b45ef6 --- /dev/null +++ b/msm/src/ffa/columns.rs @@ -0,0 +1,42 @@ +use crate::columns::{Column, ColumnIndexer}; + +use crate::N_LIMBS; + +/// Number of columns in the FFA circuits. +pub const FFA_N_COLUMNS: usize = 5 * N_LIMBS; +pub const FFA_NPUB_COLUMNS: usize = N_LIMBS; + +/// Column indexer for MSM columns. +/// +/// They represent the equation +/// `InputA(i) + InputB(i) = ModulusF(i) * Quotient + Carry(i) * 2^LIMB_SIZE - Carry(i-1)` +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub enum FFAColumn { + InputA(usize), + InputB(usize), + ModulusF(usize), + Remainder(usize), + Carry(usize), + Quotient, +} + +impl ColumnIndexer for FFAColumn { + const N_COL: usize = FFA_N_COLUMNS; + fn to_column(self) -> Column { + let to_column_inner = |offset, i| { + assert!(i < N_LIMBS); + Column::Relation(N_LIMBS * offset + i) + }; + match self { + FFAColumn::InputA(i) => to_column_inner(0, i), + FFAColumn::InputB(i) => to_column_inner(1, i), + FFAColumn::ModulusF(i) => to_column_inner(2, i), + FFAColumn::Remainder(i) => to_column_inner(3, i), + FFAColumn::Carry(i) => { + assert!(i < N_LIMBS - 1); + to_column_inner(4, i) + } + FFAColumn::Quotient => to_column_inner(4, N_LIMBS - 1), + } + } +} diff --git a/msm/src/ffa/constraint.rs b/msm/src/ffa/constraint.rs new file mode 100644 index 0000000000..66fab6ff74 --- /dev/null +++ b/msm/src/ffa/constraint.rs @@ -0,0 +1,63 @@ +use crate::{ + columns::{Column, ColumnIndexer}, + expr::E, + ffa::{columns::FFAColumnIndexer, interpreter::FFAInterpreterEnv}, +}; +use ark_ff::PrimeField; +use kimchi::circuits::{ + expr::{ConstantExpr, ConstantTerm, Expr, ExprInner, Variable}, + gate::CurrOrNext, +}; + +/// Contains constraints for just one row. +pub struct ConstraintBuilderEnv { + pub constraints: Vec>, +} + +impl FFAInterpreterEnv for ConstraintBuilderEnv { + type Position = Column; + + type Variable = E; + + fn empty() -> Self { + ConstraintBuilderEnv { + constraints: vec![], + } + } + + fn assert_zero(&mut self, cst: Self::Variable) { + self.constraints.push(cst) + } + + fn copy(&mut self, x: &Self::Variable, position: Self::Position) -> Self::Variable { + let y = Expr::Atom(ExprInner::Cell(Variable { + col: position, + row: CurrOrNext::Curr, + })); + self.constraints.push(y.clone() - x.clone()); + y + } + + fn constant(value: F) -> Self::Variable { + let cst_expr_inner = ConstantExpr::from(ConstantTerm::Literal(value)); + Expr::Atom(ExprInner::Constant(cst_expr_inner)) + } + + // TODO deduplicate, remove this + fn column_pos(ix: FFAColumnIndexer) -> Self::Position { + ix.to_column() + } + + fn read_column(&self, ix: FFAColumnIndexer) -> Self::Variable { + Expr::Atom(ExprInner::Cell(Variable { + col: ix.to_column(), + row: CurrOrNext::Curr, + })) + } + + fn range_check_abs1(&mut self, _value: &Self::Variable) {} + + fn range_check_15bit(&mut self, _value: &Self::Variable) { + // FIXME unimplemented + } +} diff --git a/msm/src/ffa/interpreter.rs b/msm/src/ffa/interpreter.rs new file mode 100644 index 0000000000..f7da6355c2 --- /dev/null +++ b/msm/src/ffa/interpreter.rs @@ -0,0 +1,136 @@ +use crate::{ + circuit_design::{ColAccessCap, ColWriteCap, LookupCap}, + ffa::{columns::FFAColumn, lookups::LookupTable}, + serialization::interpreter::{limb_decompose_biguint, limb_decompose_ff}, + LIMB_BITSIZE, N_LIMBS, +}; +use ark_ff::PrimeField; +use num_bigint::BigUint; +use num_integer::Integer; +use o1_utils::field_helpers::FieldHelpers; + +// For now this function does not /compute/ anything, although it could. +/// Constraint for one row of FF addition: +/// +/// - First: a_0 + b_0 - q * f_0 - r_0 - c_0 * 2^{15} = 0 +/// - Intermediate: a_i + b_i - q * f_i - r_i - c_i * 2^{15} + c_{i-1} = 0 +/// - Last (n=16): a_n + b_n - q * f_n - r_n + c_{n-1} = 0 +/// +/// q, c_i ∈ {-1,0,1} +/// a_i, b_i, f_i, r_i ∈ [0,2^15) +pub fn constrain_ff_addition_row< + F: PrimeField, + Env: ColAccessCap + LookupCap, +>( + env: &mut Env, + limb_num: usize, +) { + let a: Env::Variable = Env::read_column(env, FFAColumn::InputA(limb_num)); + let b: Env::Variable = Env::read_column(env, FFAColumn::InputB(limb_num)); + let f: Env::Variable = Env::read_column(env, FFAColumn::ModulusF(limb_num)); + let r: Env::Variable = Env::read_column(env, FFAColumn::Remainder(limb_num)); + let q: Env::Variable = Env::read_column(env, FFAColumn::Quotient); + env.lookup(LookupTable::RangeCheck15, vec![a.clone()]); + env.lookup(LookupTable::RangeCheck15, vec![b.clone()]); + env.lookup(LookupTable::RangeCheck15, vec![f.clone()]); + env.lookup(LookupTable::RangeCheck15, vec![r.clone()]); + env.lookup(LookupTable::RangeCheck1BitSigned, vec![q.clone()]); + let constraint = if limb_num == 0 { + let limb_size = Env::constant(From::from((1 << LIMB_BITSIZE) as u64)); + let c0: Env::Variable = Env::read_column(env, FFAColumn::Carry(limb_num)); + env.lookup(LookupTable::RangeCheck1BitSigned, vec![c0.clone()]); + a + b - q * f - r - c0 * limb_size + } else if limb_num < N_LIMBS - 1 { + let limb_size = Env::constant(From::from((1 << LIMB_BITSIZE) as u64)); + let c_prev: Env::Variable = Env::read_column(env, FFAColumn::Carry(limb_num - 1)); + let c_cur: Env::Variable = Env::read_column(env, FFAColumn::Carry(limb_num)); + env.lookup(LookupTable::RangeCheck1BitSigned, vec![c_prev.clone()]); + env.lookup(LookupTable::RangeCheck1BitSigned, vec![c_cur.clone()]); + a + b - q * f - r - c_cur * limb_size + c_prev + } else { + let c_prev: Env::Variable = Env::read_column(env, FFAColumn::Carry(limb_num - 1)); + env.lookup(LookupTable::RangeCheck1BitSigned, vec![c_prev.clone()]); + a + b - q * f - r + c_prev + }; + env.assert_zero(constraint); +} + +pub fn constrain_ff_addition< + F: PrimeField, + Env: ColAccessCap + LookupCap, +>( + env: &mut Env, +) { + for limb_i in 0..N_LIMBS { + constrain_ff_addition_row(env, limb_i); + } +} + +pub fn ff_addition_circuit< + F: PrimeField, + Ff: PrimeField, + Env: ColAccessCap + ColWriteCap + LookupCap, +>( + env: &mut Env, + a: Ff, + b: Ff, +) { + let f_bigint: BigUint = TryFrom::try_from(Ff::MODULUS).unwrap(); + + let a_limbs: [F; N_LIMBS] = limb_decompose_ff::(&a); + let b_limbs: [F; N_LIMBS] = limb_decompose_ff::(&b); + let f_limbs: [F; N_LIMBS] = + limb_decompose_biguint::(f_bigint.clone()); + a_limbs.iter().enumerate().for_each(|(i, var)| { + env.write_column(FFAColumn::InputA(i), &Env::constant(*var)); + }); + b_limbs.iter().enumerate().for_each(|(i, var)| { + env.write_column(FFAColumn::InputB(i), &Env::constant(*var)); + }); + f_limbs.iter().enumerate().for_each(|(i, var)| { + env.write_column(FFAColumn::ModulusF(i), &Env::constant(*var)); + }); + + let a_bigint = FieldHelpers::to_biguint(&a); + let b_bigint = FieldHelpers::to_biguint(&b); + + // TODO FIXME this computation must be done over BigInts, not BigUInts + // q can be -1! But only in subtraction, so for now we don't care. + // for now with addition only q ∈ {0,1} + let (q_bigint, r_bigint) = (a_bigint + b_bigint).div_rem(&f_bigint); + let r_limbs: [F; N_LIMBS] = limb_decompose_biguint::(r_bigint); + // We expect just one limb. + let q: F = limb_decompose_biguint::(q_bigint)[0]; + + env.write_column(FFAColumn::Quotient, &Env::constant(q)); + r_limbs.iter().enumerate().for_each(|(i, var)| { + env.write_column(FFAColumn::Remainder(i), &Env::constant(*var)); + }); + + let limb_size: F = From::from((1 << LIMB_BITSIZE) as u64); + let mut carry: F = From::from(0u64); + for limb_i in 0..N_LIMBS { + let res = a_limbs[limb_i] + b_limbs[limb_i] - q * f_limbs[limb_i] - r_limbs[limb_i] + carry; + let newcarry: F = if res == limb_size { + // Overflow + F::one() + } else if res == -limb_size { + // Underflow + F::zero() - F::one() + } else if res.is_zero() { + // Neither overflow nor overflow, the transcendent way of being + F::zero() + } else { + panic!("Computed carry is not -1,0,1, impossible: limb number {limb_i:?}") + }; + // Last carry should be zero, otherwise we record it + if limb_i < N_LIMBS - 1 { + env.write_column(FFAColumn::Carry(limb_i), &Env::constant(newcarry)); + carry = newcarry; + } else { + // should this be in circiut? + assert!(newcarry.is_zero()); + } + constrain_ff_addition_row(env, limb_i); + } +} diff --git a/msm/src/ffa/lookups.rs b/msm/src/ffa/lookups.rs new file mode 100644 index 0000000000..9c75a20251 --- /dev/null +++ b/msm/src/ffa/lookups.rs @@ -0,0 +1,94 @@ +use crate::logup::LookupTableID; +use ark_ff::PrimeField; +use num_bigint::BigUint; +use o1_utils::FieldHelpers; +use strum_macros::EnumIter; + +/// Enumeration of concrete lookup tables used in lookups circuit. +#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, Ord, PartialOrd, EnumIter)] +pub enum LookupTable { + /// x ∈ [0, 2^15] + RangeCheck15, + /// x ∈ [-1, 0, 1] + RangeCheck1BitSigned, +} + +impl LookupTableID for LookupTable { + fn to_u32(&self) -> u32 { + match self { + Self::RangeCheck15 => 1, + Self::RangeCheck1BitSigned => 2, + } + } + + fn from_u32(value: u32) -> Self { + match value { + 1 => Self::RangeCheck15, + 2 => Self::RangeCheck1BitSigned, + _ => panic!("Invalid lookup table id"), + } + } + + /// All tables are fixed tables. + fn is_fixed(&self) -> bool { + true + } + + fn runtime_create_column(&self) -> bool { + panic!("No runtime tables specified"); + } + + fn length(&self) -> usize { + match self { + Self::RangeCheck15 => 1 << 15, + Self::RangeCheck1BitSigned => 3, + } + } + + /// Converts a value to its index in the fixed table. + fn ix_by_value(&self, value: &[F]) -> Option { + let value = value[0]; + Some(match self { + Self::RangeCheck15 => TryFrom::try_from(value.to_biguint()).unwrap(), + Self::RangeCheck1BitSigned => { + if value == F::zero() { + 0 + } else if value == F::one() { + 1 + } else if value == F::zero() - F::one() { + 2 + } else { + panic!("Invalid value for rangecheck1abs") + } + } + }) + } + + fn all_variants() -> Vec { + vec![Self::RangeCheck15, Self::RangeCheck1BitSigned] + } +} + +impl LookupTable { + /// Provides a full list of entries for the given table. + pub fn entries(&self, domain_d1_size: u64) -> Vec { + assert!(domain_d1_size >= (1 << 15)); + match self { + Self::RangeCheck1BitSigned => [F::zero(), F::one(), F::zero() - F::one()] + .into_iter() + .chain((3..domain_d1_size).map(|_| F::one())) // dummies are 1s + .collect(), + Self::RangeCheck15 => (0..domain_d1_size).map(|i| F::from(i)).collect(), + } + } + + /// Checks if a value is in a given table. + pub fn is_member(&self, value: F) -> bool { + match self { + Self::RangeCheck1BitSigned => { + value == F::zero() || value == F::one() || value == F::zero() - F::one() + } + Self::RangeCheck15 => value.to_biguint() < BigUint::from(2u128.pow(15)), + } + } +} diff --git a/msm/src/ffa/mod.rs b/msm/src/ffa/mod.rs new file mode 100644 index 0000000000..811e1e7dcf --- /dev/null +++ b/msm/src/ffa/mod.rs @@ -0,0 +1,103 @@ +pub mod columns; +pub mod interpreter; +pub mod lookups; + +#[cfg(test)] +mod tests { + + use crate::{ + circuit_design::{ConstraintBuilderEnv, WitnessBuilderEnv}, + columns::ColumnIndexer, + ffa::{ + columns::FFAColumn, + interpreter::{self as ffa_interpreter}, + lookups::LookupTable, + }, + logup::LookupTableID, + Ff1, Fp, + }; + use ark_ff::UniformRand; + use rand::{CryptoRng, RngCore}; + use std::collections::BTreeMap; + + type FFAWitnessBuilderEnv = WitnessBuilderEnv< + Fp, + FFAColumn, + { ::N_COL }, + { ::N_COL }, + 0, + 0, + LookupTable, + >; + + /// Builds the FF addition circuit with random values. The witness + /// environment enforces the constraints internally, so it is + /// enough to just build the circuit to ensure it is satisfied. + fn build_ffa_circuit( + rng: &mut RNG, + domain_size: usize, + ) -> FFAWitnessBuilderEnv { + let mut witness_env = FFAWitnessBuilderEnv::create(); + + for _row_i in 0..domain_size { + let a: Ff1 = ::rand(rng); + let b: Ff1 = ::rand(rng); + + //use rand::Rng; + //let a: Ff1 = From::from(rng.gen_range(0..(1 << 50))); + //let b: Ff1 = From::from(rng.gen_range(0..(1 << 50))); + ffa_interpreter::ff_addition_circuit(&mut witness_env, a, b); + witness_env.next_row(); + } + + witness_env + } + + #[test] + /// Tests if FFA circuit is valid. + pub fn test_ffa_circuit() { + let mut rng = o1_utils::tests::make_test_rng(None); + build_ffa_circuit(&mut rng, 1 << 4); + } + + #[test] + pub fn heavy_test_ffa_completeness() { + let mut rng = o1_utils::tests::make_test_rng(None); + let domain_size = 1 << 15; // Otherwise we can't do 15-bit lookups. + + let mut constraint_env = ConstraintBuilderEnv::::create(); + ffa_interpreter::constrain_ff_addition(&mut constraint_env); + let constraints = constraint_env.get_constraints(); + + let witness_env = build_ffa_circuit(&mut rng, domain_size); + + // Fixed tables can be generated inside lookup_tables_data. Runtime should be generated here. + let mut lookup_tables_data = BTreeMap::new(); + for table_id in LookupTable::all_variants().into_iter() { + lookup_tables_data.insert( + table_id, + vec![table_id + .entries(domain_size as u64) + .into_iter() + .map(|x| vec![x]) + .collect()], + ); + } + let proof_inputs = witness_env.get_proof_inputs(domain_size, lookup_tables_data); + + crate::test::test_completeness_generic::< + { ::N_COL }, + { ::N_COL }, + 0, + 0, + LookupTable, + _, + >( + constraints, + Box::new([]), + proof_inputs, + domain_size, + &mut rng, + ); + } +} diff --git a/msm/src/ffa/witness.rs b/msm/src/ffa/witness.rs new file mode 100644 index 0000000000..009b773080 --- /dev/null +++ b/msm/src/ffa/witness.rs @@ -0,0 +1,113 @@ +use ark_ff::{PrimeField, Zero}; + +use crate::{ + columns::{Column, ColumnIndexer}, + ffa::{ + columns::{FFAColumnIndexer, FFA_N_COLUMNS}, + interpreter::FFAInterpreterEnv, + }, + lookups::LookupTableIDs, + proof::ProofInputs, + witness::Witness, + Fp, +}; + +#[allow(dead_code)] +/// Builder environment for a native group `G`. +pub struct WitnessBuilderEnv { + /// Aggregated witness, in raw form. For accessing [`Witness`], see the + /// `get_witness` method. + witness: Vec>, +} + +impl FFAInterpreterEnv for WitnessBuilderEnv { + type Position = Column; + + type Variable = F; + + fn empty() -> Self { + WitnessBuilderEnv { + witness: vec![Witness { + cols: Box::new([Zero::zero(); FFA_N_COLUMNS]), + }], + } + } + + fn assert_zero(&mut self, cst: Self::Variable) { + assert_eq!(cst, F::zero()); + } + + fn constant(value: F) -> Self::Variable { + value + } + + fn copy(&mut self, value: &Self::Variable, position: Self::Position) -> Self::Variable { + let Column::X(i) = position else { todo!() }; + self.witness.last_mut().unwrap().cols[i] = *value; + *value + } + + // TODO deduplicate, remove this + fn column_pos(ix: FFAColumnIndexer) -> Self::Position { + ix.to_column() + } + + fn read_column(&self, ix: FFAColumnIndexer) -> Self::Variable { + let Column::X(i) = Self::column_pos(ix) else { + todo!() + }; + self.witness.last().unwrap().cols[i] + } + + fn range_check_abs1(&mut self, _value: &Self::Variable) { + // FIXME unimplemented + } + + fn range_check_15bit(&mut self, _value: &Self::Variable) { + // FIXME unimplemented + } +} + +impl WitnessBuilderEnv { + /// Each WitnessColumn stands for both one row and multirow. This + /// function converts from a vector of one-row instantiation to a + /// single multi-row form (which is a `Witness`). + pub fn get_witness( + &self, + domain_size: usize, + ) -> ProofInputs { + let mut cols: [Vec; FFA_N_COLUMNS] = std::array::from_fn(|_| vec![]); + + if self.witness.len() > domain_size { + panic!("Too many witness rows added"); + } + + // Filling actually used rows + for w in &self.witness { + let Witness { cols: witness_row } = w; + for i in 0..FFA_N_COLUMNS { + cols[i].push(witness_row[i]); + } + } + + // Filling ther rows up to the domain size + for _ in self.witness.len()..domain_size { + for col in cols.iter_mut() { + col.push(Zero::zero()); + } + } + + ProofInputs { + evaluations: Witness { + cols: Box::new(cols), + }, + logups: vec![], + } + } + + pub fn next_row(&mut self) { + self.witness.push(Witness { + cols: Box::new([Zero::zero(); FFA_N_COLUMNS]), + }); + } +} diff --git a/msm/src/lib.rs b/msm/src/lib.rs new file mode 100644 index 0000000000..c87295ad1a --- /dev/null +++ b/msm/src/lib.rs @@ -0,0 +1,61 @@ +use mina_poseidon::{ + constants::PlonkSpongeConstantsKimchi, + sponge::{DefaultFqSponge, DefaultFrSponge}, +}; +use poly_commitment::kzg::KZGProof; + +pub use logup::{ + Logup, LogupWitness, LookupProof as LogupProof, LookupTable as LogupTable, + LookupTableID as LogupTableID, LookupTableID, +}; + +pub mod circuit_design; +pub mod column_env; +pub mod columns; +pub mod expr; +pub mod logup; +/// Instantiations of Logups for the MSM project +// REMOVEME. The different interpreters must define their own tables. +pub mod lookups; +pub mod precomputed_srs; +pub mod proof; +pub mod prover; +pub mod verifier; +pub mod witness; + +pub mod fec; +pub mod ffa; +pub mod serialization; +pub mod test; + +/// Define the maximum degree we support for the evaluations. +/// For instance, it can be used to split the looked-up functions into partial +/// sums. +const MAX_SUPPORTED_DEGREE: usize = 8; + +/// Domain size for the MSM project, equal to the BN254 SRS size. +pub const DOMAIN_SIZE: usize = 1 << 15; + +// @volhovm: maybe move these to the FF circuits module later. +/// Bitsize of the foreign field limb representation. +pub const LIMB_BITSIZE: usize = 15; + +/// Number of limbs representing one foreign field element (either +/// [`Ff1`] or [`Ff2`]). +pub const N_LIMBS: usize = 17; + +pub type BN254 = ark_ec::bn::Bn; +pub type BN254G1Affine = ::G1Affine; +pub type BN254G2Affine = ::G2Affine; + +/// The native field we are working with. +pub type Fp = ark_bn254::Fr; + +/// The foreign field we are emulating (one of the two) +pub type Ff1 = mina_curves::pasta::Fp; +pub type Ff2 = mina_curves::pasta::Fq; + +pub type SpongeParams = PlonkSpongeConstantsKimchi; +pub type BaseSponge = DefaultFqSponge; +pub type ScalarSponge = DefaultFrSponge; +pub type OpeningProof = KZGProof; diff --git a/msm/src/logup.rs b/msm/src/logup.rs new file mode 100644 index 0000000000..19ba65a191 --- /dev/null +++ b/msm/src/logup.rs @@ -0,0 +1,958 @@ +#![allow(clippy::type_complexity)] + +//! Implement a variant of the logarithmic derivative lookups based on the +//! equations described in the paper ["Multivariate lookups based on logarithmic +//! derivatives"](https://eprint.iacr.org/2022/1530.pdf). +//! +//! The variant is mostly based on the observation that the polynomial +//! identities can be verified using the "Idealised low-degree protocols" +//! described in the section 4 of the +//! ["PlonK"](https://eprint.iacr.org/2019/953.pdf) paper and "the quotient +//! polynomial" described in the round 3 of the PlonK protocol, instead of using +//! the sumcheck protocol. +//! +//! The protocol is based on the following observations: +//! +//! The sequence (a_i) is included in (b_i) if and only if the following +//! equation holds: +//! ```text +//! k 1 l m_i +//! ∑ ------- = ∑ ------- (1) +//! i=1 β + a_i i=1 β + b_i +//! ``` +//! where m_i is the number of times a_i appears in the sequence b_i. +//! +//! The sequence (b_i) will refer to the table values and the sequence (a_i) the +//! values the prover looks up. +//! +//! For readability, the table values are represented as the evaluations over a +//! subgroup H of the field F of a +//! polynomial t(X), and the looked-up values by the evaluations of a polynomial +//! f(X). If we suppose the subgroup H is defined as {1, ω, ω^2, ..., ω^{n-1}}, +//! the equation (1) becomes: +//! +//! ```text +//! n 1 n m(ω^i) +//! ∑ ---------- = ∑ ---------- (2) +//! i=1 β + f(ω^i) i=1 β + t(ω^i) +//! ``` +//! +//! In the codebase, the multiplicities m_i are called the "lookup counters". +//! +//! The protocol can be generalized to multiple "looked-up" polynomials f_1, +//! ..., f_k (embedded in the structure `LogupWitness` in the codebase) and the +//! equation (2) becomes: +//! +//! ```text +//! n k 1 n m(ω^i) +//! ∑ ∑ ------------ = ∑ ----------- (3) +//! i=1 j=1 β + f_j(ω^i) i=1 β + t(ω^i) +//! ``` +//! +//! which can be rewritten as: +//! ```text +//! n ( k 1 m(ω^i) ) +//! ∑ ( ∑ ------------ - ----------- ) = 0 (4) +//! i=1 ( j=1 β + f_j(ω^i) β + t(ω^i) ) +//! \ / +//! ----------------------------------- +//! "inner sums", h(ω^i) +//! ``` +//! +//! The equation says that if we sum/accumulate the "inner sums" (called the +//! "lookup terms" in the codebase) over the +//! subgroup H, we will get a zero value. Note the analogy with the +//! "multiplicative" accumulator used in the lookup argument called +//! ["Plookup"](https://eprint.iacr.org/2020/315.pdf). +//! +//! We will define an accumulator ϕ : H -> F (called the "lookup aggregation" in +//! the codebase) which will contain the "running +//! inner sums" which will be equal to zero to start, and when we finished +//! accumulating, it must equal zero. Note that the initial and final values can +//! be anything. The idea of the equation 4 is that all the values have been +//! read and written to the accumulator the right number of times, with respect +//! to the multiplicities m. +//! More precisely, we will have: +//! ```text +//! - φ(1) = 0 +//! h(ω^j) +//! /----------------------------------\ +//! ( k 1 m(ω^j) ) +//! - φ(ω^{j + 1}) = φ(ω^j) + ( ∑ ------------ - ----------- ) +//! ( i=1 β + f_i(ω^j) β + t(ω^j) ) +//! +//! - φ(ω^n) = φ(1) = 0 +//! ``` +//! +//! We will split the inner sums into chunks of size (MAX_SUPPORTED_DEGREE - 2) +//! to avoid having a too large degree for the quotient polynomial. +//! As a reminder, the paper ["Multivariate lookups based on logarithmic +//! derivatives"](https://eprint.iacr.org/2022/1530.pdf) uses the sumcheck +//! protocol to compute the partial sums (equations 16 and 17). However, we use +//! the PlonK polynomial IOP and therefore, we will use the quotient polynomial, +//! and the computation of the partial sums will be translated into a constraint +//! in a new power of alpha. +//! +//! Note that the inner sum h(X) can be constrainted as followed: +//! ```text +//! k k / k \ +//! h(X) * ᴨ (β + f_{i}(X)) = ∑ | m_{i}(X) * ᴨ (β + f_{j}(X)) | (5) +//! i=0 i=0 | j=0 | +//! \ j≠i / +//! ``` +//! (with m_i(X) being the multiplicities for `i = 0` and `-1` otherwise, and +//! f_0(X) being the table t(X)). +//! More than one "inner sum" can be created in the case that `k + 2` is higher +//! than the maximum degree supported. +//! The quotient polynomial, defined at round 3 of the [PlonK +//! protocol](https://eprint.iacr.org/2019/953.pdf), will be something like: +//! +//! ```text +//! ... + α^i [φ(ω X) - φ(X) - h(X)] + α^(i + 1) (5) + ... +//! t(X) = ------------------------------------------------------ +//! Z_H(X) +//! ``` +//! +//! `k` can then be seen as the number of lookups we can make per row. The +//! additional cost when we reach the maximum degree supported is to add a new +//! constraint and add a new column. +//! For rows with less than `k` lookups, the prover will add a dummy value, +//! which will be a value known to be in the table, and the multiplicity must be +//! increased appropriately. +//! +//! To handle more than one table, we will use a table ID and transform the +//! single value lookup into a vector lookup, using a random combiner. +//! The protocol can also handle vector lookups, by using the random combiner. +//! The looked-up values therefore become functions f_j: H x H x ... x H -> F +//! and is transformed into a f'_j: H -> F using a random combiner `r`. +//! +//! To summarize, the prover will: +//! - commit to the multiplicities m. +//! - commit to individual looked-up values f (which include the table t) which +//! should be already included in the PlonK protocol as columns. +//! - coin an evaluation point β. +//! - coin a random combiner j (used to aggregate the table ID and concatenate +//! vector lookups, if any). +//! - commit to the inner sums/lookup terms h. +//! - commit to the running sum φ. +//! - add constraints to the quotient polynomial. +//! - evaluate all polynomials at the evaluation points ζ and ζω (because we +//! access the "next" row for the accumulator in the quotient polynomial). +use ark_ff::{Field, PrimeField, Zero}; +use kimchi::circuits::{ + berkeley_columns::BerkeleyChallengeTerm, + expr::{ConstantExpr, ConstantTerm, Expr, ExprInner}, +}; +use std::{collections::BTreeMap, hash::Hash}; + +use crate::{ + columns::Column, + expr::{curr_cell, next_cell, E}, + MAX_SUPPORTED_DEGREE, +}; + +/// Generic structure to represent a (vector) lookup the table with ID +/// `table_id`. +/// +/// The structure represents the individual fraction of the sum described in the +/// Logup protocol (for instance Eq. 8). +/// +/// The table ID is added to the random linear combination formed with the +/// values. The combiner for the random linear combination is coined during the +/// proving phase by the prover. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Logup { + pub(crate) table_id: ID, + pub(crate) numerator: F, + pub(crate) value: Vec, +} + +/// Basic trait for logarithmic lookups. +impl Logup +where + F: Clone, + ID: LookupTableID, +{ + /// Creates a new Logup + pub fn new(table_id: ID, numerator: F, value: &[F]) -> Self { + Self { + table_id, + numerator, + value: value.to_vec(), + } + } +} + +/// Trait for lookup table variants +pub trait LookupTableID: + Send + Sync + Copy + Hash + Eq + PartialEq + Ord + PartialOrd + core::fmt::Debug +{ + /// Assign a unique ID, as a u32 value + fn to_u32(&self) -> u32; + + /// Build a value from a u32 + fn from_u32(value: u32) -> Self; + + /// Assign a unique ID to the lookup tables. + fn to_field(&self) -> F { + F::from(self.to_u32()) + } + + /// Identify fixed and RAMLookups with a boolean. + /// This can be used to identify the lookups whose table values are fixed, + /// like range checks. + fn is_fixed(&self) -> bool; + + /// If a table is runtime table, `true` means we should create an + /// explicit extra column for it to "read" from. `false` means + /// that this table will be reading from some existing (e.g. + /// relation) columns, and no extra columns should be added. + /// + /// Panics if the argument is a fixed table. + fn runtime_create_column(&self) -> bool; + + /// Assign a unique ID to the lookup tables, as an expression. + fn to_constraint(&self) -> E { + let f = self.to_field(); + let f = ConstantExpr::from(ConstantTerm::Literal(f)); + E::Atom(ExprInner::Constant(f)) + } + + /// Returns the length of each table. + fn length(&self) -> usize; + + /// Returns None if the table is runtime (and thus mapping value + /// -> ix is not known at compile time. + fn ix_by_value(&self, value: &[F]) -> Option; + + fn all_variants() -> Vec; +} + +/// A table of values that can be used for a lookup, along with the ID for the table. +#[derive(Debug, Clone)] +pub struct LookupTable { + /// Table ID corresponding to this table + pub table_id: ID, + /// Vector of values inside each entry of the table + pub entries: Vec>, +} + +/// Represents a witness of one instance of the lookup argument +// IMPROVEME: Possible to index by a generic const? +// The parameter N is the number of functions/looked-up values per row. It is +// used by the PlonK polynomial IOP to compute the number of partial sums. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LogupWitness { + /// A list of functions/looked-up values. + /// Invariant: for fixed lookup tables, the last value of the vector is the + /// lookup table t. The lookup table values must have a negative sign. + /// The values are represented as: + /// [ [f_{1}(1), ..., f_{1}(ω^(n-1)], + /// [f_{2}(1), ..., f_{2}(ω^(n-1)] + /// ... + /// [f_{m}(1), ..., f_{m}(ω^(n-1)] + /// ] + // + // TODO: for efficiency, as we go through columns and after that row, we + // should reorganize this. While working on the interpreter, we might + // change this structure. + // + // TODO: for efficiency, we might want to have a single flat fixed-size + // array + pub f: Vec>>, + /// The multiplicity polynomials; by convention, this is a vector + /// of columns, corresponding to the `tail` of `f`. That is, + /// `m[last] ~ f[last]`. + pub m: Vec>, +} + +/// Represents the proof of the lookup argument +/// It is parametrized by the type `T` which can be either: +/// - `Polycomm` for the commitments +/// - `F` for the evaluations at ζ (resp. ζω). +// FIXME: We should have a fixed number of m and h. Should we encode that in +// the type? +#[derive(Debug, Clone)] +pub struct LookupProof { + /// The multiplicity polynomials + pub(crate) m: BTreeMap>, + /// The polynomial keeping the sum of each row + pub(crate) h: BTreeMap>, + /// The "running-sum" over the rows, coined `φ` + pub(crate) sum: T, + /// All fixed lookup tables values, indexed by their ID + pub(crate) fixed_tables: BTreeMap, +} + +/// Iterator implementation to abstract the content of the structure. +/// It can be used to iterate over the commitments (resp. the evaluations) +/// without requiring to have a look at the inner fields. +impl<'lt, G, ID: LookupTableID> IntoIterator for &'lt LookupProof { + type Item = &'lt G; + type IntoIter = std::vec::IntoIter<&'lt G>; + + fn into_iter(self) -> Self::IntoIter { + let mut iter_contents = vec![]; + // First multiplicities + self.m.values().for_each(|m| iter_contents.extend(m)); + self.h.values().for_each(|h| iter_contents.extend(h)); + iter_contents.push(&self.sum); + // Fixed tables + self.fixed_tables + .values() + .for_each(|t| iter_contents.push(t)); + iter_contents.into_iter() + } +} + +/// Compute the following constraint: +/// ```text +/// lhs +/// |------------------------------------------| +/// | denominators | +/// | /--------------\ | +/// column * (\prod_{i = 0}^{N} (β + f_{i}(X))) = +/// \sum_{i = 0}^{N} m_{i} * \prod_{j = 1, j \neq i}^{N} (β + f_{j}(X)) +/// | |--------------------------------------------------| +/// | Inner part of rhs | +/// | | +/// | / +/// \ / +/// \ / +/// \---------------------------------------------------------/ +/// rhs +/// ``` +/// It is because h(X) (column) is defined as: +/// ```text +/// n m_i(X) n 1 m_0(ω^j) +/// h(X) = ∑ ---------- = ∑ ------------ - ----------- +/// i=0 β + f_i(X) i=1 β + f_i(ω^j) β + t(ω^j) +///``` +/// The first form is generic, the second is concrete with f_0 = t; m_0 = m; m_i = 1 for ≠ 1. +/// We will be thinking in the generic form. +/// +/// For instance, if N = 2, we have +/// ```text +/// h(X) = m_1(X) / (β + f_1(X)) + m_2(X) / (β + f_{2}(X)) +/// +/// m_1(X) * (β + f_2(X)) + m_2(X) * (β + f_{1}(X)) +/// = ---------------------------------------------- +/// (β + f_2(X)) * (β + f_1(X)) +/// ``` +/// which is equivalent to +/// ```text +/// h(X) * (β + f_2(X)) * (β + f_1(X)) = m_1(X) * (β + f_2(X)) + m_2(X) * (β + f_{1}(X)) +/// ``` +/// When we have f_1(X) a looked-up value, t(X) a fixed table and m_2(X) being +/// the multiplicities, we have +/// ```text +/// h(X) * (β + t(X)) * (β + f(X)) = (β + t(X)) + m(X) * (β + f(X)) +/// ``` +pub fn combine_lookups( + column: Column, + lookups: Vec, ID>>, +) -> E { + let joint_combiner = { + let joint_combiner = ConstantExpr::from(BerkeleyChallengeTerm::JointCombiner); + E::Atom(ExprInner::Constant(joint_combiner)) + }; + let beta = { + let beta = ConstantExpr::from(BerkeleyChallengeTerm::Beta); + E::Atom(ExprInner::Constant(beta)) + }; + + // Compute (β + f_{i}(X)) for each i. + // Note that f_i(X) = table_id + r * x_{1} + r^2 x_{2} + ... r^{N} x_{N} + let denominators = lookups + .iter() + .map(|x| { + // Compute r * x_{1} + r^2 x_{2} + ... r^{N} x_{N} + let combined_value = x + .value + .iter() + .rev() + .fold(E::zero(), |acc, y| acc * joint_combiner.clone() + y.clone()) + * joint_combiner.clone(); + // FIXME: sanity check for the domain, we should consider it in prover.rs. + // We do only support degree one constraint in the denominator. + let combined_degree_real = combined_value.degree(1, 0); + assert!(combined_degree_real <= 1, "Only degree zero and one is supported in the denominator of the lookup because of the maximum degree supported (8); got degree {combined_degree_real}",); + // add table id + evaluation point + beta.clone() + combined_value + x.table_id.to_constraint() + }) + .collect::>(); + + // Compute `column * (\prod_{i = 1}^{N} (β + f_{i}(X)))` + let lhs = denominators + .iter() + .fold(curr_cell(column), |acc, x| acc * x.clone()); + + // Compute `\sum_{i = 0}^{N} m_{i} * \prod_{j = 1, j \neq i}^{N} (β + f_{j}(X))` + let rhs = lookups + .into_iter() + .enumerate() + .map(|(i, x)| { + denominators.iter().enumerate().fold( + // Compute individual \sum_{j = 1, j \neq i}^{N} (β + f_{j}(X)) + // This is the inner part of rhs. It multiplies with m_{i} + x.numerator, + |acc, (j, y)| { + if i == j { + acc + } else { + acc * y.clone() + } + }, + ) + }) + // Individual sums + .reduce(|x, y| x + y) + .unwrap_or(E::zero()); + + lhs - rhs +} + +/// Build the constraints for the lookup protocol. +/// The constraints are the partial sum and the aggregation of the partial sums. +pub fn constraint_lookups( + lookup_reads: &BTreeMap>>>, + lookup_writes: &BTreeMap>>>, +) -> Vec> { + let mut constraints: Vec> = vec![]; + let mut lookup_terms_cols: Vec = vec![]; + lookup_reads.iter().for_each(|(table_id, reads)| { + let mut idx_partial_sum = 0; + let table_id_u32 = table_id.to_u32(); + + // FIXME: do not clone + let mut lookups: Vec<_> = reads + .clone() + .into_iter() + .map(|value| Logup { + table_id: *table_id, + numerator: Expr::Atom(ExprInner::Constant(ConstantExpr::from( + ConstantTerm::Literal(F::one()), + ))), + value, + }) + .collect(); + + if table_id.is_fixed() || table_id.runtime_create_column() { + let table_lookup = Logup { + table_id: *table_id, + numerator: -curr_cell(Column::LookupMultiplicity((table_id_u32, 0))), + value: vec![curr_cell(Column::LookupFixedTable(table_id_u32))], + }; + lookups.push(table_lookup); + } else { + lookup_writes + .get(table_id) + .expect("Lookup writes for table_id {table_id:?} not present") + .iter() + .enumerate() + .for_each(|(i, write_columns)| { + lookups.push(Logup { + table_id: *table_id, + numerator: -curr_cell(Column::LookupMultiplicity((table_id_u32, i))), + value: write_columns.clone(), + }); + }); + } + + // We split in chunks of 6 (MAX_SUPPORTED_DEGREE - 2) + lookups.chunks(MAX_SUPPORTED_DEGREE - 2).for_each(|chunk| { + let col = Column::LookupPartialSum((table_id_u32, idx_partial_sum)); + lookup_terms_cols.push(col); + constraints.push(combine_lookups(col, chunk.to_vec())); + idx_partial_sum += 1; + }); + }); + + // Generic code over the partial sum + // Compute φ(ωX) - φ(X) - \sum_{i = 1}^{N} h_i(X) + { + let constraint = + next_cell(Column::LookupAggregation) - curr_cell(Column::LookupAggregation); + let constraint = lookup_terms_cols + .into_iter() + .fold(constraint, |acc, col| acc - curr_cell(col)); + constraints.push(constraint); + } + constraints +} + +pub mod prover { + use crate::{ + logup::{Logup, LogupWitness, LookupTableID}, + MAX_SUPPORTED_DEGREE, + }; + use ark_ff::{FftField, Zero}; + use ark_poly::{univariate::DensePolynomial, Evaluations, Radix2EvaluationDomain as D}; + use kimchi::{circuits::domains::EvaluationDomains, curve::KimchiCurve}; + use mina_poseidon::FqSponge; + use poly_commitment::{ + commitment::{absorb_commitment, PolyComm}, + OpenProof, SRS as _, + }; + use rayon::iter::{IntoParallelIterator, ParallelIterator}; + use std::collections::BTreeMap; + + /// The structure used by the prover the compute the quotient polynomial. + /// The structure contains the evaluations of the inner sums, the + /// multiplicities, the aggregation and the fixed tables, over the domain d8. + pub struct QuotientPolynomialEnvironment<'a, F: FftField, ID: LookupTableID> { + /// The evaluations of the partial sums, over d8. + pub lookup_terms_evals_d8: &'a BTreeMap>>>, + /// The evaluations of the aggregation, over d8. + pub lookup_aggregation_evals_d8: &'a Evaluations>, + /// The evaluations of the multiplicities, over d8, indexed by the table ID. + pub lookup_counters_evals_d8: &'a BTreeMap>>>, + /// The evaluations of the fixed tables, over d8, indexed by the table ID. + pub fixed_tables_evals_d8: &'a BTreeMap>>, + } + + /// Represents the environment for the logup argument. + pub struct Env { + /// The polynomial of the multiplicities, indexed by the table ID. + pub lookup_counters_poly_d1: BTreeMap>>, + /// The commitments to the multiplicities, indexed by the table ID. + pub lookup_counters_comm_d1: BTreeMap>>, + + /// The polynomials of the inner sums. + pub lookup_terms_poly_d1: BTreeMap>>, + /// The commitments of the inner sums. + pub lookup_terms_comms_d1: BTreeMap>>, + + /// The aggregation polynomial. + pub lookup_aggregation_poly_d1: DensePolynomial, + /// The commitment to the aggregation polynomial. + pub lookup_aggregation_comm_d1: PolyComm, + + // Evaluating over d8 for the quotient polynomial + #[allow(clippy::type_complexity)] + pub lookup_counters_evals_d8: + BTreeMap>>>, + #[allow(clippy::type_complexity)] + pub lookup_terms_evals_d8: + BTreeMap>>>, + pub lookup_aggregation_evals_d8: Evaluations>, + + pub fixed_lookup_tables_poly_d1: BTreeMap>, + pub fixed_lookup_tables_comms_d1: BTreeMap>, + pub fixed_lookup_tables_evals_d8: + BTreeMap>>, + + /// The combiner used for vector lookups + pub joint_combiner: G::ScalarField, + + /// The evaluation point used for the lookup polynomials. + pub beta: G::ScalarField, + } + + impl Env { + /// Create an environment for the prover to create a proof for the Logup protocol. + /// The protocol does suppose that the individual lookup terms are + /// committed as part of the columns. + /// Therefore, the protocol only focus on commiting to the "grand + /// product sum" and the "row-accumulated" values. + pub fn create< + OpeningProof: OpenProof, + Sponge: FqSponge, + >( + lookups: BTreeMap>, + domain: EvaluationDomains, + fq_sponge: &mut Sponge, + srs: &OpeningProof::SRS, + ) -> Self + where + OpeningProof::SRS: Sync, + { + // Use parallel iterators where possible. + // Polynomial m(X) + // FIXME/IMPROVEME: m(X) is only for fixed table + let lookup_counters_evals_d1: BTreeMap< + ID, + Vec>>, + > = { + lookups + .clone() + .into_par_iter() + .map(|(table_id, logup_witness)| { + ( + table_id, + logup_witness.m.into_iter().map(|m| + Evaluations::>::from_vec_and_domain( + m.to_vec(), + domain.d1, + ) + ).collect() + ) + }) + .collect() + }; + + let lookup_counters_poly_d1: BTreeMap>> = + (&lookup_counters_evals_d1) + .into_par_iter() + .map(|(id, evals)| { + (*id, evals.iter().map(|e| e.interpolate_by_ref()).collect()) + }) + .collect(); + + let lookup_counters_evals_d8: BTreeMap< + ID, + Vec>>, + > = (&lookup_counters_poly_d1) + .into_par_iter() + .map(|(id, lookup)| { + ( + *id, + lookup + .iter() + .map(|l| l.evaluate_over_domain_by_ref(domain.d8)) + .collect(), + ) + }) + .collect(); + + let lookup_counters_comm_d1: BTreeMap>> = + (&lookup_counters_evals_d1) + .into_par_iter() + .map(|(id, polys)| { + ( + *id, + polys + .iter() + .map(|poly| srs.commit_evaluations_non_hiding(domain.d1, poly)) + .collect(), + ) + }) + .collect(); + + lookup_counters_comm_d1.values().for_each(|comms| { + comms + .iter() + .for_each(|comm| absorb_commitment(fq_sponge, comm)) + }); + // -- end of m(X) + + // -- start computing the row sums h(X) + // It will be used to compute the running sum in lookup_aggregation + // Coin a combiner to perform vector lookup. + // The row sums h are defined as + // -- n 1 1 + // h(ω^i) = ∑ -------------------- - -------------- + // j = 0 (β + f_{j}(ω^i)) (β + t(ω^i)) + let vector_lookup_combiner = fq_sponge.challenge(); + + // Coin an evaluation point for the rational functions + let beta = fq_sponge.challenge(); + + // + // @volhovm TODO make sure this is h_i. It looks like f_i for fixed tables. + // It is an actual fixed column containing "fixed lookup data". + // + // Contain the evalations of the h_i. We divide the looked-up values + // in chunks of (MAX_SUPPORTED_DEGREE - 2) + let mut fixed_lookup_tables: BTreeMap> = BTreeMap::new(); + + // @volhovm TODO These are h_i related chunks! + // + // We keep the lookup terms in a map, to process them in order in the constraints. + let mut lookup_terms_map: BTreeMap>> = BTreeMap::new(); + + // Where are commitments to the last element are made? First "read" columns we don't commit to, right?.. + + lookups.into_iter().for_each(|(table_id, logup_witness)| { + let LogupWitness { f, m: _ } = logup_witness; + // The number of functions to look up, including the fixed table. + let n = f.len(); + let n_partial_sums = if n % (MAX_SUPPORTED_DEGREE - 2) == 0 { + n / (MAX_SUPPORTED_DEGREE - 2) + } else { + n / (MAX_SUPPORTED_DEGREE - 2) + 1 + }; + + let mut partial_sums = + vec![ + Vec::::with_capacity(domain.d1.size as usize); + n_partial_sums + ]; + + // We compute first the denominators of all f_i and t. We gather them in + // a vector to perform a batch inversion. + let mut denominators = Vec::with_capacity(n * domain.d1.size as usize); + // Iterate over the rows + for j in 0..domain.d1.size { + // Iterate over individual columns (i.e. f_i and t) + for (i, f_i) in f.iter().enumerate() { + let Logup { + numerator: _, + table_id, + value, + } = &f_i[j as usize]; + // Compute x_{1} + r x_{2} + ... r^{N-1} x_{N} + // This is what we actually put into the `fixed_lookup_tables`. + let combined_value_pow0: G::ScalarField = + value.iter().rev().fold(G::ScalarField::zero(), |acc, y| { + acc * vector_lookup_combiner + y + }); + // Compute r * x_{1} + r^2 x_{2} + ... r^{N} x_{N} + let combined_value: G::ScalarField = + combined_value_pow0 * vector_lookup_combiner; + // add table id + let combined_value = combined_value + table_id.to_field::(); + + // If last element and fixed lookup tables, we keep + // the *combined* value of the table. + // + // Otherwise we're processing a runtime table + // with explicit writes, so we don't need to + // create any extra columns. + if (table_id.is_fixed() || table_id.runtime_create_column()) && i == (n - 1) + { + fixed_lookup_tables + .entry(*table_id) + .or_insert_with(Vec::new) + .push(combined_value_pow0); + } + + // β + a_{i} + let lookup_denominator = beta + combined_value; + denominators.push(lookup_denominator); + } + } + assert!(denominators.len() == n * domain.d1.size as usize); + + // Given a vector {β + a_i}, computes a vector {1/(β + a_i)} + ark_ff::fields::batch_inversion(&mut denominators); + + // Evals is the sum on the individual columns for each row + let mut denominator_index = 0; + + // We only need to add the numerator now + for j in 0..domain.d1.size { + let mut partial_sum_idx = 0; + let mut row_acc = G::ScalarField::zero(); + for (i, f_i) in f.iter().enumerate() { + let Logup { + numerator, + table_id: _, + value: _, + } = &f_i[j as usize]; + row_acc += *numerator * denominators[denominator_index]; + denominator_index += 1; + // We split in chunks of (MAX_SUPPORTED_DEGREE - 2) + // We reset the accumulator for the current partial + // sum after keeping it. + if (i + 1) % (MAX_SUPPORTED_DEGREE - 2) == 0 { + partial_sums[partial_sum_idx].push(row_acc); + row_acc = G::ScalarField::zero(); + partial_sum_idx += 1; + } + } + // Whatever leftover in `row_acc` left in the end of the iteration, we write it into + // `partial_sums` too. This is only done in case `n % (MAX_SUPPORTED_DEGREE - 2) != 0` + // which means that the similar addition to `partial_sums` a few lines above won't be triggered. + // So we have this wrapping up call instead. + if n % (MAX_SUPPORTED_DEGREE - 2) != 0 { + partial_sums[partial_sum_idx].push(row_acc); + } + } + lookup_terms_map.insert(table_id, partial_sums); + }); + + // Sanity check to verify that the number of evaluations is correct + lookup_terms_map.values().for_each(|evals| { + evals + .iter() + .for_each(|eval| assert_eq!(eval.len(), domain.d1.size as usize)) + }); + + // Sanity check to verify that we have all the evaluations for the fixed lookup tables + fixed_lookup_tables + .values() + .for_each(|evals| assert_eq!(evals.len(), domain.d1.size as usize)); + + #[allow(clippy::type_complexity)] + let lookup_terms_evals_d1: BTreeMap< + ID, + Vec>>, + > = + (&lookup_terms_map) + .into_par_iter() + .map(|(id, lookup_terms)| { + let lookup_terms = lookup_terms.into_par_iter().map(|lookup_term| { + Evaluations::>::from_vec_and_domain( + lookup_term.to_vec(), domain.d1, + )}).collect::>(); + (*id, lookup_terms) + }) + .collect(); + + let fixed_lookup_tables_evals_d1: BTreeMap< + ID, + Evaluations>, + > = fixed_lookup_tables + .into_iter() + .map(|(id, evals)| { + ( + id, + Evaluations::>::from_vec_and_domain( + evals, domain.d1, + ), + ) + }) + .collect(); + + let lookup_terms_poly_d1: BTreeMap>> = + (&lookup_terms_evals_d1) + .into_par_iter() + .map(|(id, lookup_terms)| { + let lookup_terms: Vec> = lookup_terms + .into_par_iter() + .map(|evals| evals.interpolate_by_ref()) + .collect(); + (*id, lookup_terms) + }) + .collect(); + + let fixed_lookup_tables_poly_d1: BTreeMap> = + (&fixed_lookup_tables_evals_d1) + .into_par_iter() + .map(|(id, evals)| (*id, evals.interpolate_by_ref())) + .collect(); + + #[allow(clippy::type_complexity)] + let lookup_terms_evals_d8: BTreeMap< + ID, + Vec>>, + > = (&lookup_terms_poly_d1) + .into_par_iter() + .map(|(id, lookup_terms)| { + let lookup_terms: Vec>> = + lookup_terms + .into_par_iter() + .map(|lookup_term| lookup_term.evaluate_over_domain_by_ref(domain.d8)) + .collect(); + (*id, lookup_terms) + }) + .collect(); + + let fixed_lookup_tables_evals_d8: BTreeMap< + ID, + Evaluations>, + > = (&fixed_lookup_tables_poly_d1) + .into_par_iter() + .map(|(id, poly)| (*id, poly.evaluate_over_domain_by_ref(domain.d8))) + .collect(); + + let lookup_terms_comms_d1: BTreeMap>> = lookup_terms_evals_d1 + .iter() + .map(|(id, lookup_terms)| { + let lookup_terms = lookup_terms + .into_par_iter() + .map(|lookup_term| { + srs.commit_evaluations_non_hiding(domain.d1, lookup_term) + }) + .collect(); + (*id, lookup_terms) + }) + .collect(); + + let fixed_lookup_tables_comms_d1: BTreeMap> = + (&fixed_lookup_tables_evals_d1) + .into_par_iter() + .map(|(id, evals)| (*id, srs.commit_evaluations_non_hiding(domain.d1, evals))) + .collect(); + + lookup_terms_comms_d1.values().for_each(|comms| { + comms + .iter() + .for_each(|comm| absorb_commitment(fq_sponge, comm)) + }); + + fixed_lookup_tables_comms_d1 + .values() + .for_each(|comm| absorb_commitment(fq_sponge, comm)); + // -- end computing the row sums h + + // -- start computing the running sum in lookup_aggregation + // The running sum, φ, is defined recursively over the subgroup as followed: + // - φ(1) = 0 + // - φ(ω^{j + 1}) = φ(ω^j) + \ + // \sum_{i = 1}^{n} (1 / (β + f_i(ω^{j + 1}))) - \ + // (m(ω^{j + 1}) / (β + t(ω^{j + 1}))) + // - φ(ω^n) = 0 + let lookup_aggregation_evals_d1 = { + { + for table_id in ID::all_variants().into_iter() { + let mut acc = G::ScalarField::zero(); + for i in 0..domain.d1.size as usize { + // φ(1) = 0 + let lookup_terms = lookup_terms_evals_d1.get(&table_id).unwrap(); + acc = lookup_terms.iter().fold(acc, |acc, lte| acc + lte[i]); + } + // Sanity check to verify that the accumulator ends up being zero. + assert_eq!( + acc, + G::ScalarField::zero(), + "Logup accumulator for table {table_id:?} must be zero" + ); + } + } + let mut evals = Vec::with_capacity(domain.d1.size as usize); + { + let mut acc = G::ScalarField::zero(); + for i in 0..domain.d1.size as usize { + // φ(1) = 0 + evals.push(acc); + lookup_terms_evals_d1.iter().for_each(|(_, lookup_terms)| { + acc = lookup_terms.iter().fold(acc, |acc, lte| acc + lte[i]); + }) + } + // Sanity check to verify that the accumulator ends up being zero. + assert_eq!( + acc, + G::ScalarField::zero(), + "Logup accumulator must be zero" + ); + } + Evaluations::>::from_vec_and_domain( + evals, domain.d1, + ) + }; + + let lookup_aggregation_poly_d1 = lookup_aggregation_evals_d1.interpolate_by_ref(); + + let lookup_aggregation_evals_d8 = + lookup_aggregation_poly_d1.evaluate_over_domain_by_ref(domain.d8); + + let lookup_aggregation_comm_d1 = + srs.commit_evaluations_non_hiding(domain.d1, &lookup_aggregation_evals_d1); + + absorb_commitment(fq_sponge, &lookup_aggregation_comm_d1); + Self { + lookup_counters_poly_d1, + lookup_counters_comm_d1, + + lookup_terms_poly_d1, + lookup_terms_comms_d1, + + lookup_aggregation_poly_d1, + lookup_aggregation_comm_d1, + + lookup_counters_evals_d8, + lookup_terms_evals_d8, + lookup_aggregation_evals_d8, + + fixed_lookup_tables_poly_d1, + fixed_lookup_tables_comms_d1, + fixed_lookup_tables_evals_d8, + + joint_combiner: vector_lookup_combiner, + beta, + } + } + } +} diff --git a/msm/src/lookups.rs b/msm/src/lookups.rs new file mode 100644 index 0000000000..c488a2f3c1 --- /dev/null +++ b/msm/src/lookups.rs @@ -0,0 +1,199 @@ +//! Instantiate the Logup protocol for the MSM project. + +use crate::logup::{Logup, LogupWitness, LookupTableID}; +use ark_ff::{FftField, PrimeField}; +use kimchi::circuits::domains::EvaluationDomains; +use rand::{seq::SliceRandom, thread_rng, Rng}; +use std::{cmp::Ord, iter}; + +/// Dummy lookup table. For the cases when you don't need one -- a single dummy element 0. +#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)] +pub enum DummyLookupTable { + DummyLookupTable, +} + +impl LookupTableID for DummyLookupTable { + fn to_u32(&self) -> u32 { + 1 + } + + fn from_u32(id: u32) -> Self { + match id { + 1 => DummyLookupTable::DummyLookupTable, + _ => panic!("Dummy lookup table has only index 1"), + } + } + + fn length(&self) -> usize { + 1 + } + + /// All tables are fixed tables. + fn is_fixed(&self) -> bool { + true + } + + fn runtime_create_column(&self) -> bool { + panic!("No runtime tables specified"); + } + + fn ix_by_value(&self, value: &[F]) -> Option { + if value[0] == F::zero() { + Some(0) + } else { + panic!("Invalid value for DummyLookupTable") + } + } + + fn all_variants() -> Vec { + vec![DummyLookupTable::DummyLookupTable] + } +} + +impl DummyLookupTable { + /// Provides a full list of entries for the given table. + pub fn entries(&self, domain_d1_size: u64) -> Vec { + // All zeroes + (0..domain_d1_size).map(|_| F::zero()).collect() + } +} + +/// Lookup tables used in the MSM project +// TODO: Add more built-in lookup tables +#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)] +pub enum LookupTableIDs { + RangeCheck16, + /// Custom lookup table + /// The index of the table is used as the ID, padded with the number of + /// built-in tables. + Custom(u32), +} + +impl LookupTableID for LookupTableIDs { + fn to_u32(&self) -> u32 { + match self { + LookupTableIDs::RangeCheck16 => 1_u32, + LookupTableIDs::Custom(id) => id + 1, + } + } + + fn from_u32(id: u32) -> Self { + match id { + 1 => LookupTableIDs::RangeCheck16, + _ => LookupTableIDs::Custom(id - 1), + } + } + + fn length(&self) -> usize { + match self { + LookupTableIDs::RangeCheck16 => 1 << 16, + LookupTableIDs::Custom(_) => todo!(), + } + } + + /// All tables are fixed tables. + fn is_fixed(&self) -> bool { + true + } + + fn runtime_create_column(&self) -> bool { + panic!("No runtime tables specified"); + } + + fn ix_by_value(&self, _value: &[F]) -> Option { + todo!() + } + + fn all_variants() -> Vec { + // TODO in the future this must depend on some associated type + // that parameterises the lookup table. + vec![Self::RangeCheck16, Self::Custom(0)] + } +} + +/// Additive lookups used in the MSM project based on Logup +pub type Lookup = Logup; + +/// Represents a witness of one instance of the lookup argument of the MSM project +pub type LookupWitness = LogupWitness; + +// This should be used only for testing purposes. +// It is not only in the test API because it is used at the moment in the +// main.rs. It should be moved to the test API when main.rs is replaced with +// real production code. +impl LookupWitness { + /// Generate a random number of correct lookups in the table RangeCheck16 + pub fn random(domain: EvaluationDomains) -> (LookupTableIDs, Self) { + let mut rng = thread_rng(); + // TODO: generate more random f + let table_size: u64 = rng.gen_range(1..domain.d1.size); + let table_id = rng.gen_range(1..1000); + // Build a table of value we can look up + let t: Vec = { + // Generate distinct values to avoid to have to handle the + // normalized multiplicity polynomial + let mut n: Vec = (1..(table_size * 100)).collect(); + n.shuffle(&mut rng); + n[0..table_size as usize].to_vec() + }; + // permutation argument + let f = { + let mut f = t.clone(); + f.shuffle(&mut rng); + f + }; + let dummy_value = F::rand(&mut rng); + let repeated_dummy_value: Vec = { + let r: Vec = iter::repeat(dummy_value) + .take((domain.d1.size - table_size) as usize) + .collect(); + r + }; + let t_evals = { + let mut table = Vec::with_capacity(domain.d1.size as usize); + table.extend(t.iter().map(|v| Lookup { + table_id: LookupTableIDs::Custom(table_id), + numerator: -F::one(), + value: vec![F::from(*v)], + })); + table.extend( + repeated_dummy_value + .iter() + .map(|v| Lookup { + table_id: LookupTableIDs::Custom(table_id), + numerator: -F::one(), + value: vec![*v], + }) + .collect::>>(), + ); + table + }; + let f_evals: Vec> = { + let mut table = Vec::with_capacity(domain.d1.size as usize); + table.extend(f.iter().map(|v| Lookup { + table_id: LookupTableIDs::Custom(table_id), + numerator: F::one(), + value: vec![F::from(*v)], + })); + table.extend( + repeated_dummy_value + .iter() + .map(|v| Lookup { + table_id: LookupTableIDs::Custom(table_id), + numerator: F::one(), + value: vec![*v], + }) + .collect::>>(), + ); + table + }; + let m = (0..domain.d1.size).map(|_| F::one()).collect(); + ( + LookupTableIDs::Custom(table_id), + LookupWitness { + f: vec![f_evals, t_evals], + m: vec![m], + }, + ) + } +} diff --git a/msm/src/precomputed_srs.rs b/msm/src/precomputed_srs.rs new file mode 100644 index 0000000000..f278914419 --- /dev/null +++ b/msm/src/precomputed_srs.rs @@ -0,0 +1,145 @@ +//! Clone of kimchi/precomputed_srs.rs but for MSM project with BN254 + +use crate::{Fp, BN254, DOMAIN_SIZE}; +use ark_ec::pairing::Pairing; +use ark_ff::UniformRand; +use ark_serialize::Write; +use kimchi::{circuits::domains::EvaluationDomains, precomputed_srs::TestSRS}; +use poly_commitment::{kzg::PairingSRS, SRS as _}; +use rand::{rngs::StdRng, SeedableRng}; +use serde::{Deserialize, Serialize}; +use std::{fs::File, io::BufReader, path::PathBuf}; + +/// A clone of the `PairingSRS` that is serialized in a test-optimised way. +#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)] +pub struct TestPairingSRS { + pub full_srs: TestSRS, + pub verifier_srs: TestSRS, +} + +impl From> for TestPairingSRS { + fn from(value: PairingSRS) -> Self { + TestPairingSRS { + full_srs: From::from(value.full_srs), + verifier_srs: From::from(value.verifier_srs), + } + } +} + +impl From> for PairingSRS { + fn from(value: TestPairingSRS) -> Self { + PairingSRS { + full_srs: From::from(value.full_srs), + verifier_srs: From::from(value.verifier_srs), + } + } +} + +/// Obtains an SRS for a specific curve from disk, or generates it if absent. +pub fn get_bn254_srs(domain: EvaluationDomains) -> PairingSRS { + let srs = if domain.d1.size as usize == DOMAIN_SIZE { + read_bn254_srs_from_disk(get_bn254_srs_path()) + } else { + PairingSRS::create(domain.d1.size as usize) + }; + srs.full_srs.get_lagrange_basis(domain.d1); // not added if already present. + srs +} + +/// The path of the serialized BN254 SRS, inside this repo. +pub fn get_bn254_srs_path() -> PathBuf { + let base_path = env!("CARGO_MANIFEST_DIR"); + if base_path.is_empty() { + println!("WARNING: BN254 precomputation: CARGO_MANIFEST_DIR is absent, can't determine working directory. It is sometimes absent in release mode."); + } + PathBuf::from(base_path).join("../srs/test_bn254.srs") +} + +/// Tries to read the SRS from disk, otherwise panics. Returns the +/// value without Lagrange basis. +fn read_bn254_srs_from_disk(srs_path: PathBuf) -> PairingSRS { + let file = + File::open(srs_path.clone()).unwrap_or_else(|_| panic!("missing SRS file: {srs_path:?}")); + let reader = BufReader::new(file); + let srs: TestPairingSRS = rmp_serde::from_read(reader).unwrap(); + From::from(srs) +} + +/// Creates a BN254 SRS. If the `overwrite_srs` flag is on, or +/// `SRS_OVERWRITE` env variable is ON, also writes it into the file. +fn create_and_store_srs_with_path( + force_overwrite: bool, + domain_size: usize, + srs_path: PathBuf, +) -> PairingSRS { + // We generate with a fixed-seed RNG, only used for testing. + let mut rng = &mut StdRng::from_seed([42u8; 32]); + let trapdoor = Fp::rand(&mut rng); + let srs = PairingSRS::create_trusted_setup(trapdoor, domain_size); + + for sub_domain_size in 1..=domain_size { + let domain = EvaluationDomains::::create(sub_domain_size).unwrap(); + srs.full_srs.get_lagrange_basis(domain.d1); + } + + // overwrite SRS if the env var is set + if force_overwrite || std::env::var("SRS_OVERWRITE").is_ok() { + // Create parent directories + std::fs::create_dir_all(srs_path.parent().unwrap()).unwrap(); + // Open/create the file + let mut file = std::fs::OpenOptions::new() + .create(true) + .write(true) + .open(srs_path.clone()) + .expect("failed to open SRS file"); + + let test_srs: TestPairingSRS = From::from(srs.clone()); + let srs_bytes = rmp_serde::to_vec(&test_srs).unwrap(); + file.write_all(&srs_bytes).expect("failed to write file"); + file.flush().expect("failed to flush file"); + } + + // get SRS from disk + let srs_on_disk = read_bn254_srs_from_disk(srs_path); + + // check that it matches what we just generated + assert_eq!(srs, srs_on_disk); + + srs +} + +/// Creates and writes the SRS into `get_bn254_srs_path()`. +pub fn create_and_store_srs(force_overwrite: bool, domain_size: usize) -> PairingSRS { + let srs_path = get_bn254_srs_path(); + create_and_store_srs_with_path(force_overwrite, domain_size, srs_path) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + /// Creates or checks the real full-sized SRS. Only overwrites + /// with the ENV flag. + pub fn heavy_test_create_or_check_srs() { + let domain_size = DOMAIN_SIZE; + create_and_store_srs_with_path(false, domain_size, get_bn254_srs_path()); + } + + #[test] + /// Fast test for a small-sized SRS. Always writes & reads. + pub fn check_bn256_srs_serialization() { + let domain_size = 1 << 8; + let test_srs_path = PathBuf::from("/tmp/test_bn254.srs"); + create_and_store_srs_with_path(true, domain_size, test_srs_path); + } + + #[test] + /// Checks if `get_bn254_srs` does not fail. Can be used for + /// time-profiling. + pub fn check_get_bn254_srs() { + let domain_size = DOMAIN_SIZE; + let domain = EvaluationDomains::::create(domain_size).unwrap(); + get_bn254_srs(domain); + } +} diff --git a/msm/src/proof.rs b/msm/src/proof.rs new file mode 100644 index 0000000000..4700450cf5 --- /dev/null +++ b/msm/src/proof.rs @@ -0,0 +1,175 @@ +use crate::{ + logup::{LookupProof, LookupTableID}, + lookups::{LookupTableIDs, LookupWitness}, + witness::Witness, + LogupWitness, DOMAIN_SIZE, +}; +use ark_ff::PrimeField; +use kimchi::{ + circuits::{ + domains::EvaluationDomains, + expr::{ColumnEvaluations, ExprError}, + }, + curve::KimchiCurve, + proof::PointEvaluations, +}; +use poly_commitment::{commitment::PolyComm, OpenProof}; +use rand::thread_rng; +use std::collections::BTreeMap; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ProofInputs { + /// Actual values w_i of the witness columns. "Evaluations" as in + /// evaluations of polynomial P_w that interpolates w_i. + pub evaluations: Witness>, + pub logups: BTreeMap>, +} + +impl ProofInputs { + // This should be used only for testing purposes. + // It is not only in the test API because it is used at the moment in the + // main.rs. It should be moved to the test API when main.rs is replaced with + // real production code. + pub fn random(domain: EvaluationDomains) -> Self { + let mut rng = thread_rng(); + let cols: Box<[Vec; N_WIT]> = Box::new(std::array::from_fn(|_| { + (0..domain.d1.size as usize) + .map(|_| F::rand(&mut rng)) + .collect::>() + })); + let logups = { + let (table_id, item) = LookupWitness::::random(domain); + BTreeMap::from([(table_id, item)]) + }; + ProofInputs { + evaluations: Witness { cols }, + logups, + } + } +} + +impl Default for ProofInputs { + /// Creates a default proof instance. Note that such an empty "zero" instance will not satisfy any constraint. + /// E.g. some constraints that have constants inside of them (A - const = 0) cannot be satisfied by it. + fn default() -> Self { + ProofInputs { + evaluations: Witness { + cols: Box::new(std::array::from_fn(|_| { + (0..DOMAIN_SIZE).map(|_| F::zero()).collect() + })), + }, + logups: Default::default(), + } + } +} + +#[derive(Debug, Clone)] +// TODO Should public input and fixed selectors evaluations be here? +pub struct ProofEvaluations< + const N_WIT: usize, + const N_REL: usize, + const N_DSEL: usize, + const N_FSEL: usize, + F, + ID: LookupTableID, +> { + /// Witness evaluations, including public inputs + pub(crate) witness_evals: Witness>, + /// Evaluations of fixed selectors. + pub(crate) fixed_selectors_evals: Box<[PointEvaluations; N_FSEL]>, + /// Logup argument evaluations + pub(crate) logup_evals: Option, ID>>, + /// Evaluation of Z_H(ζ) (t_0(X) + ζ^n t_1(X) + ...) at ζω. + pub(crate) ft_eval1: F, +} + +/// The trait ColumnEvaluations is used by the verifier. +/// It will return the evaluation of the corresponding column at the +/// evaluation points coined by the verifier during the protocol. +impl< + const N_WIT: usize, + const N_REL: usize, + const N_DSEL: usize, + const N_FSEL: usize, + F: Clone, + ID: LookupTableID, + > ColumnEvaluations for ProofEvaluations +{ + type Column = crate::columns::Column; + + fn evaluate(&self, col: Self::Column) -> Result, ExprError> { + // TODO: substitute when non-literal generic constants are available + assert!(N_WIT == N_REL + N_DSEL); + let res = match col { + Self::Column::Relation(i) => { + assert!(i < N_REL, "Index out of bounds"); + self.witness_evals[i].clone() + } + Self::Column::DynamicSelector(i) => { + assert!(i < N_DSEL, "Index out of bounds"); + self.witness_evals[N_REL + i].clone() + } + Self::Column::FixedSelector(i) => { + assert!(i < N_FSEL, "Index out of bounds"); + self.fixed_selectors_evals[i].clone() + } + Self::Column::LookupPartialSum((table_id, idx)) => { + if let Some(ref lookup) = self.logup_evals { + lookup.h[&ID::from_u32(table_id)][idx].clone() + } else { + panic!("No lookup provided") + } + } + Self::Column::LookupAggregation => { + if let Some(ref lookup) = self.logup_evals { + lookup.sum.clone() + } else { + panic!("No lookup provided") + } + } + Self::Column::LookupMultiplicity((table_id, idx)) => { + if let Some(ref lookup) = self.logup_evals { + lookup.m[&ID::from_u32(table_id)][idx].clone() + } else { + panic!("No lookup provided") + } + } + Self::Column::LookupFixedTable(table_id) => { + if let Some(ref lookup) = self.logup_evals { + lookup.fixed_tables[&ID::from_u32(table_id)].clone() + } else { + panic!("No lookup provided") + } + } + }; + Ok(res) + } +} + +#[derive(Debug, Clone)] +pub struct ProofCommitments { + /// Commitments to the N columns of the circuits, also called the 'witnesses'. + /// If some columns are considered as public inputs, it is counted in the witness. + pub(crate) witness_comms: Witness>, + /// Commitments to the polynomials used by the lookup argument, coined "logup". + /// The values contains the chunked polynomials. + pub(crate) logup_comms: Option, ID>>, + /// Commitments to the quotient polynomial. + /// The value contains the chunked polynomials. + pub(crate) t_comm: PolyComm, +} + +#[derive(Debug, Clone)] +pub struct Proof< + const N_WIT: usize, + const N_REL: usize, + const N_DSEL: usize, + const N_FSEL: usize, + G: KimchiCurve, + OpeningProof: OpenProof, + ID: LookupTableID, +> { + pub(crate) proof_comms: ProofCommitments, + pub(crate) proof_evals: ProofEvaluations, + pub(crate) opening_proof: OpeningProof, +} diff --git a/msm/src/prover.rs b/msm/src/prover.rs new file mode 100644 index 0000000000..2ea2534da2 --- /dev/null +++ b/msm/src/prover.rs @@ -0,0 +1,587 @@ +#![allow(clippy::type_complexity)] +#![allow(clippy::boxed_local)] + +use crate::{ + column_env::ColumnEnvironment, + expr::E, + logup, + logup::{prover::Env, LookupProof, LookupTableID}, + proof::{Proof, ProofCommitments, ProofEvaluations, ProofInputs}, + witness::Witness, + MAX_SUPPORTED_DEGREE, +}; +use ark_ff::{Field, One, Zero}; +use ark_poly::{ + univariate::DensePolynomial, EvaluationDomain, Evaluations, Polynomial, + Radix2EvaluationDomain as R2D, +}; +use kimchi::{ + circuits::{ + berkeley_columns::BerkeleyChallenges, + domains::EvaluationDomains, + expr::{l0_1, Constants, Expr}, + }, + curve::KimchiCurve, + groupmap::GroupMap, + plonk_sponge::FrSponge, + proof::PointEvaluations, +}; +use mina_poseidon::{sponge::ScalarChallenge, FqSponge}; +use o1_utils::ExtendedDensePolynomial; +use poly_commitment::{ + commitment::{absorb_commitment, PolyComm}, + utils::DensePolynomialOrEvaluations, + OpenProof, SRS, +}; +use rand::{CryptoRng, RngCore}; +use rayon::iter::{IntoParallelIterator, ParallelIterator}; +use thiserror::Error; + +/// Errors that can arise when creating a proof +#[derive(Error, Debug, Clone)] +pub enum ProverError { + #[error("the proof could not be constructed: {0}")] + Generic(&'static str), + + #[error("the provided (witness) constraints was not satisfied: {0}")] + ConstraintNotSatisfied(String), + + #[error("the provided (witness) constraint has degree {0} > allowed {1}; expr: {2}")] + ConstraintDegreeTooHigh(u64, u64, String), +} + +pub fn prove< + G: KimchiCurve, + OpeningProof: OpenProof, + EFqSponge: Clone + FqSponge, + EFrSponge: FrSponge, + RNG, + const N_WIT: usize, + const N_REL: usize, + const N_DSEL: usize, + const N_FSEL: usize, + ID: LookupTableID, +>( + domain: EvaluationDomains, + srs: &OpeningProof::SRS, + constraints: &Vec>, + fixed_selectors: Box<[Vec; N_FSEL]>, + inputs: ProofInputs, + rng: &mut RNG, +) -> Result, ProverError> +where + OpeningProof::SRS: Sync, + RNG: RngCore + CryptoRng, +{ + //////////////////////////////////////////////////////////////////////////// + // Setting up the protocol + //////////////////////////////////////////////////////////////////////////// + + let group_map = G::Map::setup(); + + //////////////////////////////////////////////////////////////////////////// + // Round 1: Creating and absorbing column commitments + //////////////////////////////////////////////////////////////////////////// + + let mut fq_sponge = EFqSponge::new(G::other_curve_sponge_params()); + + let fixed_selectors_evals_d1: Box<[Evaluations>; N_FSEL]> = + o1_utils::array::vec_to_boxed_array( + fixed_selectors + .into_par_iter() + .map(|evals| Evaluations::from_vec_and_domain(evals, domain.d1)) + .collect(), + ); + + let fixed_selectors_polys: Box<[DensePolynomial; N_FSEL]> = + o1_utils::array::vec_to_boxed_array( + fixed_selectors_evals_d1 + .into_par_iter() + .map(|evals| evals.interpolate()) + .collect(), + ); + + let fixed_selectors_comms: Box<[PolyComm; N_FSEL]> = { + let comm = |poly: &DensePolynomial| srs.commit_non_hiding(poly, 1); + o1_utils::array::vec_to_boxed_array( + fixed_selectors_polys + .as_ref() + .into_par_iter() + .map(comm) + .collect(), + ) + }; + + // Do not use parallelism + (fixed_selectors_comms) + .into_iter() + .for_each(|comm| absorb_commitment(&mut fq_sponge, &comm)); + + // Interpolate all columns on d1, using trait Into. + let witness_evals_d1: Witness>> = inputs + .evaluations + .into_par_iter() + .map(|evals| { + Evaluations::>::from_vec_and_domain( + evals, domain.d1, + ) + }) + .collect::>>>(); + + let witness_polys: Witness> = { + let interpolate = + |evals: Evaluations>| evals.interpolate(); + witness_evals_d1 + .into_par_iter() + .map(interpolate) + .collect::>>() + }; + + let witness_comms: Witness> = { + let blinders = PolyComm { + chunks: vec![G::ScalarField::one()], + }; + let comm = { + |poly: &DensePolynomial| { + // In case the column polynomial is all zeroes, we want to mask the commitment + let comm = srs.commit_custom(poly, 1, &blinders).unwrap(); + comm.commitment + } + }; + (&witness_polys) + .into_par_iter() + .map(comm) + .collect::>>() + }; + + // Do not use parallelism + (&witness_comms) + .into_iter() + .for_each(|comm| absorb_commitment(&mut fq_sponge, comm)); + + // -- Start Logup + let lookup_env = if !inputs.logups.is_empty() { + Some(Env::create::( + inputs.logups, + domain, + &mut fq_sponge, + srs, + )) + } else { + None + }; + + let max_degree = { + if lookup_env.is_none() { + constraints + .iter() + .map(|expr| expr.degree(1, 0)) + .max() + .unwrap_or(0) + } else { + 8 + } + }; + + // Don't need to be absorbed. Already absorbed in logup::prover::Env::create + // FIXME: remove clone + let logup_comms = Option::map(lookup_env.as_ref(), |lookup_env| LookupProof { + m: lookup_env.lookup_counters_comm_d1.clone(), + h: lookup_env.lookup_terms_comms_d1.clone(), + sum: lookup_env.lookup_aggregation_comm_d1.clone(), + fixed_tables: lookup_env.fixed_lookup_tables_comms_d1.clone(), + }); + + // -- end computing the running sum in lookup_aggregation + // -- End of Logup + + let domain_eval = if max_degree <= 4 { + domain.d4 + } else if max_degree as usize <= MAX_SUPPORTED_DEGREE { + domain.d8 + } else { + panic!("We do support constraints up to {:?}", MAX_SUPPORTED_DEGREE) + }; + + let witness_evals: Witness>> = { + (&witness_polys) + .into_par_iter() + .map(|evals| evals.evaluate_over_domain_by_ref(domain_eval)) + .collect::>>>() + }; + + let fixed_selectors_evals: Box<[Evaluations>; N_FSEL]> = { + o1_utils::array::vec_to_boxed_array( + (fixed_selectors_polys.as_ref()) + .into_par_iter() + .map(|evals| evals.evaluate_over_domain_by_ref(domain_eval)) + .collect(), + ) + }; + + //////////////////////////////////////////////////////////////////////////// + // Round 2: Creating and committing to the quotient polynomial + //////////////////////////////////////////////////////////////////////////// + + let (_, endo_r) = G::endos(); + + // Sample α with the Fq-Sponge. + let alpha: G::ScalarField = fq_sponge.challenge(); + + let zk_rows = 0; + let column_env: ColumnEnvironment<'_, N_WIT, N_REL, N_DSEL, N_FSEL, _, _> = { + let challenges = BerkeleyChallenges { + alpha, + // NB: as there is no permutation argument, we do use the beta + // field instead of a new one for the evaluation point. + beta: Option::map(lookup_env.as_ref(), |x| x.beta).unwrap_or(G::ScalarField::zero()), + gamma: G::ScalarField::zero(), + joint_combiner: Option::map(lookup_env.as_ref(), |x| x.joint_combiner) + .unwrap_or(G::ScalarField::zero()), + }; + ColumnEnvironment { + constants: Constants { + endo_coefficient: *endo_r, + mds: &G::sponge_params().mds, + zk_rows, + }, + challenges, + witness: &witness_evals, + fixed_selectors: &fixed_selectors_evals, + l0_1: l0_1(domain.d1), + lookup: Option::map(lookup_env.as_ref(), |lookup_env| { + logup::prover::QuotientPolynomialEnvironment { + lookup_terms_evals_d8: &lookup_env.lookup_terms_evals_d8, + lookup_aggregation_evals_d8: &lookup_env.lookup_aggregation_evals_d8, + lookup_counters_evals_d8: &lookup_env.lookup_counters_evals_d8, + fixed_tables_evals_d8: &lookup_env.fixed_lookup_tables_evals_d8, + } + }), + domain, + } + }; + + let quotient_poly: DensePolynomial = { + let mut last_constraint_failed = None; + // Only for debugging purposes + for expr in constraints.iter() { + let fail_q_division = + ProverError::ConstraintNotSatisfied(format!("Unsatisfied expression: {:}", expr)); + // Check this expression are witness satisfied + let (_, res) = expr + .evaluations(&column_env) + .interpolate_by_ref() + .divide_by_vanishing_poly(domain.d1) + .ok_or(fail_q_division.clone())?; + if !res.is_zero() { + eprintln!("Unsatisfied expression: {}", expr); + //return Err(fail_q_division); + last_constraint_failed = Some(expr.clone()); + } + } + if let Some(expr) = last_constraint_failed { + return Err(ProverError::ConstraintNotSatisfied(format!( + "Unsatisfied expression: {:}", + expr + ))); + } + + // Compute ∑ α^i constraint_i as an expression + let combined_expr = + Expr::combine_constraints(0..(constraints.len() as u32), constraints.clone()); + + // We want to compute the quotient polynomial, i.e. + // t(X) = (∑ α^i constraint_i(X)) / Z_H(X). + // The sum of the expressions is called the "constraint polynomial". + // We will use the evaluations points of the individual witness and + // lookup columns. + // Note that as the constraints might be of higher degree than N, the + // size of the set H we want the constraints to be verified on, we must + // have more than N evaluations points for each columns. This is handled + // in the ColumnEnvironment structure. + // Reminder: to compute P(X) = P_{1}(X) * P_{2}(X), from the evaluations + // of P_{1} and P_{2}, with deg(P_{1}) = deg(P_{2}(X)) = N, we must have + // 2N evaluation points to compute P as deg(P(X)) <= 2N. + let expr_evaluation: Evaluations> = + combined_expr.evaluations(&column_env); + + // And we interpolate using the evaluations + let expr_evaluation_interpolated = expr_evaluation.interpolate(); + + let fail_final_q_division = || { + panic!("Division by vanishing poly must not fail at this point, we checked it before") + }; + // We compute the polynomial t(X) by dividing the constraints polynomial + // by the vanishing polynomial, i.e. Z_H(X). + let (quotient, res) = expr_evaluation_interpolated + .divide_by_vanishing_poly(domain.d1) + .unwrap_or_else(fail_final_q_division); + // As the constraints must be verified on H, the rest of the division + // must be equal to 0 as the constraints polynomial and Z_H(X) are both + // equal on H. + if !res.is_zero() { + fail_final_q_division(); + } + + quotient + }; + + let num_chunks: usize = if max_degree == 1 { + 1 + } else { + (max_degree - 1) as usize + }; + + //~ 1. commit to the quotient polynomial $t$. + let t_comm = srs.commit_non_hiding("ient_poly, num_chunks); + + //////////////////////////////////////////////////////////////////////////// + // Round 3: Evaluations at ζ and ζω + //////////////////////////////////////////////////////////////////////////// + + //~ 1. Absorb the commitment of the quotient polynomial with the Fq-Sponge. + absorb_commitment(&mut fq_sponge, &t_comm); + + //~ 1. Sample ζ with the Fq-Sponge. + let zeta_chal = ScalarChallenge(fq_sponge.challenge()); + + let zeta = zeta_chal.to_field(endo_r); + + let omega = domain.d1.group_gen; + // We will also evaluate at ζω as lookups do require to go to the next row. + let zeta_omega = zeta * omega; + + // Evaluate the polynomials at ζ and ζω -- Columns + let witness_evals: Witness> = { + let eval = |p: &DensePolynomial<_>| PointEvaluations { + zeta: p.evaluate(&zeta), + zeta_omega: p.evaluate(&zeta_omega), + }; + (&witness_polys) + .into_par_iter() + .map(eval) + .collect::>>() + }; + + let fixed_selectors_evals: Box<[PointEvaluations<_>; N_FSEL]> = { + let eval = |p: &DensePolynomial<_>| PointEvaluations { + zeta: p.evaluate(&zeta), + zeta_omega: p.evaluate(&zeta_omega), + }; + o1_utils::array::vec_to_boxed_array( + fixed_selectors_polys + .as_ref() + .into_par_iter() + .map(eval) + .collect::<_>(), + ) + }; + + // IMPROVEME: move this into the logup module + let logup_evals = lookup_env.as_ref().map(|lookup_env| LookupProof { + m: lookup_env + .lookup_counters_poly_d1 + .iter() + .map(|(id, polys)| { + ( + *id, + polys + .iter() + .map(|poly| { + let zeta = poly.evaluate(&zeta); + let zeta_omega = poly.evaluate(&zeta_omega); + PointEvaluations { zeta, zeta_omega } + }) + .collect(), + ) + }) + .collect(), + h: lookup_env + .lookup_terms_poly_d1 + .iter() + .map(|(id, polys)| { + let polys_evals: Vec<_> = polys + .iter() + .map(|poly| PointEvaluations { + zeta: poly.evaluate(&zeta), + zeta_omega: poly.evaluate(&zeta_omega), + }) + .collect(); + (*id, polys_evals) + }) + .collect(), + sum: PointEvaluations { + zeta: lookup_env.lookup_aggregation_poly_d1.evaluate(&zeta), + zeta_omega: lookup_env.lookup_aggregation_poly_d1.evaluate(&zeta_omega), + }, + fixed_tables: { + lookup_env + .fixed_lookup_tables_poly_d1 + .iter() + .map(|(id, poly)| { + let zeta = poly.evaluate(&zeta); + let zeta_omega = poly.evaluate(&zeta_omega); + (*id, PointEvaluations { zeta, zeta_omega }) + }) + .collect() + }, + }); + + //////////////////////////////////////////////////////////////////////////// + // Round 4: Opening proof w/o linearization polynomial + //////////////////////////////////////////////////////////////////////////// + + // Fiat Shamir - absorbing evaluations + let fq_sponge_before_evaluations = fq_sponge.clone(); + let mut fr_sponge = EFrSponge::new(G::sponge_params()); + fr_sponge.absorb(&fq_sponge.digest()); + + for PointEvaluations { zeta, zeta_omega } in (&witness_evals).into_iter() { + fr_sponge.absorb(zeta); + fr_sponge.absorb(zeta_omega); + } + + for PointEvaluations { zeta, zeta_omega } in fixed_selectors_evals.as_ref().iter() { + fr_sponge.absorb(zeta); + fr_sponge.absorb(zeta_omega); + } + + if lookup_env.is_some() { + for PointEvaluations { zeta, zeta_omega } in logup_evals.as_ref().unwrap().into_iter() { + fr_sponge.absorb(zeta); + fr_sponge.absorb(zeta_omega); + } + } + + // Compute ft(X) = \ + // (1 - ζ^n) \ + // (t_0(X) + ζ^n t_1(X) + ... + ζ^{kn} t_{k}(X)) + // where \sum_i t_i(X) X^{i n} = t(X), and t(X) is the quotient polynomial. + // At the end, we get the (partial) evaluation of the constraint polynomial + // in ζ. + let ft: DensePolynomial = { + let evaluation_point_to_domain_size = zeta.pow([domain.d1.size]); + // Compute \sum_i t_i(X) ζ^{i n} + // First we split t in t_i, and we reduce to degree (n - 1) after using `linearize` + let t_chunked: DensePolynomial = quotient_poly + .to_chunked_polynomial(num_chunks, domain.d1.size as usize) + .linearize(evaluation_point_to_domain_size); + // -Z_H = (1 - ζ^n) + let minus_vanishing_poly_at_zeta = -domain.d1.vanishing_polynomial().evaluate(&zeta); + // Multiply the polynomial \sum_i t_i(X) ζ^{i n} by -Z_H(ζ) + // (the evaluation in ζ of the vanishing polynomial) + t_chunked.scale(minus_vanishing_poly_at_zeta) + }; + + // We only evaluate at ζω as the verifier can compute the + // evaluation at ζ from the independent evaluations at ζ of the + // witness columns because ft(X) is the constraint polynomial, built from + // the public constraints. + // We evaluate at ζω because the lookup argument requires to compute + // \phi(Xω) - \phi(X). + let ft_eval1 = ft.evaluate(&zeta_omega); + + // Absorb ft(ζω) + fr_sponge.absorb(&ft_eval1); + + let v_chal = fr_sponge.challenge(); + let v = v_chal.to_field(endo_r); + let u_chal = fr_sponge.challenge(); + let u = u_chal.to_field(endo_r); + + let coefficients_form = DensePolynomialOrEvaluations::DensePolynomial; + let non_hiding = |n_chunks| PolyComm { + chunks: vec![G::ScalarField::zero(); n_chunks], + }; + let hiding = |n_chunks| PolyComm { + chunks: vec![G::ScalarField::one(); n_chunks], + }; + + // Gathering all polynomials to use in the opening proof + let mut polynomials: Vec<_> = (&witness_polys) + .into_par_iter() + .map(|poly| (coefficients_form(poly), hiding(1))) + .collect(); + + // @volhovm: I'm not sure we need to prove opening of fixed + // selectors in the commitment. + polynomials.extend( + fixed_selectors_polys + .as_ref() + .into_par_iter() + .map(|poly| (coefficients_form(poly), non_hiding(1))) + .collect::>(), + ); + + // Adding Logup + if let Some(ref lookup_env) = lookup_env { + // -- first m(X) + polynomials.extend( + lookup_env + .lookup_counters_poly_d1 + .values() + .flat_map(|polys| { + polys + .iter() + .map(|poly| (coefficients_form(poly), non_hiding(1))) + .collect::>() + }) + .collect::>(), + ); + // -- after that the partial sums + polynomials.extend({ + let polys = lookup_env.lookup_terms_poly_d1.values().map(|polys| { + polys + .iter() + .map(|poly| (coefficients_form(poly), non_hiding(1))) + .collect::>() + }); + let polys: Vec<_> = polys.flatten().collect(); + polys + }); + // -- after that the running sum + polynomials.push(( + coefficients_form(&lookup_env.lookup_aggregation_poly_d1), + non_hiding(1), + )); + // -- Adding fixed lookup tables + polynomials.extend( + lookup_env + .fixed_lookup_tables_poly_d1 + .values() + .map(|poly| (coefficients_form(poly), non_hiding(1))) + .collect::>(), + ); + } + polynomials.push((coefficients_form(&ft), non_hiding(1))); + + let opening_proof = OpenProof::open::<_, _, R2D>( + srs, + &group_map, + polynomials.as_slice(), + &[zeta, zeta_omega], + v, + u, + fq_sponge_before_evaluations, + rng, + ); + + let proof_evals: ProofEvaluations = { + ProofEvaluations { + witness_evals, + fixed_selectors_evals, + logup_evals, + ft_eval1, + } + }; + + Ok(Proof { + proof_comms: ProofCommitments { + witness_comms, + logup_comms, + t_comm, + }, + proof_evals, + opening_proof, + }) +} diff --git a/msm/src/serialization/README.md b/msm/src/serialization/README.md new file mode 100644 index 0000000000..a70109c934 --- /dev/null +++ b/msm/src/serialization/README.md @@ -0,0 +1,34 @@ +## Kimchi Foreign Field gate serialization subcircuit + +The 15/16 bulletproof challenges will be given as 88 limbs, the encoding used by Kimchi. +A circuit would be required to convert into 15 bits limbs that will be used for the MSM algorithm. +We will use one row = one decomposition. + +We have the following circuit shape: + +| $b_{0}$ | $b_{1}$ | $b_{2}$ | $c_{0}$ | $c_{1}$ | ... | $c_{16}$ | $b_{2, 0}$ | $b_{2, 1}$ | ... | $b_{2, 19}$ | +| ------- | ------- | ------- | ------- | ------- | --- | -------- | --------- | --------- | --- | ----------- | +| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | +| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | + +We can suppose that $b_{2}$ is only on 80 bits as the input is maximum +$BN254(\mathbb{F}_{scalar})$, which is 254 bits long. +We will decompose $b_{2}$ in chunks of 4 bits: + +$$b_{2} = \sum_{i = 0}^{19} b_{2, i} 2^{4 i}$$ + +And we will add the following constraint: + +1. For the first 180 bits: + +$$b_{0} + b_{1} 2^88 + b_{2, 0} * 2^{88 * 2} - \sum_{j = 0}^{11} c_{j} 2^{15 j} = 0$$ + +2. For the remaining 75 bits: + +$$c_{12} + c_{13} * 2^{15} + c_{14} 2^{15 * 2} + c_{15} 2^{15 * 3} + c_{16} 2^{15 * 4} = \sum_{j = 1}^{19} b_{2, j} * 2^{4 (j - 1)}$$ + +with additional lookups. + +$b_{0}$, $b_{1}$ and $b_{2}$ are the decomposition on 88 bits given by the +foreign field gate in Kimchi. The values $c_{0}$, $\cdots$, $c_{16}$ are the limbs +required for the MSM circuit. Each limbs $c_{i}$ will be on 15 bits. diff --git a/msm/src/serialization/column.rs b/msm/src/serialization/column.rs new file mode 100644 index 0000000000..fb30bb1c66 --- /dev/null +++ b/msm/src/serialization/column.rs @@ -0,0 +1,90 @@ +use crate::{ + columns::{Column, ColumnIndexer}, + serialization::{interpreter::N_LIMBS_LARGE, N_INTERMEDIATE_LIMBS}, + N_LIMBS, +}; + +/// Total number of columns in the serialization circuit, including fixed selectors. +pub const N_COL_SER: usize = N_INTERMEDIATE_LIMBS + 6 * N_LIMBS + N_LIMBS_LARGE + 12; + +/// Number of fixed selectors for serialization circuit. +pub const N_FSEL_SER: usize = 2; + +/// Columns used by the serialization subcircuit. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub enum SerializationColumn { + /// A fixed selector column that gives one the current row, starting with 0. + CurrentRow, + /// For current row i, this is i - 2^{ceil(log(i)) - 1} + PreviousCoeffRow, + /// 3 88-bit inputs. For the row #i this represents the IPA challenge xi_{log(i)}. + ChalKimchi(usize), + /// N_INTERMEDIATE_LIMBS intermediate values, 4 bits long. Represent parts of the IPA challenge. + ChalIntermediate(usize), + /// N_LIMBS values, representing the converted IPA challenge. + ChalConverted(usize), + /// Previous coefficient C_j, this one is looked up. For the row i, the expected + /// value is (C_i >> 1). + CoeffInput(usize), + /// Trusted (for range) foreign field modulus, in 4 big limbs. + FFieldModulus(usize), + /// Quotient limbs (small) + QuotientSmall(usize), + /// Quotient limbs (large) + QuotientLarge(usize), + /// Sign of the quotient, one bit + QuotientSign, + /// Carry limbs + Carry(usize), + /// The resulting coefficient C_i = C_{i - 2^{ceil(log i) - 1}} * xi_{log(i)}. In small limbs. + CoeffResult(usize), +} + +impl ColumnIndexer for SerializationColumn { + const N_COL: usize = N_COL_SER; + fn to_column(self) -> Column { + match self { + Self::CurrentRow => Column::FixedSelector(0), + Self::PreviousCoeffRow => Column::FixedSelector(1), + Self::ChalKimchi(j) => { + assert!(j < 3); + Column::Relation(j) + } + Self::ChalIntermediate(j) => { + assert!(j < N_INTERMEDIATE_LIMBS); + Column::Relation(3 + j) + } + Self::ChalConverted(j) => { + assert!(j < N_LIMBS); + Column::Relation(N_INTERMEDIATE_LIMBS + 3 + j) + } + Self::CoeffInput(j) => { + assert!(j < N_LIMBS); + Column::Relation(N_INTERMEDIATE_LIMBS + N_LIMBS + 3 + j) + } + Self::FFieldModulus(j) => { + assert!(j < 4); + Column::Relation(N_INTERMEDIATE_LIMBS + 2 * N_LIMBS + 3 + j) + } + Self::QuotientSmall(j) => { + assert!(j < N_LIMBS); + Column::Relation(N_INTERMEDIATE_LIMBS + 2 * N_LIMBS + 7 + j) + } + Self::QuotientLarge(j) => { + assert!(j < N_LIMBS_LARGE); + Column::Relation(N_INTERMEDIATE_LIMBS + 3 * N_LIMBS + 7 + j) + } + Self::QuotientSign => { + Column::Relation(N_INTERMEDIATE_LIMBS + 3 * N_LIMBS + N_LIMBS_LARGE + 7) + } + Self::Carry(j) => { + assert!(j < 2 * N_LIMBS + 2); + Column::Relation(N_INTERMEDIATE_LIMBS + 3 * N_LIMBS + N_LIMBS_LARGE + 8 + j) + } + Self::CoeffResult(j) => { + assert!(j < N_LIMBS); + Column::Relation(N_INTERMEDIATE_LIMBS + 5 * N_LIMBS + N_LIMBS_LARGE + 10 + j) + } + } + } +} diff --git a/msm/src/serialization/interpreter.rs b/msm/src/serialization/interpreter.rs new file mode 100644 index 0000000000..d4353a1863 --- /dev/null +++ b/msm/src/serialization/interpreter.rs @@ -0,0 +1,1031 @@ +use ark_ff::{PrimeField, Zero}; +use num_bigint::{BigInt, BigUint, ToBigInt}; +use num_integer::Integer; +use num_traits::{sign::Signed, Euclid}; +use std::marker::PhantomData; + +use crate::{ + circuit_design::{ + capabilities::write_column_const, ColAccessCap, ColWriteCap, HybridCopyCap, LookupCap, + MultiRowReadCap, + }, + columns::ColumnIndexer, + logup::LookupTableID, + serialization::{ + column::{SerializationColumn, N_FSEL_SER}, + lookups::LookupTable, + N_INTERMEDIATE_LIMBS, + }, + LIMB_BITSIZE, N_LIMBS, +}; +use kimchi::circuits::{ + expr::{Expr, ExprInner, Variable}, + gate::CurrOrNext, +}; +use o1_utils::{field_helpers::FieldHelpers, foreign_field::ForeignElement}; + +// Such "helpers" defeat the whole purpose of the interpreter. +// TODO remove +pub trait HybridSerHelpers { + /// Returns the bits between [highest_bit, lowest_bit] of the variable `x`, + /// and copy the result in the column `position`. + /// The value `x` is expected to be encoded in big-endian + fn bitmask_be( + &mut self, + x: &>::Variable, + highest_bit: u32, + lowest_bit: u32, + position: CIx, + ) -> Self::Variable + where + Self: ColAccessCap; +} + +impl HybridSerHelpers + for crate::circuit_design::ConstraintBuilderEnv +{ + fn bitmask_be( + &mut self, + _x: &>::Variable, + _highest_bit: u32, + _lowest_bit: u32, + position: CIx, + ) -> >::Variable { + // No constraint added. It is supposed that the caller will constraint + // later the returned variable and/or do a range check. + Expr::Atom(ExprInner::Cell(Variable { + col: position.to_column(), + row: CurrOrNext::Curr, + })) + } +} + +impl< + F: PrimeField, + CIx: ColumnIndexer, + const N_COL: usize, + const N_REL: usize, + const N_DSEL: usize, + const N_FSEL: usize, + LT: LookupTableID, + > HybridSerHelpers + for crate::circuit_design::WitnessBuilderEnv +{ + fn bitmask_be( + &mut self, + x: &>::Variable, + highest_bit: u32, + lowest_bit: u32, + position: CIx, + ) -> >::Variable { + // FIXME: we can assume bitmask_be will be called only on value with + // maximum 128 bits. We use bitmask_be only for the limbs + let x_bytes_u8 = &x.to_bytes()[0..16]; + let x_u128 = u128::from_le_bytes(x_bytes_u8.try_into().unwrap()); + let res = (x_u128 >> lowest_bit) & ((1 << (highest_bit - lowest_bit)) - 1); + let res_fp: F = res.into(); + self.write_column_raw(position.to_column(), res_fp); + res_fp + } +} + +/// Alias for LIMB_BITSIZE, used for convenience. +pub const LIMB_BITSIZE_SMALL: usize = LIMB_BITSIZE; +/// Alias for N_LIMBS, used for convenience. +pub const N_LIMBS_SMALL: usize = N_LIMBS; + +/// In FEC addition we use bigger limbs, of 75 bits, that are still +/// nicely decomposable into smaller 15bit ones for range checking. +pub const LIMB_BITSIZE_LARGE: usize = LIMB_BITSIZE_SMALL * 5; // 75 bits +pub const N_LIMBS_LARGE: usize = 4; + +/// Returns the highest limb of the foreign field modulus. Is used by the lookups. +pub fn ff_modulus_highest_limb() -> BigUint { + let f_bui: BigUint = TryFrom::try_from(::MODULUS).unwrap(); + f_bui >> ((N_LIMBS - 1) * LIMB_BITSIZE) +} + +/// Deserialize a field element of the scalar field of Vesta or Pallas given as +/// a sequence of 3 limbs of 88 bits. +/// It will deserialize into limbs of 15 bits. +/// Given a scalar field element of Vesta or Pallas, here the decomposition: +/// ```text +/// limbs = [limbs0, limbs1, limbs2] +/// | limbs0 | limbs1 | limbs2 | +/// | 0 ... 87 | 88 ... 175 | 176 .. 264 | +/// ---- ---- ---- +/// / \ / \ / \ +/// (1) (2) (3) +/// (1): c0 = 0...14, c1 = 15..29, c2 = 30..44, c3 = 45..59, c4 = 60..74 +/// (1) and (2): c5 = limbs0[75]..limbs0[87] || limbs1[0]..limbs1[1] +/// (2): c6 = 2...16, c7 = 17..31, c8 = 32..46, c9 = 47..61, c10 = 62..76 +/// (2) and (3): c11 = limbs1[77]..limbs1[87] || limbs2[0]..limbs2[3] +/// (3) c12 = 4...18, c13 = 19..33, c14 = 34..48, c15 = 49..63, c16 = 64..78 +/// ``` +/// And we can ignore the last 10 bits (i.e. `limbs2[78..87]`) as a field element +/// is 254bits long. +pub fn deserialize_field_element< + F: PrimeField, + Ff: PrimeField, + Env: ColAccessCap + + LookupCap> + + HybridCopyCap + + HybridSerHelpers>, +>( + env: &mut Env, + limbs: [BigUint; 3], +) { + let input_limb0 = Env::constant(F::from(limbs[0].clone())); + let input_limb1 = Env::constant(F::from(limbs[1].clone())); + let input_limb2 = Env::constant(F::from(limbs[2].clone())); + let input_limbs = [ + input_limb0.clone(), + input_limb1.clone(), + input_limb2.clone(), + ]; + + // FIXME: should we assert this in the circuit? + assert!(limbs[0] < BigUint::from(2u128.pow(88))); + assert!(limbs[1] < BigUint::from(2u128.pow(88))); + assert!(limbs[2] < BigUint::from(2u128.pow(79))); + + let limb0_var = env.hcopy(&input_limb0, SerializationColumn::ChalKimchi(0)); + let limb1_var = env.hcopy(&input_limb1, SerializationColumn::ChalKimchi(1)); + let limb2_var = env.hcopy(&input_limb2, SerializationColumn::ChalKimchi(2)); + + let mut limb2_vars = vec![]; + + // Compute individual 4 bits limbs of b2 + { + let mut constraint = limb2_var.clone(); + for j in 0..N_INTERMEDIATE_LIMBS { + let var = env.bitmask_be( + &input_limb2, + 4 * (j + 1) as u32, + 4 * j as u32, + SerializationColumn::ChalIntermediate(j), + ); + limb2_vars.push(var.clone()); + let pow: u128 = 1 << (4 * j); + let pow = Env::constant(pow.into()); + constraint = constraint - var * pow; + } + env.assert_zero(constraint) + } + // Range check on each limb + limb2_vars + .iter() + .for_each(|v| env.lookup(LookupTable::RangeCheck4, vec![v.clone()])); + + let mut fifteen_bits_vars = vec![]; + + for j in 0..3 { + for i in 0..5 { + let ci_var = env.bitmask_be( + &input_limbs[j], + 15 * (i + 1) + 2 * j as u32, + 15 * i + 2 * j as u32, + SerializationColumn::ChalConverted(6 * j + i as usize), + ); + fifteen_bits_vars.push(ci_var) + } + + if j < 2 { + let shift = 2 * (j + 1); // ∈ [2, 4] + let res = (limbs[j].clone() >> (73 + shift)) + & BigUint::from((1u128 << (88 - 73 + shift)) - 1); + let res_prime = limbs[j + 1].clone() & BigUint::from((1u128 << shift) - 1); + let res: BigUint = res + (res_prime << (15 - shift)); + let res = Env::constant(F::from(res)); + let c5_var = env.hcopy(&res, SerializationColumn::ChalConverted(6 * j + 5)); + fifteen_bits_vars.push(c5_var); + } + } + + // Range check on each limb + fifteen_bits_vars + .iter() + .for_each(|v| env.lookup(LookupTable::RangeCheck15, vec![v.clone()])); + + let shl_88_var = Env::constant(F::from(1u128 << 88u128)); + let shl_15_var = Env::constant(F::from(1u128 << 15u128)); + + // -- Start second constraint + { + // b0 + b1 * 2^88 + b2 * 2^176 + let constraint = { + limb0_var + + limb1_var * shl_88_var.clone() + + shl_88_var.clone() * shl_88_var.clone() * limb2_vars[0].clone() + }; + + // Substracting 15 bits values + let (constraint, _) = (0..=11).fold( + (constraint, Env::constant(F::one())), + |(acc, shl_var), i| { + ( + acc - fifteen_bits_vars[i].clone() * shl_var.clone(), + shl_15_var.clone() * shl_var.clone(), + ) + }, + ); + env.assert_zero(constraint); + } + + // -- Start third constraint + { + // Computing + // c12 + c13 * 2^15 + c14 * 2^30 + c15 * 2^45 + c16 * 2^60 + let constraint = fifteen_bits_vars[12].clone(); + let constraint = (1..=4).fold(constraint, |acc, i| { + acc + fifteen_bits_vars[12 + i].clone() * Env::constant(F::from(1u128 << (15 * i))) + }); + + let constraint = (1..=19).fold(constraint, |acc, i| { + let var = limb2_vars[i].clone() * Env::constant(F::from(1u128 << (4 * (i - 1)))); + acc - var + }); + env.assert_zero(constraint); + } +} + +/// Interprets bigint `input` as an element of a field modulo `f_bi`, +/// converts it to `[0,f_bi)` range, and outptus a corresponding +/// biguint representation. +pub fn bigint_to_biguint_f(input: BigInt, f_bi: &BigInt) -> BigUint { + let corrected_import: BigInt = if input.is_negative() && input > -f_bi { + &input + f_bi + } else if input.is_negative() { + Euclid::rem_euclid(&input, f_bi) + } else { + input + }; + corrected_import.to_biguint().unwrap() +} + +/// Decompose biguint into `N` limbs of bit size `B`. +pub fn limb_decompose_biguint( + input: BigUint, +) -> [F; N] { + let ff_el: ForeignElement = ForeignElement::from_biguint(input); + ff_el.limbs +} + +/// Decomposes a foreign field element into `N` limbs of bit size `B`. +pub fn limb_decompose_ff( + input: &Ff, +) -> [F; N] { + let input_bi: BigUint = FieldHelpers::to_biguint(input); + limb_decompose_biguint::(input_bi) +} + +/// Returns all `(i,j)` with `i,j \in [0,list_len]` such that `i + j = n`. +fn choice2(list_len: usize, n: usize) -> Vec<(usize, usize)> { + use itertools::Itertools; + let indices = Vec::from_iter(0..list_len); + indices + .clone() + .into_iter() + .cartesian_product(indices) + .filter(|(i1, i2)| i1 + i2 == n) + .collect() +} + +/// A convenience helper: given a `list_len` and `n` (arguments of +/// `choice2`), it creates an array consisting of `f(i,j)` where `i,j +/// \in [0,list_len]` such that `i + j = n`, and then sums all the +/// elements in this array. +pub fn fold_choice2(list_len: usize, n: usize, f: Foo) -> Var +where + Foo: Fn(usize, usize) -> Var, + Var: Clone + std::ops::Add + From, +{ + let chosen = choice2(list_len, n); + chosen + .into_iter() + .map(|(j, k)| f(j, k)) + .fold(Var::from(0u64), |acc, v| acc + v) +} + +/// Helper function for limb recombination. +/// +/// Combines an array of `M` elements (think `N_LIMBS_SMALL`) into an +/// array of `N` elements (think `N_LIMBS_LARGE`) elements by taking +/// chunks `a_i` of size `K = BITSIZE_N / BITSIZE_M` from the first, and recombining them as +/// `a_i * 2^{i * 2^LIMB_BITSIZE_SMALL}`. +pub fn combine_limbs_m_to_n< + const M: usize, + const N: usize, + const BITSIZE_M: usize, + const BITSIZE_N: usize, + F: PrimeField, + V: std::ops::Add + + std::ops::Sub + + std::ops::Mul + + std::ops::Neg + + From + + Clone, + Func: Fn(F) -> V, +>( + from_field: Func, + x: [V; M], +) -> [V; N] { + assert!(BITSIZE_N % BITSIZE_M == 0); + let k = BITSIZE_N / BITSIZE_M; + let constant_bui = |x: BigUint| from_field(F::from(x)); + let disparity: usize = M % k; + std::array::from_fn(|i| { + // We have less small limbs in the last large limb + let upper_bound = if disparity != 0 && i == N - 1 { + disparity + } else { + k + }; + (0..upper_bound) + .map(|j| x[k * i + j].clone() * constant_bui(BigUint::from(1u128) << (j * BITSIZE_M))) + .fold(V::from(0u64), |acc, v| acc + v) + }) +} + +/// Helper function for limb recombination. +/// +/// Combines small limbs into big limbs. +pub fn combine_small_to_large>( + x: [Env::Variable; N_LIMBS_SMALL], +) -> [Env::Variable; N_LIMBS_LARGE] { + combine_limbs_m_to_n::< + N_LIMBS_SMALL, + N_LIMBS_LARGE, + LIMB_BITSIZE_SMALL, + LIMB_BITSIZE_LARGE, + F, + Env::Variable, + _, + >(|f| Env::constant(f), x) +} + +/// Helper function for limb recombination for carry specifically. +/// Each big carry limb is stored as 6 (not 5!) small elements. We +/// accept 36 small limbs, and return 6 large ones. +pub fn combine_carry>( + x: [Env::Variable; 2 * N_LIMBS_SMALL + 2], +) -> [Env::Variable; 2 * N_LIMBS_LARGE - 2] { + let constant_u128 = |x: u128| Env::constant(From::from(x)); + std::array::from_fn(|i| { + (0..6) + .map(|j| x[6 * i + j].clone() * constant_u128(1u128 << (j * (LIMB_BITSIZE_SMALL - 1)))) + .fold(Env::Variable::from(0u64), |acc, v| acc + v) + }) +} + +/// This constarins the multiplication part of the circuit. +pub fn constrain_multiplication< + F: PrimeField, + Ff: PrimeField, + Env: ColAccessCap + LookupCap>, +>( + env: &mut Env, +) { + let chal_converted_limbs_small: [_; N_LIMBS_SMALL] = + core::array::from_fn(|i| env.read_column(SerializationColumn::ChalConverted(i))); + let coeff_input_limbs_small: [_; N_LIMBS_SMALL] = + core::array::from_fn(|i| env.read_column(SerializationColumn::CoeffInput(i))); + let coeff_result_limbs_small: [_; N_LIMBS_SMALL] = + core::array::from_fn(|i| env.read_column(SerializationColumn::CoeffResult(i))); + + let ffield_modulus_limbs_large: [_; N_LIMBS_LARGE] = + core::array::from_fn(|i| env.read_column(SerializationColumn::FFieldModulus(i))); + let quotient_limbs_small: [_; N_LIMBS_SMALL] = + core::array::from_fn(|i| env.read_column(SerializationColumn::QuotientSmall(i))); + let quotient_limbs_large: [_; N_LIMBS_LARGE] = + core::array::from_fn(|i| env.read_column(SerializationColumn::QuotientLarge(i))); + let carry_limbs_small: [_; 2 * N_LIMBS_SMALL + 2] = + core::array::from_fn(|i| env.read_column(SerializationColumn::Carry(i))); + + let quotient_sign = env.read_column(SerializationColumn::QuotientSign); + + // u128 covers our limb sizes shifts which is good + let constant_u128 = |x: u128| -> >::Variable { + Env::constant(From::from(x)) + }; + + { + let current_row = env.read_column(SerializationColumn::CurrentRow); + let previous_coeff_row = env.read_column(SerializationColumn::PreviousCoeffRow); + + // TODO: We don't have to write top half of the table since it's never read. + + // Writing the output + // (cur_i, [VEC]) + let mut vec_output: Vec<_> = coeff_result_limbs_small.clone().to_vec(); + vec_output.insert(0, current_row); + env.lookup_runtime_write(LookupTable::MultiplicationBus, vec_output); + + //// Writing the constant: it's only read once + //// (0, [VEC representing 0]) + env.lookup_runtime_write( + LookupTable::MultiplicationBus, + vec![Env::constant(F::zero()); N_LIMBS_SMALL + 1], + ); + + // Reading the input: + // (prev_i, [VEC]) + let mut vec_input: Vec<_> = coeff_input_limbs_small.clone().to_vec(); + vec_input.insert(0, previous_coeff_row); + env.lookup(LookupTable::MultiplicationBus, vec_input.clone()); + } + + // Quotient sign must be -1 or 1. + env.assert_zero(quotient_sign.clone() * quotient_sign.clone() - Env::constant(F::one())); + + // Result variable must be in the field. + for (i, x) in coeff_result_limbs_small.iter().enumerate() { + if i % N_LIMBS_SMALL == N_LIMBS_SMALL - 1 { + // If it's the highest limb, we need to check that it's representing a field element. + env.lookup( + LookupTable::RangeCheckFfHighest(PhantomData), + vec![x.clone()], + ); + } else { + env.lookup(LookupTable::RangeCheck15, vec![x.clone()]); + } + } + + // Quotient limbs must fit into 15 bits, but we don't care if they're in the field. + for x in quotient_limbs_small.iter() { + env.lookup(LookupTable::RangeCheck15, vec![x.clone()]); + } + + // Carry limbs need to be in particular ranges. + for (i, x) in carry_limbs_small.iter().enumerate() { + if i % 6 == 5 { + // This should be a different range check depending on which big-limb we're processing? + // So instead of one type of lookup we will have 5 different ones? + env.lookup(LookupTable::RangeCheck9Abs, vec![x.clone()]); // 4 + 5 ? + } else { + // TODO add actual lookup + env.lookup(LookupTable::RangeCheck14Abs, vec![x.clone()]); + //env.range_check_abs15(x); + // assert!(x < F::from(1u64 << 15) || x >= F::zero() - F::from(1u64 << 15)); + } + } + + // FIXME: Some of these /have/ to be in the [0,F), and carries have very specific ranges! + + let chal_converted_limbs_large = + combine_small_to_large::<_, _, Env>(chal_converted_limbs_small.clone()); + let coeff_input_limbs_large = + combine_small_to_large::<_, _, Env>(coeff_input_limbs_small.clone()); + let coeff_result_limbs_large = + combine_small_to_large::<_, _, Env>(coeff_result_limbs_small.clone()); + let quotient_limbs_large_abs_expected = + combine_small_to_large::<_, _, Env>(quotient_limbs_small.clone()); + for j in 0..N_LIMBS_LARGE { + env.assert_zero( + quotient_limbs_large[j].clone() + - quotient_sign.clone() * quotient_limbs_large_abs_expected[j].clone(), + ); + } + let carry_limbs_large: [_; 2 * N_LIMBS_LARGE - 2] = + combine_carry::<_, _, Env>(carry_limbs_small.clone()); + + let limb_size_large = constant_u128(1u128 << LIMB_BITSIZE_LARGE); + let add_extra_carries = |i: usize, + carry_limbs_large: &[>::Variable; + 2 * N_LIMBS_LARGE - 2]| + -> >::Variable { + if i == 0 { + -(carry_limbs_large[0].clone() * limb_size_large.clone()) + } else if i < 2 * N_LIMBS_LARGE - 2 { + carry_limbs_large[i - 1].clone() + - carry_limbs_large[i].clone() * limb_size_large.clone() + } else if i == 2 * N_LIMBS_LARGE - 2 { + carry_limbs_large[i - 1].clone() + } else { + panic!("add_extra_carries: the index {i:?} is too high") + } + }; + + // Equation 1 + // General form: + // \sum_{k,j | k+j = i} xi_j cprev_k - c_i - \sum_{k,j} q_k f_j - c_i * 2^B + c_{i-1} = 0 + #[allow(clippy::needless_range_loop)] + for i in 0..2 * N_LIMBS_LARGE - 1 { + let mut constraint = fold_choice2(N_LIMBS_LARGE, i, |j, k| { + chal_converted_limbs_large[j].clone() * coeff_input_limbs_large[k].clone() + }); + if i < N_LIMBS_LARGE { + constraint = constraint - coeff_result_limbs_large[i].clone(); + } + constraint = constraint + - fold_choice2(N_LIMBS_LARGE, i, |j, k| { + quotient_limbs_large[j].clone() * ffield_modulus_limbs_large[k].clone() + }); + constraint = constraint + add_extra_carries(i, &carry_limbs_large); + + env.assert_zero(constraint); + } +} + +/// Multiplication sub-circuit of the serialization/bootstrap +/// procedure. Takes challenge x_{log i} and coefficient c_prev_i as input, +/// returns next coefficient c_i. +pub fn multiplication_circuit< + F: PrimeField, + Ff: PrimeField, + Env: ColWriteCap + LookupCap>, +>( + env: &mut Env, + chal: Ff, + coeff_input: Ff, + write_chal_converted: bool, +) -> Ff { + let coeff_result = chal * coeff_input; + + let two_bi: BigInt = TryFrom::try_from(2).unwrap(); + + let large_limb_size: F = From::from(1u128 << LIMB_BITSIZE_LARGE); + + // Foreign field modulus + let f_bui: BigUint = TryFrom::try_from(Ff::MODULUS).unwrap(); + let f_bi: BigInt = f_bui.to_bigint().unwrap(); + + // Native field modulus (prime) + let n_bui: BigUint = TryFrom::try_from(F::MODULUS).unwrap(); + let n_bi: BigInt = n_bui.to_bigint().unwrap(); + let n_half_bi = &n_bi / &two_bi; + + let chal_limbs_small: [F; N_LIMBS_SMALL] = + limb_decompose_ff::(&chal); + let chal_limbs_large: [F; N_LIMBS_LARGE] = + limb_decompose_ff::(&chal); + let coeff_input_limbs_large: [F; N_LIMBS_LARGE] = + limb_decompose_ff::(&coeff_input); + let coeff_result_limbs_large: [F; N_LIMBS_LARGE] = + limb_decompose_ff::(&coeff_result); + let ff_modulus_limbs_large: [F; N_LIMBS_LARGE] = + limb_decompose_biguint::(f_bui.clone()); + + let coeff_input_limbs_small: [F; N_LIMBS_SMALL] = + limb_decompose_ff::(&coeff_input); + let coeff_result_limbs_small: [F; N_LIMBS_SMALL] = + limb_decompose_ff::(&coeff_result); + + // No generics for closures + let write_array_small = + |env: &mut Env, + input: [F; N_LIMBS_SMALL], + f_column: &dyn Fn(usize) -> SerializationColumn| { + input.iter().enumerate().for_each(|(i, var)| { + env.write_column(f_column(i), &Env::constant(*var)); + }) + }; + + let write_array_large = + |env: &mut Env, + input: [F; N_LIMBS_LARGE], + f_column: &dyn Fn(usize) -> SerializationColumn| { + input.iter().enumerate().for_each(|(i, var)| { + env.write_column(f_column(i), &Env::constant(*var)); + }) + }; + + if write_chal_converted { + write_array_small(env, chal_limbs_small, &|i| { + SerializationColumn::ChalConverted(i) + }); + } + write_array_small(env, coeff_input_limbs_small, &|i| { + SerializationColumn::CoeffInput(i) + }); + write_array_small(env, coeff_result_limbs_small, &|i| { + SerializationColumn::CoeffResult(i) + }); + write_array_large(env, ff_modulus_limbs_large, &|i| { + SerializationColumn::FFieldModulus(i) + }); + + let chal_bi: BigInt = FieldHelpers::to_bigint_positive(&chal); + let coeff_input_bi: BigInt = FieldHelpers::to_bigint_positive(&coeff_input); + let coeff_result_bi: BigInt = FieldHelpers::to_bigint_positive(&coeff_result); + + let (quotient_bi, r_bi) = (&chal_bi * coeff_input_bi - coeff_result_bi).div_rem(&f_bi); + assert!(r_bi.is_zero()); + let (quotient_bi, quotient_sign): (BigInt, F) = if quotient_bi.is_negative() { + (-quotient_bi, -F::one()) + } else { + (quotient_bi, F::one()) + }; + + // Written into the columns + let quotient_limbs_small: [F; N_LIMBS_SMALL] = + limb_decompose_biguint::( + quotient_bi.to_biguint().unwrap(), + ); + + // Used for witness computation + let quotient_limbs_large: [F; N_LIMBS_LARGE] = + limb_decompose_biguint::( + quotient_bi.to_biguint().unwrap(), + ) + .into_iter() + .map(|v| v * quotient_sign) + .collect::>() + .try_into() + .unwrap(); + + write_array_small(env, quotient_limbs_small, &|i| { + SerializationColumn::QuotientSmall(i) + }); + write_array_large(env, quotient_limbs_large, &|i| { + SerializationColumn::QuotientLarge(i) + }); + + write_column_const(env, SerializationColumn::QuotientSign, "ient_sign); + + let mut carry: F = From::from(0u64); + + #[allow(clippy::needless_range_loop)] + for i in 0..N_LIMBS_LARGE * 2 - 1 { + let compute_carry = |res: F| -> F { + // TODO enforce this as an integer division + let mut res_bi = res.to_bigint_positive(); + if res_bi > n_half_bi { + res_bi -= &n_bi; + } + let (div, rem) = res_bi.div_rem(&large_limb_size.to_bigint_positive()); + assert!( + rem.is_zero(), + "Cannot compute carry for step {i:?}: div {div:?}, rem {rem:?}" + ); + let carry_f: BigUint = bigint_to_biguint_f(div, &n_bi); + F::from_biguint(&carry_f).unwrap() + }; + + let assign_carry = |env: &mut Env, newcarry: F, carryvar: &mut F| { + // Last carry should be zero, otherwise we record it + if i < N_LIMBS_LARGE * 2 - 2 { + // Carries will often not fit into 5 limbs, but they /should/ fit in 6 limbs I think. + let newcarry_sign = if newcarry.to_bigint_positive() > n_half_bi { + F::zero() - F::one() + } else { + F::one() + }; + let newcarry_abs_bui = (newcarry * newcarry_sign).to_biguint(); + // Our big carries are at most 79 bits, so we need 6 small limbs per each. + // However we split them into 14-bit chunks -- each chunk is signed, so in the end + // the layout is [14bitabs,14bitabs,14bitabs,14bitabs,14bitabs,9bitabs] + // altogether giving a 79bit number (signed). + let newcarry_limbs: [F; 6] = + limb_decompose_biguint::( + newcarry_abs_bui.clone(), + ); + + for (j, limb) in newcarry_limbs.iter().enumerate() { + env.write_column( + SerializationColumn::Carry(6 * i + j), + &Env::constant(newcarry_sign * limb), + ); + } + + *carryvar = newcarry; + } else { + // should this be in circiut? + assert!(newcarry.is_zero(), "Last carry is non-zero"); + } + }; + + let mut res = fold_choice2(N_LIMBS_LARGE, i, |j, k| { + chal_limbs_large[j] * coeff_input_limbs_large[k] + }); + if i < N_LIMBS_LARGE { + res -= &coeff_result_limbs_large[i]; + } + res -= fold_choice2(N_LIMBS_LARGE, i, |j, k| { + quotient_limbs_large[j] * ff_modulus_limbs_large[k] + }); + res += carry; + let newcarry = compute_carry(res); + assign_carry(env, newcarry, &mut carry); + } + + constrain_multiplication::(env); + coeff_result +} + +/// Full serialization circuit. +pub fn serialization_circuit< + F: PrimeField, + Ff: PrimeField, + Env: ColWriteCap + + LookupCap> + + HybridCopyCap + + HybridSerHelpers> + + MultiRowReadCap, +>( + env: &mut Env, + input_chal: Ff, + field_elements: Vec<[F; 3]>, + domain_size: usize, +) { + // A map containing results of multiplications, per row + let mut prev_rows: Vec = vec![]; + + for (i, limbs) in field_elements.iter().enumerate() { + // Witness + + let coeff_input = if i == 0 { + Ff::zero() + } else { + prev_rows[i - (1 << (i.ilog2()))] + }; + + deserialize_field_element(env, limbs.map(Into::into)); + let mul_result = multiplication_circuit(env, input_chal, coeff_input, false); + + prev_rows.push(mul_result); + + // Don't reset on the last iteration. + if i < domain_size { + env.next_row() + } + } +} + +/// Builds fixed selectors for serialization circuit. +/// +/// +/// | i | sel1 | sel2 | +/// |------|------|------| +/// | 0000 | 0 | 0 | +/// | 0001 | 1 | 0 | +/// | 0010 | 2 | 0 | +/// | 0011 | 3 | 1 | +/// | 0100 | 4 | 0 | +/// | 0101 | 5 | 1 | +/// | 0110 | 6 | 2 | +/// | 0111 | 7 | 3 | +/// | 1000 | 8 | 0 | +/// | 1001 | 9 | 1 | +/// | 1010 | 10 | 2 | +/// | 1011 | 11 | 3 | +/// | 1100 | 12 | 4 | +/// | 1101 | 13 | 5 | +/// | 1110 | 14 | 6 | +/// | 1111 | 15 | 7 | +pub fn build_selectors(domain_size: usize) -> [Vec; N_FSEL_SER] { + let sel1 = (0..domain_size).map(|i| F::from(i as u64)).collect(); + let sel2 = (0..domain_size) + .map(|i| { + if i < 2 { + F::zero() + } else { + F::from((i - (1 << (i.ilog2()))) as u64) + } + }) + .collect(); + + [sel1, sel2] +} + +#[cfg(test)] +mod tests { + use crate::{ + circuit_design::{ColAccessCap, WitnessBuilderEnv}, + columns::ColumnIndexer, + serialization::{ + column::SerializationColumn, + interpreter::{ + build_selectors, deserialize_field_element, limb_decompose_ff, + multiplication_circuit, serialization_circuit, + }, + lookups::LookupTable, + N_INTERMEDIATE_LIMBS, + }, + Ff1, LIMB_BITSIZE, N_LIMBS, + }; + use ark_ff::{BigInteger, One, PrimeField, UniformRand, Zero}; + use num_bigint::BigUint; + use o1_utils::{tests::make_test_rng, FieldHelpers}; + use rand::{CryptoRng, Rng, RngCore}; + use std::str::FromStr; + + // In this test module we assume native = foreign = scalar field of Vesta. + type Fp = Ff1; + + type SerializationWitnessBuilderEnv = WitnessBuilderEnv< + Fp, + SerializationColumn, + { ::N_COL }, + { ::N_COL }, + 0, + 0, + LookupTable, + >; + + fn test_decomposition_generic(x: Fp) { + let bits = x.to_bits(); + let limb0: u128 = { + let limb0_le_bits: &[bool] = &bits.clone().into_iter().take(88).collect::>(); + let limb0 = Fp::from_bits(limb0_le_bits).unwrap(); + limb0.to_biguint().try_into().unwrap() + }; + let limb1: u128 = { + let limb0_le_bits: &[bool] = &bits + .clone() + .into_iter() + .skip(88) + .take(88) + .collect::>(); + let limb0 = Fp::from_bits(limb0_le_bits).unwrap(); + limb0.to_biguint().try_into().unwrap() + }; + let limb2: u128 = { + let limb0_le_bits: &[bool] = &bits + .clone() + .into_iter() + .skip(2 * 88) + .take(79) + .collect::>(); + let limb0 = Fp::from_bits(limb0_le_bits).unwrap(); + limb0.to_biguint().try_into().unwrap() + }; + let mut dummy_env = SerializationWitnessBuilderEnv::create(); + deserialize_field_element( + &mut dummy_env, + [ + BigUint::from(limb0), + BigUint::from(limb1), + BigUint::from(limb2), + ], + ); + + // Check limb are copied into the environment + let limbs_to_assert = [limb0, limb1, limb2]; + for (i, limb) in limbs_to_assert.iter().enumerate() { + assert_eq!( + Fp::from(*limb), + dummy_env.read_column(SerializationColumn::ChalKimchi(i)) + ); + } + + // Check intermediate limbs + { + let bits = Fp::from(limb2).to_bits(); + for j in 0..N_INTERMEDIATE_LIMBS { + let le_bits: &[bool] = &bits + .clone() + .into_iter() + .skip(j * 4) + .take(4) + .collect::>(); + let t = Fp::from_bits(le_bits).unwrap(); + let intermediate_v = + dummy_env.read_column(SerializationColumn::ChalIntermediate(j)); + assert_eq!( + t, + intermediate_v, + "{}", + format_args!( + "Intermediate limb {j}. Exp value is {:?}, computed is {:?}", + t.to_biguint(), + intermediate_v.to_biguint() + ) + ) + } + } + + // Checking msm limbs + for i in 0..N_LIMBS { + let le_bits: &[bool] = &bits + .clone() + .into_iter() + .skip(i * LIMB_BITSIZE) + .take(LIMB_BITSIZE) + .collect::>(); + let t = Fp::from_bits(le_bits).unwrap(); + let converted_v = dummy_env.read_column(SerializationColumn::ChalConverted(i)); + assert_eq!( + t, + converted_v, + "{}", + format_args!( + "MSM limb {i}. Exp value is {:?}, computed is {:?}", + t.to_biguint(), + converted_v.to_biguint() + ) + ) + } + } + + #[test] + fn test_decomposition_zero() { + test_decomposition_generic(Fp::zero()); + } + + #[test] + fn test_decomposition_one() { + test_decomposition_generic(Fp::one()); + } + + #[test] + fn test_decomposition_random_first_limb_only() { + let mut rng = make_test_rng(None); + let x = rng.gen_range(0..2u128.pow(88) - 1); + test_decomposition_generic(Fp::from(x)); + } + + #[test] + fn test_decomposition_second_limb_only() { + test_decomposition_generic(Fp::from(2u128.pow(88))); + test_decomposition_generic(Fp::from(2u128.pow(88) + 1)); + test_decomposition_generic(Fp::from(2u128.pow(88) + 2)); + test_decomposition_generic(Fp::from(2u128.pow(88) + 16)); + test_decomposition_generic(Fp::from(2u128.pow(88) + 23234)); + } + + #[test] + fn test_decomposition_random_second_limb_only() { + let mut rng = make_test_rng(None); + let x = rng.gen_range(0..2u128.pow(88) - 1); + test_decomposition_generic(Fp::from(2u128.pow(88) + x)); + } + + #[test] + fn test_decomposition_random() { + let mut rng = make_test_rng(None); + test_decomposition_generic(Fp::rand(&mut rng)); + } + + #[test] + fn test_decomposition_order_minus_one() { + let x = BigUint::from_bytes_be(&::MODULUS.to_bytes_be()) + - BigUint::from_str("1").unwrap(); + + test_decomposition_generic(Fp::from(x)); + } + + fn build_serialization_mul_circuit( + rng: &mut RNG, + domain_size: usize, + ) -> SerializationWitnessBuilderEnv { + let mut witness_env = WitnessBuilderEnv::create(); + + // To support less rows than domain_size we need to have selectors. + //let row_num = rng.gen_range(0..domain_size); + + let fixed_selectors = build_selectors(domain_size); + witness_env.set_fixed_selectors(fixed_selectors.to_vec()); + + for row_i in 0..domain_size { + let input_chal: Ff1 = ::rand(rng); + let coeff_input: Ff1 = ::rand(rng); + multiplication_circuit(&mut witness_env, input_chal, coeff_input, true); + + if row_i < domain_size - 1 { + witness_env.next_row(); + } + } + + witness_env + } + + #[test] + /// Builds the FF addition circuit with random values. The witness + /// environment enforces the constraints internally, so it is + /// enough to just build the circuit to ensure it is satisfied. + pub fn test_serialization_mul_circuit() { + let mut rng = o1_utils::tests::make_test_rng(None); + build_serialization_mul_circuit(&mut rng, 1 << 4); + } + + #[test] + /// Builds the whole serialization circuit and checks it internally. + pub fn test_serialization_full_circuit() { + let mut rng = o1_utils::tests::make_test_rng(None); + let domain_size = 1 << 15; + + let mut witness_env: SerializationWitnessBuilderEnv = WitnessBuilderEnv::create(); + + let fixed_selectors = build_selectors(domain_size); + witness_env.set_fixed_selectors(fixed_selectors.to_vec()); + + let mut field_elements = vec![]; + + // FIXME: we do use always the same values here, because we have a + // constant check (X - c), different for each row. And there is no + // constant support/public input yet in the quotient polynomial. + let input_chal: Ff1 = ::rand(&mut rng); + let [input1, input2, input3]: [Fp; 3] = limb_decompose_ff::(&input_chal); + for _ in 0..domain_size { + field_elements.push([input1, input2, input3]) + } + + serialization_circuit( + &mut witness_env, + input_chal, + field_elements.clone(), + domain_size, + ); + } +} diff --git a/msm/src/serialization/lookups.rs b/msm/src/serialization/lookups.rs new file mode 100644 index 0000000000..7dc2592f22 --- /dev/null +++ b/msm/src/serialization/lookups.rs @@ -0,0 +1,212 @@ +use crate::{logup::LookupTableID, Logup, LIMB_BITSIZE, N_LIMBS}; +use ark_ff::PrimeField; +use num_bigint::BigUint; +use o1_utils::FieldHelpers; +use std::marker::PhantomData; +use strum_macros::EnumIter; + +/// Enumeration of concrete lookup tables used in serialization circuit. +#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, Ord, PartialOrd, EnumIter)] +pub enum LookupTable { + /// x ∈ [0, 2^15] + RangeCheck15, + /// x ∈ [0, 2^4] + RangeCheck4, + /// x ∈ [-2^14, 2^14-1] + RangeCheck14Abs, + /// x ∈ [-2^4, 2^4-1] + RangeCheck9Abs, + /// x ∈ [0, ff_highest] where ff_highest is the highest 15-bit + /// limb of the modulus of the foreign field `Ff`. + RangeCheckFfHighest(PhantomData), + /// Communication bus for the multiplication circuit. + MultiplicationBus, +} + +impl LookupTableID for LookupTable { + fn to_u32(&self) -> u32 { + match self { + Self::RangeCheck15 => 1, + Self::RangeCheck4 => 2, + Self::RangeCheck14Abs => 3, + Self::RangeCheck9Abs => 4, + Self::RangeCheckFfHighest(_) => 5, + Self::MultiplicationBus => 6, + } + } + + fn from_u32(value: u32) -> Self { + match value { + 1 => Self::RangeCheck15, + 2 => Self::RangeCheck4, + 3 => Self::RangeCheck14Abs, + 4 => Self::RangeCheck9Abs, + 5 => Self::RangeCheckFfHighest(PhantomData), + 6 => Self::MultiplicationBus, + _ => panic!("Invalid lookup table id"), + } + } + + fn is_fixed(&self) -> bool { + match self { + Self::RangeCheck15 => true, + Self::RangeCheck4 => true, + Self::RangeCheck14Abs => true, + Self::RangeCheck9Abs => true, + Self::RangeCheckFfHighest(_) => true, + Self::MultiplicationBus => false, + } + } + + fn runtime_create_column(&self) -> bool { + match self { + Self::MultiplicationBus => false, + _ => panic!("runtime_create_column was called on a non-runtime table"), + } + } + + fn length(&self) -> usize { + match self { + Self::RangeCheck15 => 1 << 15, + Self::RangeCheck4 => 1 << 4, + Self::RangeCheck14Abs => 1 << 15, + Self::RangeCheck9Abs => 1 << 10, + Self::RangeCheckFfHighest(_) => TryFrom::try_from( + crate::serialization::interpreter::ff_modulus_highest_limb::(), + ) + .unwrap(), + Self::MultiplicationBus => 1 << 15, + } + } + + /// Converts a value to its index in the fixed table. + fn ix_by_value(&self, value: &[F]) -> Option { + let value = value[0]; + if self.is_fixed() { + assert!(self.is_member(value).unwrap()); + } + + match self { + Self::RangeCheck15 => Some(TryFrom::try_from(value.to_biguint()).unwrap()), + Self::RangeCheck4 => Some(TryFrom::try_from(value.to_biguint()).unwrap()), + Self::RangeCheck14Abs => { + if value < F::from(1u64 << 14) { + Some(TryFrom::try_from(value.to_biguint()).unwrap()) + } else { + Some( + TryFrom::try_from((value + F::from(2 * (1u64 << 14))).to_biguint()) + .unwrap(), + ) + } + } + Self::RangeCheck9Abs => { + if value < F::from(1u64 << 9) { + Some(TryFrom::try_from(value.to_biguint()).unwrap()) + } else { + Some( + TryFrom::try_from((value + F::from(2 * (1u64 << 9))).to_biguint()).unwrap(), + ) + } + } + Self::RangeCheckFfHighest(_) => Some(TryFrom::try_from(value.to_biguint()).unwrap()), + Self::MultiplicationBus => None, + } + } + + fn all_variants() -> Vec { + vec![ + Self::RangeCheck15, + Self::RangeCheck4, + Self::RangeCheck14Abs, + Self::RangeCheck9Abs, + Self::RangeCheckFfHighest(PhantomData), + Self::MultiplicationBus, + ] + } +} + +impl LookupTable { + fn entries_ff_highest(domain_d1_size: u64) -> Vec { + let top_modulus_f = + F::from_biguint(&crate::serialization::interpreter::ff_modulus_highest_limb::()) + .unwrap(); + (0..domain_d1_size) + .map(|i| { + if F::from(i) < top_modulus_f { + F::from(i) + } else { + F::zero() + } + }) + .collect() + } + + /// Provides a full list of entries for the given table. + pub fn entries(&self, domain_d1_size: u64) -> Option> { + assert!(domain_d1_size >= (1 << 15)); + match self { + Self::RangeCheck15 => Some((0..domain_d1_size).map(|i| F::from(i)).collect()), + Self::RangeCheck4 => Some( + (0..domain_d1_size) + .map(|i| if i < (1 << 4) { F::from(i) } else { F::zero() }) + .collect(), + ), + Self::RangeCheck14Abs => Some( + (0..domain_d1_size) + .map(|i| { + if i < (1 << 14) { + // [0,1,2 ... (1<<14)-1] + F::from(i) + } else if i < 2 * (1 << 14) { + // [-(i<<14),...-2,-1] + F::from(i) - F::from(2u64 * (1 << 14)) + } else { + F::zero() + } + }) + .collect(), + ), + + Self::RangeCheck9Abs => Some( + (0..domain_d1_size) + .map(|i| { + if i < (1 << 9) { + // [0,1,2 ... (1<<9)-1] + F::from(i) + } else if i < 2 * (1 << 9) { + // [-(i<<9),...-2,-1] + F::from(i) - F::from(2u64 * (1 << 9)) + } else { + F::zero() + } + }) + .collect(), + ), + Self::RangeCheckFfHighest(_) => Some(Self::entries_ff_highest::(domain_d1_size)), + _ => panic!("not possible"), + } + } + + /// Checks if a value is in a given table. + pub fn is_member(&self, value: F) -> Option { + match self { + Self::RangeCheck15 => Some(value.to_biguint() < BigUint::from(2u128.pow(15))), + Self::RangeCheck4 => Some(value.to_biguint() < BigUint::from(2u128.pow(4))), + Self::RangeCheck14Abs => { + Some(value < F::from(1u64 << 14) || value >= F::zero() - F::from(1u64 << 14)) + } + Self::RangeCheck9Abs => { + Some(value < F::from(1u64 << 9) || value >= F::zero() - F::from(1u64 << 9)) + } + Self::RangeCheckFfHighest(_) => { + let f_bui: BigUint = TryFrom::try_from(Ff::MODULUS).unwrap(); + let top_modulus_f: F = + F::from_biguint(&(f_bui >> ((N_LIMBS - 1) * LIMB_BITSIZE))).unwrap(); + Some(value < top_modulus_f) + } + Self::MultiplicationBus => None, + } + } +} + +pub type Lookup = Logup>; diff --git a/msm/src/serialization/mod.rs b/msm/src/serialization/mod.rs new file mode 100644 index 0000000000..2cdabdc9f0 --- /dev/null +++ b/msm/src/serialization/mod.rs @@ -0,0 +1,125 @@ +pub mod column; +pub mod interpreter; +pub mod lookups; + +/// The number of intermediate limbs of 4 bits required for the circuit +pub const N_INTERMEDIATE_LIMBS: usize = 20; + +#[cfg(test)] +mod tests { + use ark_ff::UniformRand; + use std::collections::BTreeMap; + + use crate::{ + circuit_design::{constraints::ConstraintBuilderEnv, witness::WitnessBuilderEnv}, + columns::ColumnIndexer, + logup::LookupTableID, + serialization::{ + column::{SerializationColumn, N_COL_SER, N_FSEL_SER}, + interpreter::{ + build_selectors, constrain_multiplication, deserialize_field_element, + limb_decompose_ff, serialization_circuit, + }, + lookups::LookupTable, + }, + Ff1, Fp, + }; + + type SerializationWitnessBuilderEnv = WitnessBuilderEnv< + Fp, + SerializationColumn, + { ::N_COL - N_FSEL_SER }, + { ::N_COL - N_FSEL_SER }, + 0, + N_FSEL_SER, + LookupTable, + >; + + #[test] + fn heavy_test_completeness() { + let mut rng = o1_utils::tests::make_test_rng(None); + // Must be at least 1 << 15 to support rangecheck15 + let domain_size: usize = 1 << 15; + + let mut witness_env = SerializationWitnessBuilderEnv::create(); + + let fixed_selectors = build_selectors(domain_size); + witness_env.set_fixed_selectors(fixed_selectors.to_vec()); + + let mut field_elements = vec![]; + + // FIXME: we do use always the same values here, because we have a + // constant check (X - c), different for each row. And there is no + // constant support/public input yet in the quotient polynomial. + let input_chal: Ff1 = ::rand(&mut rng); + let [input1, input2, input3]: [Fp; 3] = limb_decompose_ff::(&input_chal); + for _ in 0..domain_size { + field_elements.push([input1, input2, input3]) + } + + let constraints = { + let mut constraints_env = ConstraintBuilderEnv::>::create(); + deserialize_field_element(&mut constraints_env, field_elements[0].map(Into::into)); + constrain_multiplication(&mut constraints_env); + + // Sanity checks. + assert!(constraints_env.lookup_reads[&LookupTable::RangeCheck15].len() == (3 * 17 - 1)); + assert!(constraints_env.lookup_reads[&LookupTable::RangeCheck4].len() == 20); + assert!(constraints_env.lookup_reads[&LookupTable::RangeCheck9Abs].len() == 6); + assert!( + constraints_env.lookup_reads + [&LookupTable::RangeCheckFfHighest(std::marker::PhantomData)] + .len() + == 1 + ); + + constraints_env.get_constraints() + }; + + serialization_circuit(&mut witness_env, input_chal, field_elements, domain_size); + + let runtime_tables: BTreeMap<_, Vec>>> = + witness_env.get_runtime_tables(domain_size); + + // TODO remove this clone + // Fixed tables can be generated inside lookup_tables_data. Runtime should be generated here. + let multiplication_bus: Vec>> = runtime_tables + .get(&LookupTable::MultiplicationBus) + .unwrap() + .clone(); + + //assert!(multiplication_bus.len() == 2); + + let mut lookup_tables_data: BTreeMap, Vec>>> = BTreeMap::new(); + for table_id in LookupTable::::all_variants().into_iter() { + if table_id.is_fixed() { + lookup_tables_data.insert( + table_id, + vec![table_id + .entries(domain_size as u64) + .unwrap() + .into_iter() + .map(|x| vec![x]) + .collect()], + ); + } + } + lookup_tables_data.insert(LookupTable::MultiplicationBus, multiplication_bus); + let proof_inputs = witness_env.get_proof_inputs(domain_size, lookup_tables_data); + + crate::test::test_completeness_generic::< + { N_COL_SER - N_FSEL_SER }, + { N_COL_SER - N_FSEL_SER }, + 0, + N_FSEL_SER, + LookupTable, + _, + >( + constraints, + Box::new(fixed_selectors), + proof_inputs, + domain_size, + &mut rng, + ); + } +} diff --git a/msm/src/test/generic.rs b/msm/src/test/generic.rs new file mode 100644 index 0000000000..ce09e8ae32 --- /dev/null +++ b/msm/src/test/generic.rs @@ -0,0 +1,272 @@ +/// Generic test runners for prover/verifier. +use crate::{ + expr::E, logup::LookupTableID, lookups::LookupTableIDs, proof::ProofInputs, prover::prove, + verifier::verify, witness::Witness, BaseSponge, Fp, OpeningProof, ScalarSponge, BN254, +}; +use ark_ec::AffineRepr; +use kimchi::circuits::domains::EvaluationDomains; +use poly_commitment::kzg::PairingSRS; +use rand::{CryptoRng, RngCore}; + +/// No lookups, no selectors, only witness column. `N_WIT == N_REL`. +pub fn test_completeness_generic_only_relation( + constraints: Vec>, + evaluations: Witness>, + domain_size: usize, + rng: &mut RNG, +) where + RNG: RngCore + CryptoRng, +{ + test_completeness_generic_no_lookups::( + constraints, + Box::new([]), + evaluations, + domain_size, + rng, + ) +} + +pub fn test_completeness_generic_no_lookups< + const N_WIT: usize, + const N_REL: usize, + const N_DSEL: usize, + const N_FSEL: usize, + RNG, +>( + constraints: Vec>, + fixed_selectors: Box<[Vec; N_FSEL]>, + evaluations: Witness>, + domain_size: usize, + rng: &mut RNG, +) where + RNG: RngCore + CryptoRng, +{ + let proof_inputs = ProofInputs { + evaluations, + logups: Default::default(), + }; + test_completeness_generic::( + constraints, + fixed_selectors, + proof_inputs, + domain_size, + rng, + ) +} + +// Generic function to test with different circuits with the generic prover/verifier. +// It doesn't use the interpreter to build the witness and compute the constraints. +pub fn test_completeness_generic< + const N_WIT: usize, + const N_REL: usize, + const N_DSEL: usize, + const N_FSEL: usize, + LT: LookupTableID, + RNG, +>( + constraints: Vec>, + fixed_selectors: Box<[Vec; N_FSEL]>, + proof_inputs: ProofInputs, + domain_size: usize, + rng: &mut RNG, +) where + RNG: RngCore + CryptoRng, +{ + let domain = EvaluationDomains::::create(domain_size).unwrap(); + + let srs: PairingSRS = crate::precomputed_srs::get_bn254_srs(domain); + + let proof = + prove::<_, OpeningProof, BaseSponge, ScalarSponge, _, N_WIT, N_REL, N_DSEL, N_FSEL, LT>( + domain, + &srs, + &constraints, + fixed_selectors.clone(), + proof_inputs.clone(), + rng, + ) + .unwrap(); + + { + // Checking the proof size. We should have: + // - N commitments for the columns + // - N evaluations for the columns + // - MAX_DEGREE - 1 commitments for the constraints (quotient polynomial) + // TODO: add lookups + + // We check there is always only one commitment chunk + (&proof.proof_comms.witness_comms) + .into_iter() + .for_each(|x| assert_eq!(x.len(), 1)); + // This equality is therefore trivial, but still doing it. + assert!( + (&proof.proof_comms.witness_comms) + .into_iter() + .fold(0, |acc, x| acc + x.len()) + == N_WIT + ); + // Checking that none of the commitments are zero + (&proof.proof_comms.witness_comms) + .into_iter() + .for_each(|v| v.chunks.iter().for_each(|x| assert!(!x.is_zero()))); + + // Checking the number of chunks of the quotient polynomial + let max_degree = { + if proof_inputs.logups.is_empty() { + constraints + .iter() + .map(|expr| expr.degree(1, 0)) + .max() + .unwrap_or(0) + } else { + 8 + } + }; + + if max_degree == 1 { + assert_eq!(proof.proof_comms.t_comm.len(), 1); + } else { + assert_eq!(proof.proof_comms.t_comm.len(), max_degree as usize - 1); + } + } + + let verifies = + verify::<_, OpeningProof, BaseSponge, ScalarSponge, N_WIT, N_REL, N_DSEL, N_FSEL, 0, LT>( + domain, + &srs, + &constraints, + fixed_selectors, + &proof, + Witness::zero_vec(domain_size), + ); + assert!(verifies) +} + +pub fn test_soundness_generic< + const N_WIT: usize, + const N_REL: usize, + const N_DSEL: usize, + const N_FSEL: usize, + LT: LookupTableID, + RNG, +>( + constraints: Vec>, + fixed_selectors: Box<[Vec; N_FSEL]>, + proof_inputs: ProofInputs, + proof_inputs_prime: ProofInputs, + domain_size: usize, + rng: &mut RNG, +) where + RNG: RngCore + CryptoRng, +{ + let domain = EvaluationDomains::::create(domain_size).unwrap(); + + let srs: PairingSRS = crate::precomputed_srs::get_bn254_srs(domain); + + // generate the proof + let proof = + prove::<_, OpeningProof, BaseSponge, ScalarSponge, _, N_WIT, N_REL, N_DSEL, N_FSEL, LT>( + domain, + &srs, + &constraints, + fixed_selectors.clone(), + proof_inputs, + rng, + ) + .unwrap(); + + // generate another (prime) proof + let proof_prime = + prove::<_, OpeningProof, BaseSponge, ScalarSponge, _, N_WIT, N_REL, N_DSEL, N_FSEL, LT>( + domain, + &srs, + &constraints, + fixed_selectors.clone(), + proof_inputs_prime, + rng, + ) + .unwrap(); + + // Swap the opening proof. The verification should fail. + { + let mut proof_clone = proof.clone(); + proof_clone.opening_proof = proof_prime.opening_proof; + let verifies = verify::< + _, + OpeningProof, + BaseSponge, + ScalarSponge, + N_WIT, + N_REL, + N_DSEL, + N_FSEL, + 0, + LT, + >( + domain, + &srs, + &constraints, + fixed_selectors.clone(), + &proof_clone, + Witness::zero_vec(domain_size), + ); + assert!(!verifies, "Proof with a swapped opening must fail"); + } + + // Changing at least one commitment in the proof should fail the verification. + // TODO: improve me by swapping only one commitments. It should be + // easier when an index trait is implemented. + { + let mut proof_clone = proof.clone(); + proof_clone.proof_comms = proof_prime.proof_comms; + let verifies = verify::< + _, + OpeningProof, + BaseSponge, + ScalarSponge, + N_WIT, + N_REL, + N_DSEL, + N_FSEL, + 0, + LT, + >( + domain, + &srs, + &constraints, + fixed_selectors.clone(), + &proof_clone, + Witness::zero_vec(domain_size), + ); + assert!(!verifies, "Proof with a swapped commitment must fail"); + } + + // Changing at least one evaluation at zeta in the proof should fail + // the verification. + // TODO: improve me by swapping only one evaluation at \zeta. It should be + // easier when an index trait is implemented. + { + let mut proof_clone = proof.clone(); + proof_clone.proof_evals.witness_evals = proof_prime.proof_evals.witness_evals; + let verifies = verify::< + _, + OpeningProof, + BaseSponge, + ScalarSponge, + N_WIT, + N_REL, + N_DSEL, + N_FSEL, + 0, + LT, + >( + domain, + &srs, + &constraints, + fixed_selectors, + &proof_clone, + Witness::zero_vec(domain_size), + ); + assert!(!verifies, "Proof with a swapped witness eval must fail"); + } +} diff --git a/msm/src/test/logup.rs b/msm/src/test/logup.rs new file mode 100644 index 0000000000..0312594b41 --- /dev/null +++ b/msm/src/test/logup.rs @@ -0,0 +1,79 @@ +#[cfg(test)] +mod tests { + use crate::{ + lookups::{Lookup, LookupTableIDs}, + proof::ProofInputs, + prover::prove, + verifier::verify, + witness::Witness, + BaseSponge, Fp, OpeningProof, ScalarSponge, BN254, + }; + use ark_ff::UniformRand; + use kimchi::circuits::domains::EvaluationDomains; + use poly_commitment::{kzg::PairingSRS, SRS as _}; + + // Number of columns + const LOOKUP_TEST_N_COL: usize = 10; + + #[test] + #[ignore] + fn test_soundness_logup() { + let mut rng = o1_utils::tests::make_test_rng(None); + + // We generate two different witness and two different proofs. + let domain_size = 1 << 8; + let domain = EvaluationDomains::::create(domain_size).unwrap(); + + let srs: PairingSRS = { PairingSRS::create(domain.d1.size as usize) }; + srs.full_srs.get_lagrange_basis(domain.d1); + + let mut inputs = ProofInputs::random(domain); + let constraints = vec![]; + // Take one random f_i (FIXME: taking first one for now) + let test_table_id = *inputs.logups.first_key_value().unwrap().0; + let looked_up_values = inputs.logups.get_mut(&test_table_id).unwrap().f[0].clone(); + // We change a random looked up element (FIXME: first one for now) + let wrong_looked_up_value = Lookup { + table_id: looked_up_values[0].table_id, + numerator: looked_up_values[0].numerator, + value: vec![Fp::rand(&mut rng)], + }; + // Overwriting the first looked up value + inputs.logups.get_mut(&test_table_id).unwrap().f[0][0] = wrong_looked_up_value; + // generate the proof + let proof = prove::< + _, + OpeningProof, + BaseSponge, + ScalarSponge, + _, + LOOKUP_TEST_N_COL, + LOOKUP_TEST_N_COL, + 0, + 0, + LookupTableIDs, + >(domain, &srs, &constraints, Box::new([]), inputs, &mut rng) + .unwrap(); + let verifies = verify::< + _, + OpeningProof, + BaseSponge, + ScalarSponge, + LOOKUP_TEST_N_COL, + LOOKUP_TEST_N_COL, + 0, + 0, + 0, + LookupTableIDs, + >( + domain, + &srs, + &constraints, + Box::new([]), + &proof, + Witness::zero_vec(domain_size), + ); + // FIXME: At the moment, it does verify. It should not. We are missing constraints. + assert!(!verifies); + } +} diff --git a/msm/src/test/mod.rs b/msm/src/test/mod.rs new file mode 100644 index 0000000000..ddb85a9ac0 --- /dev/null +++ b/msm/src/test/mod.rs @@ -0,0 +1,7 @@ +pub mod generic; +pub mod logup; +pub mod proof_system; +pub mod test_circuit; + +// Reexport all the generic test functions. +pub use generic::*; diff --git a/msm/src/test/proof_system.rs b/msm/src/test/proof_system.rs new file mode 100644 index 0000000000..a9ea335cb1 --- /dev/null +++ b/msm/src/test/proof_system.rs @@ -0,0 +1,402 @@ +/// Tests for the proof system itself, targeting prover and verifier. + +#[cfg(test)] +mod tests { + + use crate::{ + columns::Column, + expr::{ + E, {self}, + }, + test::test_completeness_generic_only_relation, + witness::Witness, + Fp, + }; + use ark_ff::{Field, One, UniformRand}; + use kimchi::circuits::expr::{ConstantExpr, ConstantTerm}; + + #[cfg(dead_code)] + fn test_soundness_generic( + constraints: Vec>, + evaluations: Witness>, + domain_size: usize, + rng: &mut RNG, + ) where + RNG: RngCore + CryptoRng, + { + let domain = EvaluationDomains::::create(domain_size).unwrap(); + + let mut srs: PairingSRS = { + // Trusted setup toxic waste + let x = Fp::rand(rng); + PairingSRS::create(x, domain.d1.size as usize) + }; + srs.get_lagrange_basis(domain.d1); + + let mut evaluations_prime = evaluations.clone(); + { + let i = rng.gen_range(0..N); + let j = rng.gen_range(0..domain_size); + evaluations_prime.cols[i][j] = Fp::rand(rng); + } + + let proof_inputs = ProofInputs { + evaluations: evaluations_prime, + logups: vec![], + }; + + let proof = + prove::<_, OpeningProof, BaseSponge, ScalarSponge, Column, _, N, LookupTableIDs>( + domain, + &srs, + &constraints, + proof_inputs, + rng, + ) + .unwrap(); + let verifies = verify::<_, OpeningProof, BaseSponge, ScalarSponge, N, 0, LookupTableIDs>( + domain, + &srs, + &constraints, + &proof, + Witness::zero_vec(domain_size), + ); + assert!(!verifies) + } + + // Test a constraint of degree one: X_{0} - X_{1} + #[test] + fn test_completeness_degree_one() { + let mut rng = o1_utils::tests::make_test_rng(None); + const N: usize = 2; + let domain_size = 1 << 8; + + let constraints = { + let x0 = expr::curr_cell::(Column::Relation(0)); + let x1 = expr::curr_cell::(Column::Relation(1)); + vec![x0.clone() - x1] + }; + + let random_x0s: Vec = (0..domain_size).map(|_| Fp::rand(&mut rng)).collect(); + let exp_x1 = random_x0s.clone(); + let witness: Witness> = Witness { + cols: Box::new([random_x0s, exp_x1]), + }; + + test_completeness_generic_only_relation::( + constraints.clone(), + witness.clone(), + domain_size, + &mut rng, + ); + // FIXME: we would want to allow the prover to make a proof, but the verification must fail. + // TODO: Refactorize code in prover to handle a degug or add an adversarial prover. + // test_soundness_generic(constraints, witness, domain_size, &mut rng); + } + + // Test a constraint of degree two: X_{0} * X_{0} - X_{1} - X_{2} + #[test] + fn test_completeness_degree_two() { + let mut rng = o1_utils::tests::make_test_rng(None); + const N: usize = 3; + let domain_size = 1 << 8; + + let constraints = { + let x0 = expr::curr_cell::(Column::Relation(0)); + let x1 = expr::curr_cell::(Column::Relation(1)); + let x2 = expr::curr_cell::(Column::Relation(2)); + vec![x0.clone() * x0.clone() - x1.clone() - x2.clone()] + }; + + let random_x0s: Vec = (0..domain_size).map(|_| Fp::rand(&mut rng)).collect(); + let random_x1s: Vec = (0..domain_size).map(|_| Fp::rand(&mut rng)).collect(); + let exp_x2 = random_x0s + .iter() + .zip(random_x1s.iter()) + .map(|(x0, x1)| (*x0) * (*x0) - x1) + .collect::>(); + let witness: Witness> = Witness { + cols: Box::new([random_x0s, random_x1s, exp_x2]), + }; + + test_completeness_generic_only_relation::( + constraints.clone(), + witness.clone(), + domain_size, + &mut rng, + ); + // FIXME: we would want to allow the prover to make a proof, but the verification must fail. + // TODO: Refactorize code in prover to handle a degug or add an adversarial prover. + // test_soundness_generic(constraints, witness, domain_size, &mut rng); + } + + // Test a constraint of degree three: + // X_{0} * X_{0} * X_{0} \ + // - 42 * X_{1} * X_{2} \ + // + X_{3} + #[test] + fn test_completeness_degree_three() { + let mut rng = o1_utils::tests::make_test_rng(None); + const N: usize = 4; + let domain_size = 1 << 8; + + let constraints = { + let x0 = expr::curr_cell::(Column::Relation(0)); + let x1 = expr::curr_cell::(Column::Relation(1)); + let x2 = expr::curr_cell::(Column::Relation(2)); + let x3 = expr::curr_cell::(Column::Relation(3)); + vec![ + x0.clone() * x0.clone() * x0.clone() - E::from(42) * x1.clone() * x2.clone() + + x3.clone(), + ] + }; + + let random_x0s: Vec = (0..domain_size).map(|_| Fp::rand(&mut rng)).collect(); + let random_x1s: Vec = (0..domain_size).map(|_| Fp::rand(&mut rng)).collect(); + let random_x2s: Vec = (0..domain_size).map(|_| Fp::rand(&mut rng)).collect(); + let exp_x3 = random_x0s + .iter() + .zip(random_x1s.iter()) + .zip(random_x2s.iter()) + .map(|((x0, x1), x2)| -((*x0) * (*x0) * (*x0) - Fp::from(42) * (*x1) * (*x2))) + .collect::>(); + let witness: Witness> = Witness { + cols: Box::new([random_x0s, random_x1s, random_x2s, exp_x3]), + }; + + test_completeness_generic_only_relation::( + constraints.clone(), + witness.clone(), + domain_size, + &mut rng, + ); + // FIXME: we would want to allow the prover to make a proof, but the verification must fail. + // TODO: Refactorize code in prover to handle a degug or add an adversarial prover. + // test_soundness_generic(constraints, witness, domain_size, &mut rng); + } + + #[test] + // X_{0} * (X_{1} * X_{2} * X_{3} + 1) + fn test_completeness_degree_four() { + let mut rng = o1_utils::tests::make_test_rng(None); + const N: usize = 4; + let domain_size = 1 << 8; + + let constraints = { + let x0 = expr::curr_cell::(Column::Relation(0)); + let x1 = expr::curr_cell::(Column::Relation(1)); + let x2 = expr::curr_cell::(Column::Relation(2)); + let x3 = expr::curr_cell::(Column::Relation(3)); + let one = ConstantExpr::from(ConstantTerm::Literal(Fp::one())); + vec![x0.clone() * (x1.clone() * x2.clone() * x3.clone() + E::constant(one))] + }; + + let random_x0s: Vec = (0..domain_size).map(|_| Fp::rand(&mut rng)).collect(); + let random_x1s: Vec = (0..domain_size).map(|_| Fp::rand(&mut rng)).collect(); + let random_x2s: Vec = (0..domain_size).map(|_| Fp::rand(&mut rng)).collect(); + let exp_x3 = random_x1s + .iter() + .zip(random_x2s.iter()) + .map(|(x1, x2)| -Fp::one() / (*x1 * *x2)) + .collect::>(); + let witness: Witness> = Witness { + cols: Box::new([random_x0s, random_x1s, random_x2s, exp_x3]), + }; + + test_completeness_generic_only_relation::( + constraints.clone(), + witness.clone(), + domain_size, + &mut rng, + ); + // FIXME: we would want to allow the prover to make a proof, but the verification must fail. + // TODO: Refactorize code in prover to handle a degug or add an adversarial prover. + // test_soundness_generic(constraints, witness, domain_size, &mut rng); + } + + #[test] + // X_{0}^5 + X_{1} + fn test_completeness_degree_five() { + let mut rng = o1_utils::tests::make_test_rng(None); + const N: usize = 2; + let domain_size = 1 << 8; + + let constraints = { + let x0 = expr::curr_cell::(Column::Relation(0)); + let x1 = expr::curr_cell::(Column::Relation(1)); + vec![x0.clone() * x0.clone() * x0.clone() * x0.clone() * x0.clone() + x1.clone()] + }; + + let random_x0s: Vec = (0..domain_size).map(|_| Fp::rand(&mut rng)).collect(); + let exp_x1 = random_x0s + .iter() + .map(|x0| -*x0 * *x0 * *x0 * *x0 * *x0) + .collect::>(); + let witness: Witness> = Witness { + cols: Box::new([random_x0s, exp_x1]), + }; + + test_completeness_generic_only_relation::( + constraints.clone(), + witness.clone(), + domain_size, + &mut rng, + ); + // FIXME: we would want to allow the prover to make a proof, but the verification must fail. + // TODO: Refactorize code in prover to handle a degug or add an adversarial prover. + // test_soundness_generic(constraints, witness, domain_size, &mut rng); + } + + #[test] + // X_{0}^3 + X_{1} AND X_{2}^2 - 3 X_{3} + fn test_completeness_different_constraints_different_degrees() { + let mut rng = o1_utils::tests::make_test_rng(None); + const N: usize = 4; + let domain_size = 1 << 8; + + let constraints = { + let x0 = expr::curr_cell::(Column::Relation(0)); + let x1 = expr::curr_cell::(Column::Relation(1)); + let cst1 = x0.clone() * x0.clone() * x0.clone() + x1.clone(); + let x2 = expr::curr_cell::(Column::Relation(2)); + let x3 = expr::curr_cell::(Column::Relation(3)); + let three = ConstantExpr::from(ConstantTerm::Literal(Fp::from(3))); + let cst2 = x2.clone() * x2.clone() - E::constant(three) * x3.clone(); + vec![cst1, cst2] + }; + + let random_x0s: Vec = (0..domain_size).map(|_| Fp::rand(&mut rng)).collect(); + let exp_x1 = random_x0s + .iter() + .map(|x0| -*x0 * *x0 * *x0) + .collect::>(); + let random_x2s: Vec = (0..domain_size).map(|_| Fp::rand(&mut rng)).collect(); + let exp_x3: Vec = random_x2s + .iter() + .map(|x2| (Fp::one() / Fp::from(3)) * x2 * x2) + .collect::>(); + let witness: Witness> = Witness { + cols: Box::new([random_x0s, exp_x1, random_x2s, exp_x3]), + }; + + test_completeness_generic_only_relation::( + constraints.clone(), + witness.clone(), + domain_size, + &mut rng, + ); + // FIXME: we would want to allow the prover to make a proof, but the verification must fail. + // TODO: Refactorize code in prover to handle a degug or add an adversarial prover. + // test_soundness_generic(constraints, witness, domain_size, &mut rng); + } + + #[test] + // X_{0}^6 + X_{1}^4 - X_{2}^3 - 2 X_{3} + fn test_completeness_degree_six() { + let mut rng = o1_utils::tests::make_test_rng(None); + const N: usize = 4; + let domain_size = 1 << 8; + + let constraints = { + let x0 = expr::curr_cell::(Column::Relation(0)); + let x1 = expr::curr_cell::(Column::Relation(1)); + let x2 = expr::curr_cell::(Column::Relation(2)); + let x3 = expr::curr_cell::(Column::Relation(3)); + let x0_square = x0.clone() * x0.clone(); + let x1_square = x1.clone() * x1.clone(); + let x2_square = x2.clone() * x2.clone(); + let two = ConstantExpr::from(ConstantTerm::Literal(Fp::from(2))); + vec![ + x0_square.clone() * x0_square.clone() * x0_square.clone() + + x1_square.clone() * x1_square.clone() + - x2_square.clone() * x2.clone() + - E::constant(two) * x3.clone(), + ] + }; + + let random_x0s: Vec = (0..domain_size).map(|_| Fp::rand(&mut rng)).collect(); + let random_x1s: Vec = (0..domain_size).map(|_| Fp::rand(&mut rng)).collect(); + let random_x2s: Vec = (0..domain_size).map(|_| Fp::rand(&mut rng)).collect(); + let exp_x3 = random_x0s + .iter() + .zip(random_x1s.iter()) + .zip(random_x2s.iter()) + .map(|((x0, x1), x2)| { + let x0_square = x0.clone().square(); + let x1_square = x1.clone().square(); + let x2_square = x2.clone().square(); + let x0_six = x0_square * x0_square * x0_square; + let x1_four = x1_square * x1_square; + (Fp::one() / Fp::from(2)) * (x0_six + x1_four - x2_square * x2) + }) + .collect::>(); + let witness: Witness> = Witness { + cols: Box::new([random_x0s, random_x1s, random_x2s, exp_x3]), + }; + + test_completeness_generic_only_relation::( + constraints.clone(), + witness.clone(), + domain_size, + &mut rng, + ); + // FIXME: we would want to allow the prover to make a proof, but the verification must fail. + // TODO: Refactorize code in prover to handle a degug or add an adversarial prover. + // test_soundness_generic(constraints, witness, domain_size, &mut rng); + } + + #[test] + // X_{0}^7 + X_{1}^4 - X_{2}^3 - 2 X_{3} + fn test_completeness_degree_seven() { + let mut rng = o1_utils::tests::make_test_rng(None); + const N: usize = 4; + let domain_size = 1 << 8; + + let constraints = { + let x0 = expr::curr_cell::(Column::Relation(0)); + let x1 = expr::curr_cell::(Column::Relation(1)); + let x2 = expr::curr_cell::(Column::Relation(2)); + let x3 = expr::curr_cell::(Column::Relation(3)); + let x0_square = x0.clone() * x0.clone(); + let x1_square = x1.clone() * x1.clone(); + let x2_square = x2.clone() * x2.clone(); + let two = ConstantExpr::from(ConstantTerm::Literal(Fp::from(2))); + vec![ + x0_square.clone() * x0_square.clone() * x0_square.clone() * x0.clone() + + x1_square.clone() * x1_square.clone() + - x2_square.clone() * x2.clone() + - E::constant(two) * x3.clone(), + ] + }; + + let random_x0s: Vec = (0..domain_size).map(|_| Fp::rand(&mut rng)).collect(); + let random_x1s: Vec = (0..domain_size).map(|_| Fp::rand(&mut rng)).collect(); + let random_x2s: Vec = (0..domain_size).map(|_| Fp::rand(&mut rng)).collect(); + let exp_x3 = random_x0s + .iter() + .zip(random_x1s.iter()) + .zip(random_x2s.iter()) + .map(|((x0, x1), x2)| { + let x0_square = x0.clone().square(); + let x1_square = x1.clone().square(); + let x2_square = x2.clone().square(); + let x0_six = x0_square * x0_square * x0_square * x0; + let x1_four = x1_square * x1_square; + (Fp::one() / Fp::from(2)) * (x0_six + x1_four - x2_square * x2) + }) + .collect::>(); + let witness: Witness> = Witness { + cols: Box::new([random_x0s, random_x1s, random_x2s, exp_x3]), + }; + + test_completeness_generic_only_relation::( + constraints.clone(), + witness.clone(), + domain_size, + &mut rng, + ); + // FIXME: we would want to allow the prover to make a proof, but the verification must fail. + // TODO: Refactorize code in prover to handle a degug or add an adversarial prover. + // test_soundness_generic(constraints, witness, domain_size, &mut rng); + } +} diff --git a/msm/src/test/test_circuit/columns.rs b/msm/src/test/test_circuit/columns.rs new file mode 100644 index 0000000000..a417dba1e5 --- /dev/null +++ b/msm/src/test/test_circuit/columns.rs @@ -0,0 +1,41 @@ +use crate::{ + columns::{Column, ColumnIndexer}, + N_LIMBS, +}; + +/// Number of columns in the test circuits, including fixed selectors. +pub const N_COL_TEST: usize = 4 * N_LIMBS + N_FSEL_TEST; + +/// Number of fixed selectors in the test circuit. +pub const N_FSEL_TEST: usize = 3; + +/// Test column indexer. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub enum TestColumn { + A(usize), + B(usize), + C(usize), + D(usize), + FixedSel1, + FixedSel2, + FixedSel3, +} + +impl ColumnIndexer for TestColumn { + const N_COL: usize = N_COL_TEST; + fn to_column(self) -> Column { + let to_column_inner = |offset, i| { + assert!(i < N_LIMBS); + Column::Relation(N_LIMBS * offset + i) + }; + match self { + TestColumn::A(i) => to_column_inner(0, i), + TestColumn::B(i) => to_column_inner(1, i), + TestColumn::C(i) => to_column_inner(2, i), + TestColumn::D(i) => to_column_inner(3, i), + TestColumn::FixedSel1 => Column::FixedSelector(0), + TestColumn::FixedSel2 => Column::FixedSelector(1), + TestColumn::FixedSel3 => Column::FixedSelector(2), + } + } +} diff --git a/msm/src/test/test_circuit/interpreter.rs b/msm/src/test/test_circuit/interpreter.rs new file mode 100644 index 0000000000..afb72942fd --- /dev/null +++ b/msm/src/test/test_circuit/interpreter.rs @@ -0,0 +1,370 @@ +use crate::{ + circuit_design::{ColAccessCap, ColWriteCap, DirectWitnessCap, LookupCap}, + serialization::interpreter::limb_decompose_ff, + test::test_circuit::{ + columns::{TestColumn, N_FSEL_TEST}, + lookups::LookupTable, + }, + LIMB_BITSIZE, N_LIMBS, +}; +use ark_ff::{Field, PrimeField, Zero}; + +fn fill_limbs_a_b< + F: PrimeField, + Ff: PrimeField, + Env: ColAccessCap + ColWriteCap, +>( + env: &mut Env, + a: Ff, + b: Ff, +) -> ([Env::Variable; N_LIMBS], [Env::Variable; N_LIMBS]) { + let a_limbs: [Env::Variable; N_LIMBS] = + limb_decompose_ff::(&a).map(Env::constant); + let b_limbs: [Env::Variable; N_LIMBS] = + limb_decompose_ff::(&b).map(Env::constant); + a_limbs.iter().enumerate().for_each(|(i, var)| { + env.write_column(TestColumn::A(i), var); + }); + b_limbs.iter().enumerate().for_each(|(i, var)| { + env.write_column(TestColumn::B(i), var); + }); + (a_limbs, b_limbs) +} + +/// A constraint function for A + B - C that reads values from limbs A +/// and B, and additionally returns resulting value in C. +pub fn constrain_addition>(env: &mut Env) { + let a_limbs: [Env::Variable; N_LIMBS] = + core::array::from_fn(|i| Env::read_column(env, TestColumn::A(i))); + let b_limbs: [Env::Variable; N_LIMBS] = + core::array::from_fn(|i| Env::read_column(env, TestColumn::B(i))); + // fix cloning + let c_limbs: [Env::Variable; N_LIMBS] = + core::array::from_fn(|i| Env::read_column(env, TestColumn::C(i))); + let equation: [Env::Variable; N_LIMBS] = + core::array::from_fn(|i| a_limbs[i].clone() * b_limbs[i].clone() - c_limbs[i].clone()); + equation.iter().for_each(|var| { + env.assert_zero(var.clone()); + }); +} + +/// Circuit generator function for A + B - C, with D = 0. +pub fn test_addition< + F: PrimeField, + Ff: PrimeField, + Env: ColAccessCap + ColWriteCap, +>( + env: &mut Env, + a: Ff, + b: Ff, +) { + let (a_limbs, b_limbs) = fill_limbs_a_b(env, a, b); + + (0..N_LIMBS).for_each(|i| { + env.write_column(TestColumn::C(i), &(a_limbs[i].clone() + b_limbs[i].clone())); + env.write_column(TestColumn::D(i), &Env::constant(Zero::zero())); + }); + + constrain_addition(env); +} + +/// A constraint function for A * B - D that reads values from limbs A +/// and B, and multiplicationally returns resulting value in D. +pub fn constrain_multiplication>(env: &mut Env) { + let a_limbs: [Env::Variable; N_LIMBS] = + core::array::from_fn(|i| Env::read_column(env, TestColumn::A(i))); + let b_limbs: [Env::Variable; N_LIMBS] = + core::array::from_fn(|i| Env::read_column(env, TestColumn::B(i))); + // fix cloning + let d_limbs: [Env::Variable; N_LIMBS] = + core::array::from_fn(|i| Env::read_column(env, TestColumn::D(i))); + let equation: [Env::Variable; N_LIMBS] = + core::array::from_fn(|i| a_limbs[i].clone() * b_limbs[i].clone() - d_limbs[i].clone()); + equation.iter().for_each(|var| { + env.assert_zero(var.clone()); + }); +} + +/// Circuit generator function for A * B - C, with D = 0. +pub fn test_multiplication< + F: PrimeField, + Ff: PrimeField, + Env: ColAccessCap + ColWriteCap, +>( + env: &mut Env, + a: Ff, + b: Ff, +) { + let (a_limbs, b_limbs) = fill_limbs_a_b(env, a, b); + + (0..N_LIMBS).for_each(|i| { + env.write_column(TestColumn::D(i), &(a_limbs[i].clone() * b_limbs[i].clone())); + env.write_column(TestColumn::C(i), &Env::constant(Zero::zero())); + }); + + constrain_multiplication(env); +} + +/// A constraint function for A * B - D that reads values from limbs A +/// and B, and multiplication_constally returns resulting value in D. +pub fn constrain_test_const>( + env: &mut Env, + constant: F, +) { + let a0 = Env::read_column(env, TestColumn::A(0)); + let b0 = Env::read_column(env, TestColumn::B(0)); + let equation = a0.clone() * b0.clone() - Env::constant(constant); + env.assert_zero(equation.clone()); +} + +/// Circuit generator function for A_0 * B_0 - const, with every other column = 0 +pub fn test_const + ColWriteCap>( + env: &mut Env, + a: F, + b: F, + constant: F, +) { + env.write_column(TestColumn::A(0), &Env::constant(a)); + env.write_column(TestColumn::B(0), &Env::constant(b)); + + constrain_test_const(env, constant); +} + +/// A constraint function for A_0 + B_0 - FIXED_SEL_1 +pub fn constrain_test_fixed_sel>(env: &mut Env) { + let a0 = Env::read_column(env, TestColumn::A(0)); + let b0 = Env::read_column(env, TestColumn::B(0)); + let fixed_e = Env::read_column(env, TestColumn::FixedSel1); + let equation = a0.clone() + b0.clone() - fixed_e; + env.assert_zero(equation.clone()); +} + +/// A constraint function for A_0^7 + B_0 - FIXED_SEL_1 +pub fn constrain_test_fixed_sel_degree_7>( + env: &mut Env, +) { + let a0 = Env::read_column(env, TestColumn::A(0)); + let b0 = Env::read_column(env, TestColumn::B(0)); + let fixed_e = Env::read_column(env, TestColumn::FixedSel1); + let a0_2 = a0.clone() * a0.clone(); + let a0_4 = a0_2.clone() * a0_2.clone(); + let a0_6 = a0_4.clone() * a0_2.clone(); + let a0_7 = a0_6.clone() * a0.clone(); + let equation = a0_7.clone() + b0.clone() - fixed_e; + env.assert_zero(equation.clone()); +} + +/// A constraint function for 3 * A_0^7 + 42 * B_0 - FIXED_SEL_1 +pub fn constrain_test_fixed_sel_degree_7_with_constants< + F: PrimeField, + Env: ColAccessCap, +>( + env: &mut Env, +) { + let a0 = Env::read_column(env, TestColumn::A(0)); + let fourty_two = Env::constant(F::from(42u32)); + let three = Env::constant(F::from(3u32)); + let b0 = Env::read_column(env, TestColumn::B(0)); + let fixed_e = Env::read_column(env, TestColumn::FixedSel1); + let a0_2 = a0.clone() * a0.clone(); + let a0_4 = a0_2.clone() * a0_2.clone(); + let a0_6 = a0_4.clone() * a0_2.clone(); + let a0_7 = a0_6.clone() * a0.clone(); + let equation = three * a0_7.clone() + fourty_two * b0.clone() - fixed_e; + env.assert_zero(equation.clone()); +} + +// NB: Assumes non-standard selectors +/// A constraint function for 3 * A_0^7 + B_0 * FIXED_SEL_3 +pub fn constrain_test_fixed_sel_degree_7_mul_witness< + F: PrimeField, + Env: ColAccessCap, +>( + env: &mut Env, +) { + let a0 = Env::read_column(env, TestColumn::A(0)); + let three = Env::constant(F::from(3u32)); + let b0 = Env::read_column(env, TestColumn::B(0)); + let fixed_e = Env::read_column(env, TestColumn::FixedSel3); + let a0_2 = a0.clone() * a0.clone(); + let a0_4 = a0_2.clone() * a0_2.clone(); + let a0_6 = a0_4.clone() * a0_2.clone(); + let a0_7 = a0_6.clone() * a0.clone(); + let equation = three * a0_7.clone() + b0.clone() * fixed_e; + env.assert_zero(equation.clone()); +} + +/// Circuit generator function for A_0 + B_0 - FIXED_SEL_1. +pub fn test_fixed_sel< + F: PrimeField, + Env: ColAccessCap + ColWriteCap, +>( + env: &mut Env, + a: F, +) { + env.write_column(TestColumn::A(0), &Env::constant(a)); + let fixed_e = env.read_column(TestColumn::FixedSel1); + env.write_column(TestColumn::B(0), &(fixed_e - Env::constant(a))); + + constrain_test_fixed_sel(env); +} + +/// Circuit generator function for A_0^7 + B_0 - FIXED_SEL_1. +pub fn test_fixed_sel_degree_7< + F: PrimeField, + Env: ColAccessCap + ColWriteCap, +>( + env: &mut Env, + a: F, +) { + env.write_column(TestColumn::A(0), &Env::constant(a)); + let a_2 = a * a; + let a_4 = a_2 * a_2; + let a_6 = a_4 * a_2; + let a_7 = a_6 * a; + let fixed_e = env.read_column(TestColumn::FixedSel1); + env.write_column(TestColumn::B(0), &(fixed_e - Env::constant(a_7))); + constrain_test_fixed_sel_degree_7(env); +} + +/// Circuit generator function for 3 * A_0^7 + 42 * B_0 - FIXED_SEL_1. +pub fn test_fixed_sel_degree_7_with_constants< + F: PrimeField, + Env: ColAccessCap + ColWriteCap, +>( + env: &mut Env, + a: F, +) { + env.write_column(TestColumn::A(0), &Env::constant(a)); + let a_2 = a * a; + let a_4 = a_2 * a_2; + let a_6 = a_4 * a_2; + let a_7 = a_6 * a; + let fixed_e = env.read_column(TestColumn::FixedSel1); + let inv_42 = F::from(42u32).inverse().unwrap(); + let three = F::from(3u32); + env.write_column( + TestColumn::B(0), + &((fixed_e - Env::constant(three) * Env::constant(a_7)) * Env::constant(inv_42)), + ); + constrain_test_fixed_sel_degree_7_with_constants(env); +} + +// NB: Assumes non-standard selectors +/// Circuit generator function for 3 * A_0^7 + B_0 * FIXED_SEL_3. +pub fn test_fixed_sel_degree_7_mul_witness< + F: PrimeField, + Env: ColAccessCap + ColWriteCap + DirectWitnessCap, +>( + env: &mut Env, + a: F, +) { + env.write_column(TestColumn::A(0), &Env::constant(a)); + let a_2 = a * a; + let a_4 = a_2 * a_2; + let a_6 = a_4 * a_2; + let a_7 = a_6 * a; + let fixed_e = env.read_column(TestColumn::FixedSel3); + let three = F::from(3u32); + let val_fixed_e: F = Env::variable_to_field(fixed_e); + let inv_fixed_e: F = val_fixed_e.inverse().unwrap(); + let res = -three * a_7 * inv_fixed_e; + let res_var = Env::constant(res); + env.write_column(TestColumn::B(0), &res_var); + constrain_test_fixed_sel_degree_7_mul_witness(env); +} + +pub fn constrain_lookups< + F: PrimeField, + Env: ColAccessCap + LookupCap, +>( + env: &mut Env, +) { + let a0 = Env::read_column(env, TestColumn::A(0)); + let a1 = Env::read_column(env, TestColumn::A(1)); + + env.lookup(LookupTable::RangeCheck15, vec![a0.clone()]); + env.lookup(LookupTable::RangeCheck15, vec![a1.clone()]); + + let cur_index = Env::read_column(env, TestColumn::FixedSel1); + let prev_index = Env::read_column(env, TestColumn::FixedSel2); + let next_index = Env::read_column(env, TestColumn::FixedSel3); + + env.lookup( + LookupTable::RuntimeTable1, + vec![ + cur_index.clone(), + prev_index.clone(), + next_index.clone(), + Env::constant(F::from(4u64)), + ], + ); + + // For now we only allow one read per runtime table with runtime_create_column = true. + //env.lookup(LookupTable::RuntimeTable1, vec![a0.clone(), a1.clone()]); + + env.lookup_runtime_write( + LookupTable::RuntimeTable2, + vec![ + Env::constant(F::from(1u64 << 26)), + Env::constant(F::from(5u64)), + ], + ); + env.lookup_runtime_write( + LookupTable::RuntimeTable2, + vec![cur_index, Env::constant(F::from(5u64))], + ); + env.lookup( + LookupTable::RuntimeTable2, + vec![ + Env::constant(F::from(1u64 << 26)), + Env::constant(F::from(5u64)), + ], + ); + env.lookup( + LookupTable::RuntimeTable2, + vec![prev_index, Env::constant(F::from(5u64))], + ); +} + +pub fn lookups_circuit< + F: PrimeField, + Env: ColAccessCap + + ColWriteCap + + DirectWitnessCap + + LookupCap, +>( + env: &mut Env, + domain_size: usize, +) { + for row_i in 0..domain_size { + env.write_column(TestColumn::A(0), &Env::constant(F::from(11u64))); + env.write_column(TestColumn::A(1), &Env::constant(F::from(17u64))); + + constrain_lookups(env); + + if row_i < domain_size - 1 { + env.next_row(); + } + } +} + +/// Fixed selectors for the test circuit. +pub fn build_fixed_selectors(domain_size: usize) -> Box<[Vec; N_FSEL_TEST]> { + // 0 1 2 3 4 ... + let sel1 = (0..domain_size).map(|i| F::from(i as u64)).collect(); + // 0 0 1 2 3 4 ... + let sel2 = (0..domain_size) + .map(|i| { + if i == 0 { + F::zero() + } else { + F::from((i as u64) - 1) + } + }) + .collect(); + // 1 2 3 4 5 ... + let sel3 = (0..domain_size).map(|i| F::from((i + 1) as u64)).collect(); + + Box::new([sel1, sel2, sel3]) +} diff --git a/msm/src/test/test_circuit/lookups.rs b/msm/src/test/test_circuit/lookups.rs new file mode 100644 index 0000000000..51435d6359 --- /dev/null +++ b/msm/src/test/test_circuit/lookups.rs @@ -0,0 +1,96 @@ +use crate::logup::LookupTableID; +use ark_ff::PrimeField; +use num_bigint::BigUint; +use o1_utils::FieldHelpers; +use strum_macros::EnumIter; + +#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, Ord, PartialOrd, EnumIter)] +pub enum LookupTable { + /// Fixed table, x ∈ [0, 2^15]. + RangeCheck15, + /// A runtime table, with no explicit writes. + RuntimeTable1, + /// A runtime table, with explicit writes. + RuntimeTable2, +} + +impl LookupTableID for LookupTable { + fn to_u32(&self) -> u32 { + match self { + Self::RangeCheck15 => 1, + Self::RuntimeTable1 => 2, + Self::RuntimeTable2 => 3, + } + } + + fn from_u32(value: u32) -> Self { + match value { + 1 => Self::RangeCheck15, + 2 => Self::RuntimeTable1, + 3 => Self::RuntimeTable2, + _ => panic!("Invalid lookup table id"), + } + } + + fn is_fixed(&self) -> bool { + match self { + Self::RangeCheck15 => true, + Self::RuntimeTable1 => false, + Self::RuntimeTable2 => false, + } + } + + fn runtime_create_column(&self) -> bool { + match self { + Self::RuntimeTable1 => true, + Self::RuntimeTable2 => false, + _ => panic!("runtime_create_column was called on a non-runtime table"), + } + } + + fn length(&self) -> usize { + match self { + Self::RangeCheck15 => 1 << 15, + Self::RuntimeTable1 => 1 << 15, + Self::RuntimeTable2 => 1 << 15, + } + } + + /// Converts a value to its index in the fixed table. + fn ix_by_value(&self, value: &[F]) -> Option { + let value = value[0]; + if self.is_fixed() { + assert!(self.is_member(value).unwrap()); + } + + match self { + Self::RangeCheck15 => Some(TryFrom::try_from(value.to_biguint()).unwrap()), + Self::RuntimeTable1 => None, + Self::RuntimeTable2 => None, + } + } + + fn all_variants() -> Vec { + vec![Self::RangeCheck15, Self::RuntimeTable1, Self::RuntimeTable2] + } +} + +impl LookupTable { + /// Provides a full list of entries for the given table. + pub fn entries(&self, domain_d1_size: u64) -> Option> { + assert!(domain_d1_size >= (1 << 15)); + match self { + Self::RangeCheck15 => Some((0..domain_d1_size).map(|i| F::from(i)).collect()), + _ => panic!("not possible"), + } + } + + /// Checks if a value is in a given table. + pub fn is_member(&self, value: F) -> Option { + match self { + Self::RangeCheck15 => Some(value.to_biguint() < BigUint::from(2u128.pow(15))), + Self::RuntimeTable1 => None, + Self::RuntimeTable2 => None, + } + } +} diff --git a/msm/src/test/test_circuit/mod.rs b/msm/src/test/test_circuit/mod.rs new file mode 100644 index 0000000000..0f0ed76c81 --- /dev/null +++ b/msm/src/test/test_circuit/mod.rs @@ -0,0 +1,560 @@ +pub mod columns; +pub mod interpreter; +pub mod lookups; + +#[cfg(test)] +mod tests { + use crate::{ + circuit_design::{ConstraintBuilderEnv, WitnessBuilderEnv}, + logup::LookupTableID, + lookups::DummyLookupTable, + test::test_circuit::{ + columns::{TestColumn, N_COL_TEST, N_FSEL_TEST}, + interpreter as test_interpreter, + lookups::LookupTable as TestLookupTable, + }, + Ff1, Fp, + }; + use ark_ff::UniformRand; + use rand::{CryptoRng, Rng, RngCore}; + use std::collections::BTreeMap; + + type TestWitnessBuilderEnv = WitnessBuilderEnv< + Fp, + TestColumn, + { N_COL_TEST - N_FSEL_TEST }, + { N_COL_TEST - N_FSEL_TEST }, + 0, + N_FSEL_TEST, + LT, + >; + + fn build_test_fixed_sel_circuit( + rng: &mut RNG, + domain_size: usize, + ) -> TestWitnessBuilderEnv { + let mut witness_env = WitnessBuilderEnv::create(); + + let fixed_selectors = test_interpreter::build_fixed_selectors(domain_size); + witness_env.set_fixed_selectors(fixed_selectors.to_vec()); + + for row_i in 0..domain_size { + let a: Fp = ::rand(rng); + test_interpreter::test_fixed_sel(&mut witness_env, a); + + if row_i < domain_size - 1 { + witness_env.next_row(); + } + } + + witness_env + } + + #[test] + pub fn test_build_test_fixed_sel_circuit() { + let mut rng = o1_utils::tests::make_test_rng(None); + build_test_fixed_sel_circuit::<_, DummyLookupTable>(&mut rng, 1 << 4); + } + + #[test] + fn test_completeness_fixed_sel() { + let mut rng = o1_utils::tests::make_test_rng(None); + + // Include tests for completeness for Logup as the random witness + // includes all arguments + let domain_size = 1 << 8; + + let fixed_selectors = test_interpreter::build_fixed_selectors(domain_size); + + let mut constraint_env = ConstraintBuilderEnv::::create(); + test_interpreter::constrain_test_fixed_sel::(&mut constraint_env); + // Don't use lookups for now + let constraints = constraint_env.get_relation_constraints(); + + let witness_env = + build_test_fixed_sel_circuit::<_, DummyLookupTable>(&mut rng, domain_size); + let relation_witness = witness_env.get_relation_witness(domain_size); + + crate::test::test_completeness_generic_no_lookups::< + { N_COL_TEST - N_FSEL_TEST }, + { N_COL_TEST - N_FSEL_TEST }, + 0, + N_FSEL_TEST, + _, + >( + constraints, + fixed_selectors, + relation_witness, + domain_size, + &mut rng, + ); + } + + fn build_test_fixed_sel_degree_7_circuit( + rng: &mut RNG, + domain_size: usize, + ) -> TestWitnessBuilderEnv { + let mut witness_env = WitnessBuilderEnv::create(); + + let fixed_sel: Vec = (0..domain_size).map(|i| Fp::from(i as u64)).collect(); + witness_env.set_fixed_selector_cix(TestColumn::FixedSel1, fixed_sel); + + for row_i in 0..domain_size { + let a: Fp = ::rand(rng); + test_interpreter::test_fixed_sel_degree_7(&mut witness_env, a); + + if row_i < domain_size - 1 { + witness_env.next_row(); + } + } + + witness_env + } + + #[test] + fn test_completeness_fixed_sel_degree_7() { + let mut rng = o1_utils::tests::make_test_rng(None); + + // Include tests for completeness for Logup as the random witness + // includes all arguments + let domain_size = 1 << 8; + + let fixed_selectors = test_interpreter::build_fixed_selectors(domain_size); + + let mut constraint_env = ConstraintBuilderEnv::::create(); + test_interpreter::constrain_test_fixed_sel_degree_7::(&mut constraint_env); + // Don't use lookups for now + let constraints = constraint_env.get_relation_constraints(); + + let witness_env = + build_test_fixed_sel_degree_7_circuit::<_, DummyLookupTable>(&mut rng, domain_size); + let relation_witness = witness_env.get_relation_witness(domain_size); + + crate::test::test_completeness_generic_no_lookups::< + { N_COL_TEST - N_FSEL_TEST }, + { N_COL_TEST - N_FSEL_TEST }, + 0, + N_FSEL_TEST, + _, + >( + constraints, + fixed_selectors, + relation_witness, + domain_size, + &mut rng, + ); + } + + fn build_test_fixed_sel_degree_7_circuit_with_constants< + RNG: RngCore + CryptoRng, + LT: LookupTableID, + >( + rng: &mut RNG, + domain_size: usize, + ) -> TestWitnessBuilderEnv { + let mut witness_env = WitnessBuilderEnv::create(); + + let fixed_sel: Vec = (0..domain_size).map(|i| Fp::from(i as u64)).collect(); + witness_env.set_fixed_selector_cix(TestColumn::FixedSel1, fixed_sel); + + for row_i in 0..domain_size { + let a: Fp = ::rand(rng); + test_interpreter::test_fixed_sel_degree_7_with_constants(&mut witness_env, a); + + if row_i < domain_size - 1 { + witness_env.next_row(); + } + } + + witness_env + } + + #[test] + fn test_completeness_fixed_sel_degree_7_with_constants() { + let mut rng = o1_utils::tests::make_test_rng(None); + + // Include tests for completeness for Logup as the random witness + // includes all arguments + let domain_size = 1 << 8; + + let fixed_selectors = test_interpreter::build_fixed_selectors(domain_size); + + let mut constraint_env = ConstraintBuilderEnv::::create(); + test_interpreter::constrain_test_fixed_sel_degree_7_with_constants::( + &mut constraint_env, + ); + // Don't use lookups for now + let constraints = constraint_env.get_relation_constraints(); + + let witness_env = build_test_fixed_sel_degree_7_circuit_with_constants::<_, DummyLookupTable>( + &mut rng, + domain_size, + ); + let relation_witness = witness_env.get_relation_witness(domain_size); + + crate::test::test_completeness_generic_no_lookups::< + { N_COL_TEST - N_FSEL_TEST }, + { N_COL_TEST - N_FSEL_TEST }, + 0, + N_FSEL_TEST, + _, + >( + constraints, + fixed_selectors, + relation_witness, + domain_size, + &mut rng, + ); + } + + fn build_test_fixed_sel_degree_7_circuit_mul_witness< + RNG: RngCore + CryptoRng, + LT: LookupTableID, + >( + rng: &mut RNG, + domain_size: usize, + ) -> TestWitnessBuilderEnv { + let mut witness_env = WitnessBuilderEnv::create(); + + let fixed_selectors = test_interpreter::build_fixed_selectors(domain_size); + witness_env.set_fixed_selectors(fixed_selectors.to_vec()); + + for row_i in 0..domain_size { + let a: Fp = ::rand(rng); + test_interpreter::test_fixed_sel_degree_7_mul_witness(&mut witness_env, a); + + if row_i < domain_size - 1 { + witness_env.next_row(); + } + } + + witness_env + } + + #[test] + fn test_completeness_fixed_sel_degree_7_mul_witness() { + let mut rng = o1_utils::tests::make_test_rng(None); + + // Include tests for completeness for Logup as the random witness + // includes all arguments + let domain_size = 1 << 8; + + let fixed_selectors = test_interpreter::build_fixed_selectors(domain_size); + + let mut constraint_env = ConstraintBuilderEnv::::create(); + test_interpreter::constrain_test_fixed_sel_degree_7_mul_witness::( + &mut constraint_env, + ); + // Don't use lookups for now + let constraints = constraint_env.get_relation_constraints(); + + let witness_env = build_test_fixed_sel_degree_7_circuit_mul_witness::<_, DummyLookupTable>( + &mut rng, + domain_size, + ); + let relation_witness = witness_env.get_relation_witness(domain_size); + + crate::test::test_completeness_generic_no_lookups::< + { N_COL_TEST - N_FSEL_TEST }, + { N_COL_TEST - N_FSEL_TEST }, + 0, + N_FSEL_TEST, + _, + >( + constraints, + fixed_selectors, + relation_witness, + domain_size, + &mut rng, + ); + } + + fn build_test_fixed_sel_degree_7_circuit_fixed_values< + RNG: RngCore + CryptoRng, + LT: LookupTableID, + >( + rng: &mut RNG, + domain_size: usize, + ) -> TestWitnessBuilderEnv { + let mut witness_env = WitnessBuilderEnv::create(); + + let fixed_selectors = test_interpreter::build_fixed_selectors(domain_size); + witness_env.set_fixed_selectors(fixed_selectors.to_vec()); + + for row_i in 0..domain_size { + let a: Fp = ::rand(rng); + test_interpreter::test_fixed_sel_degree_7(&mut witness_env, a); + + if row_i < domain_size - 1 { + witness_env.next_row(); + } + } + + witness_env + } + + #[test] + fn test_completeness_fixed_sel_degree_7_fixed_values() { + let mut rng = o1_utils::tests::make_test_rng(None); + + // Include tests for completeness for Logup as the random witness + // includes all arguments + let domain_size = 1 << 8; + + let fixed_selectors = test_interpreter::build_fixed_selectors(domain_size); + + let mut constraint_env = ConstraintBuilderEnv::::create(); + test_interpreter::constrain_test_fixed_sel_degree_7::(&mut constraint_env); + // Don't use lookups for now + let constraints = constraint_env.get_relation_constraints(); + + let witness_env = build_test_fixed_sel_degree_7_circuit_fixed_values::<_, DummyLookupTable>( + &mut rng, + domain_size, + ); + let relation_witness = witness_env.get_relation_witness(domain_size); + + crate::test::test_completeness_generic_no_lookups::< + { N_COL_TEST - N_FSEL_TEST }, + { N_COL_TEST - N_FSEL_TEST }, + 0, + N_FSEL_TEST, + _, + >( + constraints, + fixed_selectors, + relation_witness, + domain_size, + &mut rng, + ); + } + + fn build_test_const_circuit( + rng: &mut RNG, + domain_size: usize, + ) -> (TestWitnessBuilderEnv, Fp) { + let mut witness_env = WitnessBuilderEnv::create(); + + // To support less rows than domain_size we need to have selectors. + //let row_num = rng.gen_range(0..domain_size); + + let fixed_sel: Vec = (0..domain_size).map(|i| Fp::from(i as u64)).collect(); + witness_env.set_fixed_selector_cix(TestColumn::FixedSel1, fixed_sel); + + let constant: Fp = ::rand(rng); + for row_i in 0..domain_size { + let a: Fp = ::rand(rng); + let b: Fp = constant / a; + assert!(a * b == constant); + test_interpreter::test_const(&mut witness_env, a, b, constant); + + if row_i < domain_size - 1 { + witness_env.next_row(); + } + } + + (witness_env, constant) + } + + #[test] + pub fn test_build_test_constant_circuit() { + let mut rng = o1_utils::tests::make_test_rng(None); + build_test_const_circuit::<_, DummyLookupTable>(&mut rng, 1 << 4); + } + + #[test] + fn test_completeness_constant() { + let mut rng = o1_utils::tests::make_test_rng(None); + + // Include tests for completeness for Logup as the random witness + // includes all arguments + let domain_size = 1 << 8; + + let fixed_selectors = test_interpreter::build_fixed_selectors(domain_size); + + let (witness_env, constant) = + build_test_const_circuit::<_, DummyLookupTable>(&mut rng, domain_size); + let relation_witness = witness_env.get_relation_witness(domain_size); + + let mut constraint_env = ConstraintBuilderEnv::::create(); + test_interpreter::constrain_test_const::(&mut constraint_env, constant); + let constraints = constraint_env.get_relation_constraints(); + + crate::test::test_completeness_generic_no_lookups::< + { N_COL_TEST - N_FSEL_TEST }, + { N_COL_TEST - N_FSEL_TEST }, + 0, + N_FSEL_TEST, + _, + >( + constraints, + fixed_selectors, + relation_witness, + domain_size, + &mut rng, + ); + } + + fn build_test_mul_circuit( + rng: &mut RNG, + domain_size: usize, + ) -> TestWitnessBuilderEnv { + let mut witness_env = WitnessBuilderEnv::create(); + + // To support less rows than domain_size we need to have selectors. + //let row_num = rng.gen_range(0..domain_size); + + let fixed_sel: Vec = (0..domain_size).map(|i| Fp::from(i as u64)).collect(); + witness_env.set_fixed_selector_cix(TestColumn::FixedSel1, fixed_sel); + + let row_num = 10; + for row_i in 0..row_num { + let a: Ff1 = From::from(rng.gen_range(0..(1 << 16))); + let b: Ff1 = From::from(rng.gen_range(0..(1 << 16))); + test_interpreter::test_multiplication(&mut witness_env, a, b); + if row_i < domain_size - 1 { + witness_env.next_row(); + } + } + + witness_env + } + + #[test] + /// Tests if the "test" circuit is valid without running the proof. + pub fn test_build_test_mul_circuit() { + let mut rng = o1_utils::tests::make_test_rng(None); + build_test_mul_circuit::<_, DummyLookupTable>(&mut rng, 1 << 4); + } + + #[test] + fn test_completeness_lookups() { + let mut rng = o1_utils::tests::make_test_rng(None); + + // Include tests for completeness for Logup as the random witness + // includes all arguments + let domain_size = 1 << 15; + + let fixed_selectors = test_interpreter::build_fixed_selectors(domain_size); + + let mut constraint_env = ConstraintBuilderEnv::::create(); + test_interpreter::constrain_lookups::(&mut constraint_env); + let constraints = constraint_env.get_relation_constraints(); + + let mut witness_env: TestWitnessBuilderEnv = WitnessBuilderEnv::create(); + witness_env.set_fixed_selectors(fixed_selectors.to_vec()); + test_interpreter::lookups_circuit(&mut witness_env, domain_size); + let runtime_tables: BTreeMap<_, Vec>>> = + witness_env.get_runtime_tables(domain_size); + + let mut lookup_tables_data: BTreeMap>>> = BTreeMap::new(); + for table_id in TestLookupTable::all_variants().into_iter() { + if table_id.is_fixed() { + lookup_tables_data.insert( + table_id, + vec![table_id + .entries(domain_size as u64) + .unwrap() + .into_iter() + .map(|x| vec![x]) + .collect()], + ); + } + } + for (table_id, runtime_table) in runtime_tables.into_iter() { + lookup_tables_data.insert(table_id, runtime_table); + } + + let proof_inputs = witness_env.get_proof_inputs(domain_size, lookup_tables_data); + + crate::test::test_completeness_generic::< + { N_COL_TEST - N_FSEL_TEST }, + { N_COL_TEST - N_FSEL_TEST }, + 0, + N_FSEL_TEST, + TestLookupTable, + _, + >( + constraints, + fixed_selectors, + proof_inputs, + domain_size, + &mut rng, + ); + } + + #[test] + fn test_completeness() { + let mut rng = o1_utils::tests::make_test_rng(None); + + // Include tests for completeness for Logup as the random witness + // includes all arguments + let domain_size = 1 << 8; + + let fixed_selectors = test_interpreter::build_fixed_selectors(domain_size); + + let mut constraint_env = ConstraintBuilderEnv::::create(); + test_interpreter::constrain_multiplication::(&mut constraint_env); + // Don't use lookups for now + let constraints = constraint_env.get_relation_constraints(); + + let witness_env = build_test_mul_circuit::<_, DummyLookupTable>(&mut rng, domain_size); + let relation_witness = witness_env.get_relation_witness(domain_size); + + crate::test::test_completeness_generic_no_lookups::< + { N_COL_TEST - N_FSEL_TEST }, + { N_COL_TEST - N_FSEL_TEST }, + 0, + N_FSEL_TEST, + _, + >( + constraints, + fixed_selectors, + relation_witness, + domain_size, + &mut rng, + ); + } + + #[test] + fn test_soundness() { + let mut rng = o1_utils::tests::make_test_rng(None); + + // We generate two different witness and two different proofs. + let domain_size: usize = 1 << 8; + + let fixed_selectors = test_interpreter::build_fixed_selectors(domain_size); + + let mut constraint_env = ConstraintBuilderEnv::::create(); + test_interpreter::constrain_multiplication::(&mut constraint_env); + let constraints = constraint_env.get_relation_constraints(); + + let lookup_tables_data = BTreeMap::new(); + let witness_env = build_test_mul_circuit::<_, DummyLookupTable>(&mut rng, domain_size); + let mut proof_inputs = + witness_env.get_proof_inputs(domain_size, lookup_tables_data.clone()); + proof_inputs.logups = Default::default(); + + let witness_env_prime = + build_test_mul_circuit::<_, DummyLookupTable>(&mut rng, domain_size); + let mut proof_inputs_prime = + witness_env_prime.get_proof_inputs(domain_size, lookup_tables_data.clone()); + proof_inputs_prime.logups = Default::default(); + + crate::test::test_soundness_generic::< + { N_COL_TEST - N_FSEL_TEST }, + { N_COL_TEST - N_FSEL_TEST }, + 0, + N_FSEL_TEST, + DummyLookupTable, + _, + >( + constraints, + fixed_selectors, + proof_inputs, + proof_inputs_prime, + domain_size, + &mut rng, + ); + } +} diff --git a/msm/src/verifier.rs b/msm/src/verifier.rs new file mode 100644 index 0000000000..74d0cd0769 --- /dev/null +++ b/msm/src/verifier.rs @@ -0,0 +1,335 @@ +#![allow(clippy::type_complexity)] +#![allow(clippy::boxed_local)] + +use crate::logup::LookupTableID; +use ark_ff::{Field, Zero}; +use ark_poly::{ + univariate::DensePolynomial, EvaluationDomain, Evaluations, Polynomial, + Radix2EvaluationDomain as R2D, +}; +use rand::thread_rng; +use rayon::iter::{IntoParallelIterator, ParallelIterator}; + +use kimchi::{ + circuits::{ + berkeley_columns::BerkeleyChallenges, + domains::EvaluationDomains, + expr::{Constants, Expr, PolishToken}, + }, + curve::KimchiCurve, + groupmap::GroupMap, + plonk_sponge::FrSponge, + proof::PointEvaluations, +}; +use mina_poseidon::{sponge::ScalarChallenge, FqSponge}; +use poly_commitment::{ + commitment::{ + absorb_commitment, combined_inner_product, BatchEvaluationProof, Evaluation, PolyComm, + }, + OpenProof, SRS, +}; + +use crate::{expr::E, proof::Proof, witness::Witness}; + +pub fn verify< + G: KimchiCurve, + OpeningProof: OpenProof, + EFqSponge: Clone + FqSponge, + EFrSponge: FrSponge, + const N_WIT: usize, + const N_REL: usize, + const N_DSEL: usize, + const N_FSEL: usize, + const NPUB: usize, + ID: LookupTableID, +>( + domain: EvaluationDomains, + srs: &OpeningProof::SRS, + constraints: &Vec>, + fixed_selectors: Box<[Vec; N_FSEL]>, + proof: &Proof, + public_inputs: Witness>, +) -> bool +where + OpeningProof::SRS: Sync, +{ + let Proof { + proof_comms, + proof_evals, + opening_proof, + } = proof; + + //////////////////////////////////////////////////////////////////////////// + // Re-evaluating public inputs + //////////////////////////////////////////////////////////////////////////// + + let fixed_selectors_evals_d1: Box<[Evaluations>; N_FSEL]> = { + o1_utils::array::vec_to_boxed_array( + fixed_selectors + .into_par_iter() + .map(|evals| Evaluations::from_vec_and_domain(evals, domain.d1)) + .collect(), + ) + }; + + let fixed_selectors_polys: Box<[DensePolynomial; N_FSEL]> = { + o1_utils::array::vec_to_boxed_array( + fixed_selectors_evals_d1 + .into_par_iter() + .map(|evals| evals.interpolate()) + .collect(), + ) + }; + + let fixed_selectors_comms: Box<[PolyComm; N_FSEL]> = { + let comm = |poly: &DensePolynomial| srs.commit_non_hiding(poly, 1); + o1_utils::array::vec_to_boxed_array( + fixed_selectors_polys + .as_ref() + .into_par_iter() + .map(comm) + .collect(), + ) + }; + + // Interpolate public input columns on d1, using trait Into. + let public_input_evals_d1: Witness>> = + public_inputs + .into_par_iter() + .map(|evals| { + Evaluations::>::from_vec_and_domain( + evals, domain.d1, + ) + }) + .collect::>>>(); + + let public_input_polys: Witness> = { + let interpolate = + |evals: Evaluations>| evals.interpolate(); + public_input_evals_d1 + .into_par_iter() + .map(interpolate) + .collect::>>() + }; + + let public_input_comms: Witness> = { + let comm = |poly: &DensePolynomial| srs.commit_non_hiding(poly, 1); + (&public_input_polys) + .into_par_iter() + .map(comm) + .collect::>>() + }; + + assert!( + NPUB <= N_WIT, + "Number of public inputs exceeds number of witness columns" + ); + for i in 0..NPUB { + assert!(public_input_comms.cols[i] == proof_comms.witness_comms.cols[i]); + } + + //////////////////////////////////////////////////////////////////////////// + // Absorbing all the commitments to the columns + //////////////////////////////////////////////////////////////////////////// + + let mut fq_sponge = EFqSponge::new(G::other_curve_sponge_params()); + + fixed_selectors_comms + .as_ref() + .iter() + .chain(&proof_comms.witness_comms) + .for_each(|comm| absorb_commitment(&mut fq_sponge, comm)); + + //////////////////////////////////////////////////////////////////////////// + // Logup + //////////////////////////////////////////////////////////////////////////// + + let (joint_combiner, beta) = { + if let Some(logup_comms) = &proof_comms.logup_comms { + // First, we absorb the multiplicity polynomials + logup_comms.m.values().for_each(|comms| { + comms + .iter() + .for_each(|comm| absorb_commitment(&mut fq_sponge, comm)) + }); + + // FIXME @volhovm it seems that the verifier does not + // actually check that the fixed tables used in the proof + // are the fixed tables defined in the code. In other + // words, all the currently used "fixed" tables are + // runtime and can be chosen freely by the prover. + + // To generate the challenges + let joint_combiner = fq_sponge.challenge(); + let beta = fq_sponge.challenge(); + + // And now, we absorb the commitments to the other polynomials + logup_comms.h.values().for_each(|comms| { + comms + .iter() + .for_each(|comm| absorb_commitment(&mut fq_sponge, comm)) + }); + + logup_comms + .fixed_tables + .values() + .for_each(|comm| absorb_commitment(&mut fq_sponge, comm)); + + // And at the end, the aggregation + absorb_commitment(&mut fq_sponge, &logup_comms.sum); + (Some(joint_combiner), beta) + } else { + (None, G::ScalarField::zero()) + } + }; + + // Sample α with the Fq-Sponge. + let alpha = fq_sponge.challenge(); + + //////////////////////////////////////////////////////////////////////////// + // Quotient polynomial + //////////////////////////////////////////////////////////////////////////// + + absorb_commitment(&mut fq_sponge, &proof_comms.t_comm); + + // -- Preparing for opening proof verification + let zeta_chal = ScalarChallenge(fq_sponge.challenge()); + let (_, endo_r) = G::endos(); + let zeta: G::ScalarField = zeta_chal.to_field(endo_r); + let omega = domain.d1.group_gen; + let zeta_omega = zeta * omega; + + let mut coms_and_evaluations: Vec> = vec![]; + + coms_and_evaluations.extend( + (&proof_comms.witness_comms) + .into_iter() + .zip(&proof_evals.witness_evals) + .map(|(commitment, point_eval)| Evaluation { + commitment: commitment.clone(), + evaluations: vec![vec![point_eval.zeta], vec![point_eval.zeta_omega]], + }), + ); + + coms_and_evaluations.extend( + (fixed_selectors_comms) + .into_iter() + .zip(proof_evals.fixed_selectors_evals.iter()) + .map(|(commitment, point_eval)| Evaluation { + commitment: commitment.clone(), + evaluations: vec![vec![point_eval.zeta], vec![point_eval.zeta_omega]], + }), + ); + + if let Some(logup_comms) = &proof_comms.logup_comms { + coms_and_evaluations.extend( + logup_comms + .into_iter() + .zip(proof_evals.logup_evals.as_ref().unwrap()) + .map(|(commitment, point_eval)| Evaluation { + commitment: commitment.clone(), + evaluations: vec![vec![point_eval.zeta], vec![point_eval.zeta_omega]], + }) + .collect::>(), + ); + } + + // -- Absorb all coms_and_evaluations + let fq_sponge_before_coms_and_evaluations = fq_sponge.clone(); + let mut fr_sponge = EFrSponge::new(G::sponge_params()); + fr_sponge.absorb(&fq_sponge.digest()); + + for PointEvaluations { zeta, zeta_omega } in (&proof_evals.witness_evals).into_iter() { + fr_sponge.absorb(zeta); + fr_sponge.absorb(zeta_omega); + } + + for PointEvaluations { zeta, zeta_omega } in proof_evals.fixed_selectors_evals.as_ref().iter() { + fr_sponge.absorb(zeta); + fr_sponge.absorb(zeta_omega); + } + + if proof_comms.logup_comms.is_some() { + // Logup FS + for PointEvaluations { zeta, zeta_omega } in + proof_evals.logup_evals.as_ref().unwrap().into_iter() + { + fr_sponge.absorb(zeta); + fr_sponge.absorb(zeta_omega); + } + }; + + // Compute [ft(X)] = \ + // (1 - ζ^n) \ + // ([t_0(X)] + ζ^n [t_1(X)] + ... + ζ^{kn} [t_{k}(X)]) + let ft_comm = { + let evaluation_point_to_domain_size = zeta.pow([domain.d1.size]); + let chunked_t_comm = proof_comms + .t_comm + .chunk_commitment(evaluation_point_to_domain_size); + // (1 - ζ^n) + let minus_vanishing_poly_at_zeta = -domain.d1.vanishing_polynomial().evaluate(&zeta); + chunked_t_comm.scale(minus_vanishing_poly_at_zeta) + }; + + let challenges = BerkeleyChallenges:: { + alpha, + beta, + gamma: G::ScalarField::zero(), + joint_combiner: joint_combiner.unwrap_or(G::ScalarField::zero()), + }; + + let constants = Constants { + endo_coefficient: *endo_r, + mds: &G::sponge_params().mds, + zk_rows: 0, + }; + + let combined_expr = + Expr::combine_constraints(0..(constraints.len() as u32), constraints.clone()); + // Note the minus! ft polynomial at zeta (ft_eval0) is minus evaluation of the expression. + let ft_eval0 = -PolishToken::evaluate( + combined_expr.to_polish().as_slice(), + domain.d1, + zeta, + proof_evals, + &constants, + &challenges, + ) + .unwrap(); + + coms_and_evaluations.push(Evaluation { + commitment: ft_comm, + evaluations: vec![vec![ft_eval0], vec![proof_evals.ft_eval1]], + }); + + fr_sponge.absorb(&proof_evals.ft_eval1); + // -- End absorb all coms_and_evaluations + + let v_chal = fr_sponge.challenge(); + let v = v_chal.to_field(endo_r); + let u_chal = fr_sponge.challenge(); + let u = u_chal.to_field(endo_r); + + let combined_inner_product = { + let es: Vec<_> = coms_and_evaluations + .iter() + .map(|Evaluation { evaluations, .. }| evaluations.clone()) + .collect(); + + combined_inner_product(&v, &u, es.as_slice()) + }; + + let batch = BatchEvaluationProof { + sponge: fq_sponge_before_coms_and_evaluations, + evaluations: coms_and_evaluations, + evaluation_points: vec![zeta, zeta_omega], + polyscale: v, + evalscale: u, + opening: opening_proof, + combined_inner_product, + }; + + let group_map = G::Map::setup(); + OpeningProof::verify(srs, &group_map, &mut [batch], &mut thread_rng()) +} diff --git a/msm/src/witness.rs b/msm/src/witness.rs new file mode 100644 index 0000000000..42a1b9120b --- /dev/null +++ b/msm/src/witness.rs @@ -0,0 +1,178 @@ +use ark_ff::{FftField, Zero}; +use ark_poly::{Evaluations, Radix2EvaluationDomain}; +use folding::{instance_witness::Foldable, Witness as FoldingWitnessT}; +use poly_commitment::commitment::CommitmentCurve; +use rayon::iter::{FromParallelIterator, IntoParallelIterator, ParallelIterator}; +use std::ops::Index; + +/// The witness columns used by a gate of the MSM circuits. +/// It is generic over the number of columns, `N_WIT`, and the type of the witness, `T`. +/// It is parametrized by a type `T` which can be either: +/// - `Vec` for the evaluations +/// - `PolyComm` for the commitments +/// It can be used to represent the different subcircuits used by the project. +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct Witness { + /// A witness row is represented by an array of N witness columns + /// When T is a vector, then the witness describes the rows of the circuit. + pub cols: Box<[T; N_WIT]>, +} + +impl Default for Witness { + fn default() -> Self { + Witness { + cols: Box::new(std::array::from_fn(|_| T::zero())), + } + } +} + +impl TryFrom> for Witness { + type Error = String; + + fn try_from(value: Vec) -> Result { + let len = value.len(); + let cols: Box<[T; N_WIT]> = value + .try_into() + .map_err(|_| format!("Size mismatch: Expected {N_WIT:?} got {len:?}"))?; + Ok(Witness { cols }) + } +} + +impl Index for Witness { + type Output = T; + + fn index(&self, index: usize) -> &Self::Output { + &self.cols[index] + } +} + +impl Witness { + pub fn len(&self) -> usize { + self.cols.len() + } + + pub fn is_empty(&self) -> bool { + self.cols.is_empty() + } +} + +impl Witness> { + pub fn zero_vec(domain_size: usize) -> Self { + Witness { + // Ideally the vector should be of domain size, but + // one-element vector should be a reasonable default too. + cols: Box::new(std::array::from_fn(|_| vec![T::zero(); domain_size])), + } + } + + pub fn to_pub_columns(&self) -> Witness> { + let mut newcols: [Vec; NPUB] = std::array::from_fn(|_| vec![]); + for (i, vec) in self.cols[0..NPUB].iter().enumerate() { + newcols[i] = vec.clone(); + } + Witness { + cols: Box::new(newcols), + } + } +} + +// IMPLEMENTATION OF ITERATORS FOR THE WITNESS STRUCTURE + +impl<'lt, const N_WIT: usize, G> IntoIterator for &'lt Witness { + type Item = &'lt G; + type IntoIter = std::vec::IntoIter<&'lt G>; + + fn into_iter(self) -> Self::IntoIter { + let mut iter_contents = Vec::with_capacity(N_WIT); + iter_contents.extend(&*self.cols); + iter_contents.into_iter() + } +} + +impl IntoIterator for Witness { + type Item = F; + type IntoIter = std::vec::IntoIter; + + /// Iterate over the columns in the circuit. + fn into_iter(self) -> Self::IntoIter { + let mut iter_contents = Vec::with_capacity(N_WIT); + iter_contents.extend(*self.cols); + iter_contents.into_iter() + } +} + +impl IntoParallelIterator for Witness +where + Vec: IntoParallelIterator, +{ + type Iter = as IntoParallelIterator>::Iter; + type Item = as IntoParallelIterator>::Item; + + /// Iterate over the columns in the circuit, in parallel. + fn into_par_iter(self) -> Self::Iter { + let mut iter_contents = Vec::with_capacity(N_WIT); + iter_contents.extend(*self.cols); + iter_contents.into_par_iter() + } +} + +impl FromParallelIterator for Witness { + fn from_par_iter(par_iter: I) -> Self + where + I: IntoParallelIterator, + { + let mut iter_contents = par_iter.into_par_iter().collect::>(); + let cols = iter_contents + .drain(..N_WIT) + .collect::>() + .try_into() + .unwrap(); + Witness { cols } + } +} + +impl<'data, const N_WIT: usize, G> IntoParallelIterator for &'data Witness +where + Vec<&'data G>: IntoParallelIterator, +{ + type Iter = as IntoParallelIterator>::Iter; + type Item = as IntoParallelIterator>::Item; + + fn into_par_iter(self) -> Self::Iter { + let mut iter_contents = Vec::with_capacity(N_WIT); + iter_contents.extend(&*self.cols); + iter_contents.into_par_iter() + } +} + +impl<'data, const N_WIT: usize, G> IntoParallelIterator for &'data mut Witness +where + Vec<&'data mut G>: IntoParallelIterator, +{ + type Iter = as IntoParallelIterator>::Iter; + type Item = as IntoParallelIterator>::Item; + + fn into_par_iter(self) -> Self::Iter { + let mut iter_contents = Vec::with_capacity(N_WIT); + iter_contents.extend(&mut *self.cols); + iter_contents.into_par_iter() + } +} + +impl Foldable + for Witness>> +{ + fn combine(mut a: Self, b: Self, challenge: F) -> Self { + for (a, b) in (*a.cols).iter_mut().zip(*(b.cols)) { + for (a, b) in a.evals.iter_mut().zip(b.evals) { + *a += challenge * b; + } + } + a + } +} + +impl FoldingWitnessT + for Witness>> +{ +} diff --git a/mvpoly/Cargo.toml b/mvpoly/Cargo.toml new file mode 100644 index 0000000000..fb88a756e4 --- /dev/null +++ b/mvpoly/Cargo.toml @@ -0,0 +1,34 @@ +[package] +name = "mvpoly" +version = "0.1.0" +description = "Symbolic interpreter of multivariate polyomials" +repository = "https://github.com/o1-labs/proof-systems" +homepage = "https://o1-labs.github.io/proof-systems/" +documentation = "https://o1-labs.github.io/proof-systems/rustdoc/" +readme = "README.md" +edition = "2021" +license = "Apache-2.0" + +[lib] +path = "src/lib.rs" + +[dependencies] +ark-ff.workspace = true +kimchi.workspace = true +log.workspace = true +num-integer.workspace = true +o1-utils.workspace = true +rand.workspace = true + +[dev-dependencies] +mina-curves.workspace = true +rand.workspace = true +criterion = { version = "0.5.1", features = ["html_reports"] } + +[[bench]] +name = "prime" +harness = false + +[[bench]] +name = "monomials" +harness = false \ No newline at end of file diff --git a/mvpoly/README.md b/mvpoly/README.md new file mode 100644 index 0000000000..447e56811f --- /dev/null +++ b/mvpoly/README.md @@ -0,0 +1,36 @@ +# MVPoly: play with multi-variate poynomials + +This library aims to provide algorithms and different representations to +manipulate multivariate polynomials. + +## Notation + +$\mathbb{F}^{\le d}[X_1, \cdots, X_{N}]$ is the vector space over the finite +field $\mathbb{F}$ of multivariate polynomials with $N$ variables and of maximum +degree $d$. + +For instance, for multivariate polynomials with two variables, and maximum +degree 2, the set is +```math +\mathbb{F}^{\le 2}[X_{1}, X_{2}] = \left\{ a_{0} + a_{1} X_{1} + a_{2} X_{2} + a_{3} X_{1} X_{2} + a_{4} X_{1}^2 + a_{5} X_{2}^2 \, | \, a_{i} \in \mathbb{F} \right\} +``` + +The normal form of a multivariate polynomials is the multi-variate polynomials +whose monomials are all different, i.e. the coefficients are in the "canonical" +(sic) base $\{ \prod_{k = 1}^{N} X_{i_{1}}^{n_{i_{1}}} \cdots X_{i_{k}}^{n_{k}} +\}$ where $\sum_{k = 1}^{N} n_{i_{k}} \le d$. + +For instance, the canonical base of $\mathbb{F}^{\le 2}[X_{1}, X_{2}]$ is the set $\{ 1, X_{1}, X_{2}, X_{1} X_{2}, X_{1}^2, X_{2}^2 \}$ + +Examples: + +- $X_{1} + X_{2}$ is in normal form +- $X_{1} + 42 X_{1}$ is not in normal form +- $X_{1} (X_{1} + X_{2})$ is not in normal form +- $X_{1}^2 + X_{1} X_{2}$ is in normal form + +## Resources + +This README is partially or fully imported from the document [Sympolyc - +symbolic computation on multi-variate polynomials using prime +numbers](https://hackmd.io/@dannywillems/SyHar7p5A) diff --git a/mvpoly/benches/monomials.rs b/mvpoly/benches/monomials.rs new file mode 100644 index 0000000000..1b5362dd27 --- /dev/null +++ b/mvpoly/benches/monomials.rs @@ -0,0 +1,72 @@ +use ark_ff::UniformRand; +use criterion::{black_box, criterion_group, criterion_main, Bencher, Criterion}; +use mina_curves::pasta::Fp; +use mvpoly::{monomials::Sparse, MVPoly}; + +// Using 10 variables, with max degree 3 +// Should roughly cover the cases we care about +fn bench_sparse_add(c: &mut Criterion) { + let mut rng = o1_utils::tests::make_test_rng(None); + let p1: Sparse = unsafe { Sparse::random(&mut rng, None) }; + let p2: Sparse = unsafe { Sparse::random(&mut rng, None) }; + c.bench_function("sparse_add", |b: &mut Bencher| { + b.iter(|| { + let _ = black_box(&p1) + black_box(&p2); + }) + }); +} + +fn bench_sparse_mul(c: &mut Criterion) { + let mut rng = o1_utils::tests::make_test_rng(None); + c.bench_function("sparse_mul", |b: &mut Bencher| { + b.iter(|| { + // IMPROVEME: implement mul on references and define the random + // values before the benchmark + let p1: Sparse = unsafe { Sparse::random(&mut rng, None) }; + let p2: Sparse = unsafe { Sparse::random(&mut rng, None) }; + let _ = black_box(p1) * black_box(p2); + }) + }); +} + +fn bench_sparse_neg(c: &mut Criterion) { + let mut rng = o1_utils::tests::make_test_rng(None); + let p1: Sparse = unsafe { Sparse::random(&mut rng, None) }; + c.bench_function("sparse_neg", |b: &mut Bencher| { + b.iter(|| { + let _ = -black_box(&p1); + }) + }); +} + +fn bench_sparse_sub(c: &mut Criterion) { + let mut rng = o1_utils::tests::make_test_rng(None); + let p1: Sparse = unsafe { Sparse::random(&mut rng, None) }; + let p2: Sparse = unsafe { Sparse::random(&mut rng, None) }; + c.bench_function("sparse_sub", |b: &mut Bencher| { + b.iter(|| { + let _ = black_box(&p1) - black_box(&p2); + }) + }); +} + +fn bench_sparse_eval(c: &mut Criterion) { + let mut rng = o1_utils::tests::make_test_rng(None); + let p1: Sparse = unsafe { Sparse::random(&mut rng, None) }; + let x: [Fp; 10] = std::array::from_fn(|_| Fp::rand(&mut rng)); + c.bench_function("sparse_eval", |b: &mut Bencher| { + b.iter(|| { + let _ = black_box(&p1).eval(black_box(&x)); + }) + }); +} + +criterion_group!( + benches, + bench_sparse_add, + bench_sparse_mul, + bench_sparse_neg, + bench_sparse_sub, + bench_sparse_eval +); +criterion_main!(benches); diff --git a/mvpoly/benches/prime.rs b/mvpoly/benches/prime.rs new file mode 100644 index 0000000000..74ec185a34 --- /dev/null +++ b/mvpoly/benches/prime.rs @@ -0,0 +1,72 @@ +use ark_ff::UniformRand; +use criterion::{black_box, criterion_group, criterion_main, Bencher, Criterion}; +use mina_curves::pasta::Fp; +use mvpoly::{prime::Dense, MVPoly}; + +// Using 10 variables, with max degree 3 +// Should roughly cover the cases we care about +fn bench_dense_add(c: &mut Criterion) { + let mut rng = o1_utils::tests::make_test_rng(None); + let p1: Dense = unsafe { Dense::random(&mut rng, None) }; + let p2: Dense = unsafe { Dense::random(&mut rng, None) }; + c.bench_function("dense_add", |b: &mut Bencher| { + b.iter(|| { + let _ = black_box(&p1) + black_box(&p2); + }) + }); +} + +fn bench_dense_mul(c: &mut Criterion) { + let mut rng = o1_utils::tests::make_test_rng(None); + c.bench_function("dense_mul", |b: &mut Bencher| { + b.iter(|| { + // IMPROVEME: implement mul on references and define the random + // values before the benchmark + let p1: Dense = unsafe { Dense::random(&mut rng, None) }; + let p2: Dense = unsafe { Dense::random(&mut rng, None) }; + let _ = black_box(p1) * black_box(p2); + }) + }); +} + +fn bench_dense_neg(c: &mut Criterion) { + let mut rng = o1_utils::tests::make_test_rng(None); + let p1: Dense = unsafe { Dense::random(&mut rng, None) }; + c.bench_function("dense_neg", |b: &mut Bencher| { + b.iter(|| { + let _ = -black_box(&p1); + }) + }); +} + +fn bench_dense_sub(c: &mut Criterion) { + let mut rng = o1_utils::tests::make_test_rng(None); + let p1: Dense = unsafe { Dense::random(&mut rng, None) }; + let p2: Dense = unsafe { Dense::random(&mut rng, None) }; + c.bench_function("dense_sub", |b: &mut Bencher| { + b.iter(|| { + let _ = black_box(&p1) - black_box(&p2); + }) + }); +} + +fn bench_dense_eval(c: &mut Criterion) { + let mut rng = o1_utils::tests::make_test_rng(None); + let p1: Dense = unsafe { Dense::random(&mut rng, None) }; + let x: [Fp; 10] = std::array::from_fn(|_| Fp::rand(&mut rng)); + c.bench_function("dense_eval", |b: &mut Bencher| { + b.iter(|| { + let _ = black_box(&p1).eval(black_box(&x)); + }) + }); +} + +criterion_group!( + benches, + bench_dense_add, + bench_dense_mul, + bench_dense_neg, + bench_dense_sub, + bench_dense_eval +); +criterion_main!(benches); diff --git a/mvpoly/src/lib.rs b/mvpoly/src/lib.rs new file mode 100644 index 0000000000..6e8d35b7bf --- /dev/null +++ b/mvpoly/src/lib.rs @@ -0,0 +1,263 @@ +//! This module contains the definition of the `MVPoly` trait, which is used to +//! represent multi-variate polynomials. +//! +//! Different representations are provided in the sub-modules: +//! - `monomials`: a representation based on monomials +//! - `prime`: a representation based on a mapping from variables to prime +//! numbers. This representation is unmaintained for now. We leave it +//! for interested users. +//! +//! "Expressions", as defined in the [kimchi] crate, can be converted into a +//! multi-variate polynomial using the `from_expr` method. + +use ark_ff::PrimeField; +use kimchi::circuits::expr::{ + ConstantExpr, ConstantExprInner, ConstantTerm, Expr, ExprInner, Operations, Variable, +}; +use rand::RngCore; +use std::collections::HashMap; + +pub mod monomials; +pub mod pbt; +pub mod prime; +pub mod utils; + +/// Generic trait to represent a multi-variate polynomial +pub trait MVPoly: + // Addition + std::ops::Add + + for<'a> std::ops::Add<&'a Self, Output = Self> + // Mul + + std::ops::Mul + // Negation + + std::ops::Neg + // Sub + + std::ops::Sub + + for<'a> std::ops::Sub<&'a Self, Output = Self> + + ark_ff::One + + ark_ff::Zero + + std::fmt::Debug + + Clone + // Comparison operators + + PartialEq + + Eq + // Useful conversions + + From + + Sized +{ + /// Generate a random polynomial of maximum degree `max_degree`. + /// + /// If `None` is provided as the maximum degree, the polynomial will be + /// generated with a maximum degree of `D`. + /// + /// # Safety + /// + /// Marked as unsafe to warn the user to use it with caution and to not + /// necessarily rely on it for security/randomness in cryptographic + /// protocols. The user is responsible for providing its own secure + /// polynomial random generator, if needed. + /// + /// For now, the function is only used for testing. + unsafe fn random(rng: &mut RNG, max_degree: Option) -> Self; + + fn double(&self) -> Self; + + fn is_constant(&self) -> bool; + + fn mul_by_scalar(&self, scalar: F) -> Self; + + /// Returns the degree of the polynomial. + /// + /// The degree of the polynomial is the maximum degree of the monomials + /// that have a non-zero coefficient. + /// + /// # Safety + /// + /// The zero polynomial as a degree equals to 0, as the degree of the + /// constant polynomials. We do use the `unsafe` keyword to warn the user + /// for this specific case. + unsafe fn degree(&self) -> usize; + + /// Evaluate the polynomial at the vector point `x`. + /// + /// This is a dummy implementation. A cache can be used for the monomials to + /// speed up the computation. + fn eval(&self, x: &[F; N]) -> F; + + /// Build the univariate polynomial `x_i` from the variable `i`. + /// The conversion into the type `usize` is unspecified by this trait. It + /// is left to the trait implementation. + /// For instance, in the case of [crate::prime], the output must be a prime + /// number, starting at `2`. [crate::utils::PrimeNumberGenerator] can be + /// used. + /// For [crate::monomials], the output must be the index of the variable, + /// starting from `0`. + /// + /// The parameter `offset_next_row` is an optional argument that is used to + /// support the case where the "next row" is used. In this case, the type + /// parameter `N` must include this offset (i.e. if 4 variables are in ued, + /// N should be at least `8 = 2 * 4`). + fn from_variable>(var: Variable, offset_next_row: Option) -> Self; + + fn from_constant(op: Operations>) -> Self { + use kimchi::circuits::expr::Operations::*; + match op { + Atom(op_const) => { + match op_const { + ConstantExprInner::Challenge(_) => { + unimplemented!("Challenges are not supposed to be used in this context for now") + } + ConstantExprInner::Constant(ConstantTerm::EndoCoefficient) => { + unimplemented!( + "The constant EndoCoefficient is not supposed to be used in this context" + ) + } + ConstantExprInner::Constant(ConstantTerm::Mds { + row: _row, + col: _col, + }) => { + unimplemented!("The constant Mds is not supposed to be used in this context") + } + ConstantExprInner::Constant(ConstantTerm::Literal(c)) => Self::from(c), + } + } + Add(c1, c2) => Self::from_constant(*c1) + Self::from_constant(*c2), + Sub(c1, c2) => Self::from_constant(*c1) - Self::from_constant(*c2), + Mul(c1, c2) => Self::from_constant(*c1) * Self::from_constant(*c2), + Square(c) => Self::from_constant(*c.clone()) * Self::from_constant(*c), + Double(c1) => Self::from_constant(*c1).double(), + Pow(c, e) => { + // FIXME: dummy implementation + let p = Self::from_constant(*c); + let mut result = p.clone(); + for _ in 0..e { + result = result.clone() * p.clone(); + } + result + } + Cache(_c, _) => { + unimplemented!("The method is supposed to be used for generic multivariate expressions, not tied to a specific use case like Kimchi with this constructor") + } + IfFeature(_c, _t, _f) => { + unimplemented!("The method is supposed to be used for generic multivariate expressions, not tied to a specific use case like Kimchi with this constructor") + } + } + } + + /// Build a value from an expression. + /// This method aims to be used to be retro-compatible with what we call + /// "the expression framework". + /// In the near future, the "expression framework" should be moved also into + /// this library. + /// + /// The mapping from variable to the user is left unspecified by this trait + /// and is left to the implementation. The conversion of a variable into an + /// index is done by the trait requirement `Into` on the column type. + /// + /// The parameter `offset_next_row` is an optional argument that is used to + /// support the case where the "next row" is used. In this case, the type + /// parameter `N` must include this offset (i.e. if 4 variables are in ued, + /// N should be at least `8 = 2 * 4`). + fn from_expr, ChallengeTerm: Clone>(expr: Expr, Column>, offset_next_row: Option) -> Self { + use kimchi::circuits::expr::Operations::*; + + match expr { + Atom(op_const) => { + match op_const { + ExprInner::UnnormalizedLagrangeBasis(_) => { + unimplemented!("Not used in this context") + } + ExprInner::VanishesOnZeroKnowledgeAndPreviousRows => { + unimplemented!("Not used in this context") + } + ExprInner::Constant(c) => Self::from_constant(c), + ExprInner::Cell(var) => { + Self::from_variable::(var, offset_next_row) + } + } + } + Add(e1, e2) => { + let p1 = Self::from_expr::(*e1, offset_next_row); + let p2 = Self::from_expr::(*e2, offset_next_row); + p1 + p2 + } + Sub(e1, e2) => { + let p1 = Self::from_expr::(*e1, offset_next_row); + let p2 = Self::from_expr::(*e2, offset_next_row); + p1 - p2 + } + Mul(e1, e2) => { + let p1 = Self::from_expr::(*e1, offset_next_row); + let p2 = Self::from_expr::(*e2, offset_next_row); + p1 * p2 + } + Double(p) => { + let p = Self::from_expr::(*p, offset_next_row); + p.double() + } + Square(p) => { + let p = Self::from_expr::(*p, offset_next_row); + p.clone() * p.clone() + } + Pow(c, e) => { + // FIXME: dummy implementation + let p = Self::from_expr::(*c, offset_next_row); + let mut result = p.clone(); + for _ in 0..e { + result = result.clone() * p.clone(); + } + result + } + Cache(_c, _) => { + unimplemented!("The method is supposed to be used for generic multivariate expressions, not tied to a specific use case like Kimchi with this constructor") + } + IfFeature(_c, _t, _f) => { + unimplemented!("The method is supposed to be used for generic multivariate expressions, not tied to a specific use case like Kimchi with this constructor") + } + } + } + + /// Returns true if the polynomial is homogeneous (of degree `D`). + /// As a reminder, a polynomial is homogeneous if all its monomials have the + /// same degree. + fn is_homogeneous(&self) -> bool; + + /// Evaluate the polynomial at the vector point `x` and the extra variable + /// `u` using its homogeneous form of degree D. + fn homogeneous_eval(&self, x: &[F; N], u: F) -> F; + + /// Add the monomial `coeff * x_1^{e_1} * ... * x_N^{e_N}` to the + /// polynomial, where `e_i` are the values given by the array `exponents`. + /// + /// For instance, to add the monomial `3 * x_1^2 * x_2^3` to the polynomial, + /// one would call `add_monomial([2, 3], 3)`. + fn add_monomial(&mut self, exponents: [usize; N], coeff: F); + + /// Compute the cross-terms as described in [Behind Nova: cross-terms + /// computation for high degree + /// gates](https://hackmd.io/@dannywillems/Syo5MBq90) + /// + /// The polynomial must not necessarily be homogeneous. For this reason, the + /// values `u1` and `u2` represents the extra variable that is used to make + /// the polynomial homogeneous. + /// + /// The homogeneous degree is supposed to be the one defined by the type of + /// the polynomial, i.e. `D`. + /// + /// The output is a map of `D - 1` values that represents the cross-terms + /// for each power of `r`. + fn compute_cross_terms( + &self, + eval1: &[F; N], + eval2: &[F; N], + u1: F, + u2: F, + ) -> HashMap; + + /// Modify the monomial in the polynomial to the new value `coeff`. + fn modify_monomial(&mut self, exponents: [usize; N], coeff: F); + + /// Return true if the multi-variate polynomial is multilinear, i.e. if each + /// variable in each monomial is of maximum degree 1. + fn is_multilinear(&self) -> bool; +} diff --git a/mvpoly/src/monomials.rs b/mvpoly/src/monomials.rs new file mode 100644 index 0000000000..8fb495ef02 --- /dev/null +++ b/mvpoly/src/monomials.rs @@ -0,0 +1,513 @@ +use ark_ff::{One, PrimeField, Zero}; +use kimchi::circuits::{expr::Variable, gate::CurrOrNext}; +use num_integer::binomial; +use rand::RngCore; +use std::{ + collections::HashMap, + fmt::Debug, + ops::{Add, Mul, Neg, Sub}, +}; + +use crate::{ + prime, + utils::{compute_indices_nested_loop, naive_prime_factors, PrimeNumberGenerator}, + MVPoly, +}; + +/// Represents a multivariate polynomial in `N` variables with coefficients in +/// `F`. The polynomial is represented as a sparse polynomial, where each +/// monomial is represented by a vector of `N` exponents. +// We could use u8 instead of usize for the exponents +// FIXME: the maximum degree D is encoded in the type to match the type +// prime::Dense +#[derive(Clone)] +pub struct Sparse { + pub monomials: HashMap<[usize; N], F>, +} + +impl Add for Sparse { + type Output = Self; + + fn add(self, other: Self) -> Self { + &self + &other + } +} + +impl Add<&Sparse> for Sparse { + type Output = Sparse; + + fn add(self, other: &Sparse) -> Self::Output { + &self + other + } +} + +impl Add> for &Sparse { + type Output = Sparse; + + fn add(self, other: Sparse) -> Self::Output { + self + &other + } +} +impl Add<&Sparse> for &Sparse { + type Output = Sparse; + + fn add(self, other: &Sparse) -> Self::Output { + let mut monomials = self.monomials.clone(); + for (exponents, coeff) in &other.monomials { + monomials + .entry(*exponents) + .and_modify(|c| *c += *coeff) + .or_insert(*coeff); + } + // Remove monomials with zero coefficients + let monomials: HashMap<[usize; N], F> = monomials + .into_iter() + .filter(|(_, coeff)| !coeff.is_zero()) + .collect(); + // Handle the case where the result is zero because we want a unique + // representation + if monomials.is_empty() { + Sparse::::zero() + } else { + Sparse:: { monomials } + } + } +} + +impl Debug for Sparse { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut monomials: Vec = self + .monomials + .iter() + .map(|(exponents, coeff)| { + let mut monomial = format!("{}", coeff); + for (i, exp) in exponents.iter().enumerate() { + if *exp == 0 { + continue; + } else if *exp == 1 { + monomial.push_str(&format!("x_{}", i)); + } else { + monomial.push_str(&format!("x_{}^{}", i, exp)); + } + } + monomial + }) + .collect(); + monomials.sort(); + write!(f, "{}", monomials.join(" + ")) + } +} + +impl Mul for Sparse { + type Output = Self; + + fn mul(self, other: Self) -> Self { + let mut monomials = HashMap::new(); + self.monomials.iter().for_each(|(exponents1, coeff1)| { + other + .monomials + .clone() + .iter() + .for_each(|(exponents2, coeff2)| { + let mut exponents = [0; N]; + for i in 0..N { + exponents[i] = exponents1[i] + exponents2[i]; + } + monomials + .entry(exponents) + .and_modify(|c| *c += *coeff1 * *coeff2) + .or_insert(*coeff1 * *coeff2); + }) + }); + // Remove monomials with zero coefficients + let monomials: HashMap<[usize; N], F> = monomials + .into_iter() + .filter(|(_, coeff)| !coeff.is_zero()) + .collect(); + if monomials.is_empty() { + Self::zero() + } else { + Self { monomials } + } + } +} + +impl Neg for Sparse { + type Output = Sparse; + + fn neg(self) -> Self::Output { + -&self + } +} + +impl Neg for &Sparse { + type Output = Sparse; + + fn neg(self) -> Self::Output { + let monomials: HashMap<[usize; N], F> = self + .monomials + .iter() + .map(|(exponents, coeff)| (*exponents, -*coeff)) + .collect(); + Sparse:: { monomials } + } +} + +impl Sub for Sparse { + type Output = Sparse; + + fn sub(self, other: Sparse) -> Self::Output { + self + (-other) + } +} + +impl Sub<&Sparse> for Sparse { + type Output = Sparse; + + fn sub(self, other: &Sparse) -> Self::Output { + self + (-other) + } +} + +impl Sub> for &Sparse { + type Output = Sparse; + + fn sub(self, other: Sparse) -> Self::Output { + self + (-other) + } +} +impl Sub<&Sparse> for &Sparse { + type Output = Sparse; + + fn sub(self, other: &Sparse) -> Self::Output { + self + (-other) + } +} + +/// Equality is defined as equality of the monomials. +impl PartialEq for Sparse { + fn eq(&self, other: &Self) -> bool { + self.monomials == other.monomials + } +} + +impl Eq for Sparse {} + +impl One for Sparse { + fn one() -> Self { + let mut monomials = HashMap::new(); + monomials.insert([0; N], F::one()); + Self { monomials } + } +} + +impl Zero for Sparse { + fn is_zero(&self) -> bool { + self.monomials.len() == 1 + && self.monomials.contains_key(&[0; N]) + && self.monomials[&[0; N]].is_zero() + } + + fn zero() -> Self { + let mut monomials = HashMap::new(); + monomials.insert([0; N], F::zero()); + Self { monomials } + } +} + +impl MVPoly for Sparse { + /// Returns the degree of the polynomial. + /// + /// The degree of the polynomial is the maximum degree of the monomials + /// that have a non-zero coefficient. + /// + /// # Safety + /// + /// The zero polynomial as a degree equals to 0, as the degree of the + /// constant polynomials. We do use the `unsafe` keyword to warn the user + /// for this specific case. + unsafe fn degree(&self) -> usize { + self.monomials + .keys() + .map(|exponents| exponents.iter().sum()) + .max() + .unwrap_or(0) + } + + /// Evaluate the polynomial at the vector point `x`. + /// + /// This is a dummy implementation. A cache can be used for the monomials to + /// speed up the computation. + fn eval(&self, x: &[F; N]) -> F { + self.monomials + .iter() + .map(|(exponents, coeff)| { + let mut term = F::one(); + for (exp, point) in exponents.iter().zip(x.iter()) { + term *= point.pow([*exp as u64]); + } + term * coeff + }) + .sum() + } + + fn is_constant(&self) -> bool { + self.monomials.len() == 1 && self.monomials.contains_key(&[0; N]) + } + + fn double(&self) -> Self { + let monomials: HashMap<[usize; N], F> = self + .monomials + .iter() + .map(|(exponents, coeff)| (*exponents, coeff.double())) + .collect(); + Self { monomials } + } + + fn mul_by_scalar(&self, scalar: F) -> Self { + if scalar.is_zero() { + Self::zero() + } else { + let monomials: HashMap<[usize; N], F> = self + .monomials + .iter() + .map(|(exponents, coeff)| (*exponents, *coeff * scalar)) + .collect(); + Self { monomials } + } + } + + /// Generate a random polynomial of maximum degree `max_degree`. + /// + /// If `None` is provided as the maximum degree, the polynomial will be + /// generated with a maximum degree of `D`. + /// + /// # Safety + /// + /// Marked as unsafe to warn the user to use it with caution and to not + /// necessarily rely on it for security/randomness in cryptographic + /// protocols. The user is responsible for providing its own secure + /// polynomial random generator, if needed. + /// + /// For now, the function is only used for testing. + unsafe fn random(rng: &mut RNG, max_degree: Option) -> Self { + // IMPROVEME: using prime::Dense::random to ease the implementaiton. + // Feel free to change + prime::Dense::random(rng, max_degree).into() + } + + fn from_variable>( + var: Variable, + offset_next_row: Option, + ) -> Self { + let Variable { col, row } = var; + // Manage offset + if row == CurrOrNext::Next { + assert!( + offset_next_row.is_some(), + "The offset must be provided for the next row" + ); + } + let offset = if row == CurrOrNext::Curr { + 0 + } else { + offset_next_row.unwrap() + }; + + // Build the corresponding monomial + let var_usize: usize = col.into(); + let idx = offset + var_usize; + + let mut monomials = HashMap::new(); + let exponents: [usize; N] = std::array::from_fn(|i| if i == idx { 1 } else { 0 }); + monomials.insert(exponents, F::one()); + Self { monomials } + } + + fn is_homogeneous(&self) -> bool { + self.monomials + .iter() + .all(|(exponents, _)| exponents.iter().sum::() == D) + } + + // IMPROVEME: powers can be cached + fn homogeneous_eval(&self, x: &[F; N], u: F) -> F { + self.monomials + .iter() + .map(|(exponents, coeff)| { + let mut term = F::one(); + for (exp, point) in exponents.iter().zip(x.iter()) { + term *= point.pow([*exp as u64]); + } + term *= u.pow([D as u64 - exponents.iter().sum::() as u64]); + term * coeff + }) + .sum() + } + + fn add_monomial(&mut self, exponents: [usize; N], coeff: F) { + self.monomials + .entry(exponents) + .and_modify(|c| *c += coeff) + .or_insert(coeff); + } + + fn compute_cross_terms( + &self, + eval1: &[F; N], + eval2: &[F; N], + u1: F, + u2: F, + ) -> HashMap { + assert!( + D >= 2, + "The degree of the polynomial must be greater than 2" + ); + let mut cross_terms_by_powers_of_r: HashMap = HashMap::new(); + // We iterate over each monomial with their respective coefficient + // i.e. we do have something like coeff * x_1^d_1 * x_2^d_2 * ... * x_N^d_N + self.monomials.iter().for_each(|(exponents, coeff)| { + // "Exponents" contains all powers, even the ones that are 0. We must + // get rid of them and keep the index to fetch the correct + // evaluation later + let non_zero_exponents_with_index: Vec<(usize, &usize)> = exponents + .iter() + .enumerate() + .filter(|(_, &d)| d != 0) + .collect(); + // coeff = 0 should not happen as we suppose we have a sparse polynomial + // Therefore, skipping a check + let non_zero_exponents: Vec = non_zero_exponents_with_index + .iter() + .map(|(_, d)| *d) + .copied() + .collect::>(); + let monomial_degree = non_zero_exponents.iter().sum::(); + let u_degree: usize = D - monomial_degree; + // Will be used to compute the nested sums + // It returns all the indices i_1, ..., i_k for the sums: + // Σ_{i_1 = 0}^{n_1} Σ_{i_2 = 0}^{n_2} ... Σ_{i_k = 0}^{n_k} + let indices = + compute_indices_nested_loop(non_zero_exponents.iter().map(|d| *d + 1).collect()); + for i in 0..=u_degree { + // Add the binomial from the homogeneisation + // i.e (u_degree choose i) + let u_binomial_term = binomial(u_degree, i); + // Now, we iterate over all the indices i_1, ..., i_k, i.e. we + // do over the whole sum, and we populate the map depending on + // the power of r + indices.iter().for_each(|indices| { + let sum_indices = indices.iter().sum::() + i; + // power of r is Σ (n_k - i_k) + let power_r: usize = D - sum_indices; + + // If the sum of the indices is 0 or D, we skip the + // computation as the contribution would go in the + // evaluation of the polynomial at each evaluation + // vectors eval1 and eval2 + if sum_indices == 0 || sum_indices == D { + return; + } + // Compute + // (n_1 choose i_1) * (n_2 choose i_2) * ... * (n_k choose i_k) + let binomial_term = indices + .iter() + .zip(non_zero_exponents.iter()) + .fold(u_binomial_term, |acc, (i, &d)| acc * binomial(d, *i)); + let binomial_term = F::from(binomial_term as u64); + // Compute the product x_k^i_k + // We ignore the power as it comes into account for the + // right evaluation. + // NB: we could merge both loops, but we keep them separate + // for readability + let eval_left = indices + .iter() + .zip(non_zero_exponents_with_index.iter()) + .fold(F::one(), |acc, (i, (idx, _d))| { + acc * eval1[*idx].pow([*i as u64]) + }); + // Compute the product x'_k^(n_k - i_k) + let eval_right = indices + .iter() + .zip(non_zero_exponents_with_index.iter()) + .fold(F::one(), |acc, (i, (idx, d))| { + acc * eval2[*idx].pow([(*d - *i) as u64]) + }); + // u1^i * u2^(u_degree - i) + let u = u1.pow([i as u64]) * u2.pow([(u_degree - i) as u64]); + let res = binomial_term * eval_left * eval_right * u; + let res = *coeff * res; + cross_terms_by_powers_of_r + .entry(power_r) + .and_modify(|e| *e += res) + .or_insert(res); + }) + } + }); + cross_terms_by_powers_of_r + } + + fn modify_monomial(&mut self, exponents: [usize; N], coeff: F) { + self.monomials + .entry(exponents) + .and_modify(|c| *c = coeff) + .or_insert(coeff); + } + + fn is_multilinear(&self) -> bool { + self.monomials + .iter() + .all(|(exponents, _)| exponents.iter().all(|&d| d <= 1)) + } +} + +impl From> + for Sparse +{ + fn from(dense: prime::Dense) -> Self { + let mut prime_gen = PrimeNumberGenerator::new(); + let primes = prime_gen.get_first_nth_primes(N); + let mut monomials = HashMap::new(); + let normalized_indices = prime::Dense::::compute_normalized_indices(); + dense.iter().enumerate().for_each(|(i, coeff)| { + if *coeff != F::zero() { + let mut exponents = [0; N]; + let inv_idx = normalized_indices[i]; + let prime_decomposition_of_index = naive_prime_factors(inv_idx, &mut prime_gen); + prime_decomposition_of_index + .into_iter() + .for_each(|(prime, exp)| { + let inv_prime_idx = primes.iter().position(|&p| p == prime).unwrap(); + exponents[inv_prime_idx] = exp; + }); + monomials.insert(exponents, *coeff); + } + }); + Self { monomials } + } +} + +impl From for Sparse { + fn from(value: F) -> Self { + let mut result = Self::zero(); + result.modify_monomial([0; N], value); + result + } +} + +impl From> + for Result, String> +{ + fn from(poly: Sparse) -> Result, String> { + if M < N { + return Err("The number of variables must be greater than N".to_string()); + } + let mut monomials = HashMap::new(); + poly.monomials.iter().for_each(|(exponents, coeff)| { + let mut new_exponents = [0; M]; + new_exponents[0..N].copy_from_slice(&exponents[0..N]); + monomials.insert(new_exponents, *coeff); + }); + Ok(Sparse { monomials }) + } +} diff --git a/mvpoly/src/pbt.rs b/mvpoly/src/pbt.rs new file mode 100644 index 0000000000..b517f0107f --- /dev/null +++ b/mvpoly/src/pbt.rs @@ -0,0 +1,595 @@ +//! This module contains a list of property tests for the `MVPoly` trait. +//! +//! Any type that implements the `MVPoly` trait should pass these tests. +//! +//! For instance, one can call the `test_mul_by_one` as follows: +//! +//! ```rust +//! use mvpoly::MVPoly; +//! use mvpoly::prime::Dense; +//! use mina_curves::pasta::Fp; +//! +//! #[test] +//! fn test_mul_by_one() { +//! mvpoly::pbt::test_mul_by_one::>(); +//! mvpoly::pbt::test_mul_by_one::>(); +//! } +//! ``` + +use crate::MVPoly; +use ark_ff::PrimeField; +use rand::{seq::SliceRandom, Rng}; +use std::ops::Neg; + +pub fn test_mul_by_one>() { + let mut rng = o1_utils::tests::make_test_rng(None); + let p1 = unsafe { T::random(&mut rng, None) }; + let one = T::one(); + let p2 = p1.clone() * one.clone(); + assert_eq!(p1.clone(), p2); + let p3 = one * p1.clone(); + assert_eq!(p1.clone(), p3); +} + +pub fn test_mul_by_zero>() { + let mut rng = o1_utils::tests::make_test_rng(None); + let p1 = unsafe { T::random(&mut rng, None) }; + let zero = T::zero(); + let p2 = p1.clone() * zero.clone(); + assert_eq!(zero, p2); + let p3 = zero.clone() * p1.clone(); + assert_eq!(zero.clone(), p3); +} + +pub fn test_add_zero>() { + let mut rng = o1_utils::tests::make_test_rng(None); + let p1 = unsafe { T::random(&mut rng, None) }; + let zero = T::zero(); + let p2 = p1.clone() + zero.clone(); + assert_eq!(p1.clone(), p2); + let p3 = zero.clone() + p1.clone(); + assert_eq!(p1.clone(), p3); +} + +pub fn test_double_is_add_twice< + F: PrimeField, + const N: usize, + const D: usize, + T: MVPoly, +>() { + let mut rng = o1_utils::tests::make_test_rng(None); + let p1 = unsafe { T::random(&mut rng, None) }; + let p2 = p1.clone() + p1.clone(); + let p3 = p1.clone().double(); + assert_eq!(p2, p3); +} + +pub fn test_sub_zero>() { + let mut rng = o1_utils::tests::make_test_rng(None); + let p1 = unsafe { T::random(&mut rng, None) }; + let zero = T::zero(); + let p2 = p1.clone() - zero.clone(); + assert_eq!(p1.clone(), p2); +} + +pub fn test_neg>() { + let mut rng = o1_utils::tests::make_test_rng(None); + let p1 = unsafe { T::random(&mut rng, None) }; + let p2 = -p1.clone(); + // Test that p1 + (-p1) = 0 + let sum = p1.clone() + p2.clone(); + assert_eq!(sum, T::zero()); + // Test that -(-p1) = p1 + let p3 = -p2; + assert_eq!(p1, p3); + // Test negation of zero + let zero = T::zero(); + let neg_zero = -zero.clone(); + assert_eq!(zero, neg_zero); +} + +pub fn test_eval_pbt_add>() +where + for<'a> &'a T: std::ops::Add + std::ops::Add<&'a T, Output = T>, +{ + let mut rng = o1_utils::tests::make_test_rng(None); + let random_evaluation: [F; N] = std::array::from_fn(|_| F::rand(&mut rng)); + let p1 = unsafe { T::random(&mut rng, None) }; + let p2 = unsafe { T::random(&mut rng, None) }; + let eval_p1 = p1.eval(&random_evaluation); + let eval_p2 = p2.eval(&random_evaluation); + { + let p3 = p1.clone() + p2.clone(); + let eval_p3 = p3.eval(&random_evaluation); + assert_eq!(eval_p3, eval_p1 + eval_p2); + } + { + let p3 = p1.clone() + &p2; + let eval_p3 = p3.eval(&random_evaluation); + assert_eq!(eval_p3, eval_p1 + eval_p2); + } + { + let p3 = &p1 + p2.clone(); + let eval_p3 = p3.eval(&random_evaluation); + assert_eq!(eval_p3, eval_p1 + eval_p2); + } + { + let p3 = &p1 + &p2; + let eval_p3 = p3.eval(&random_evaluation); + assert_eq!(eval_p3, eval_p1 + eval_p2); + } +} + +pub fn test_eval_pbt_sub>() +where + for<'a> &'a T: std::ops::Sub + std::ops::Sub<&'a T, Output = T>, +{ + let mut rng = o1_utils::tests::make_test_rng(None); + let random_evaluation: [F; N] = std::array::from_fn(|_| F::rand(&mut rng)); + let p1 = unsafe { T::random(&mut rng, None) }; + let p2 = unsafe { T::random(&mut rng, None) }; + let eval_p1 = p1.eval(&random_evaluation); + let eval_p2 = p2.eval(&random_evaluation); + { + let p3 = p1.clone() - p2.clone(); + let eval_p3 = p3.eval(&random_evaluation); + assert_eq!(eval_p3, eval_p1 - eval_p2); + } + { + let p3 = p1.clone() - &p2; + let eval_p3 = p3.eval(&random_evaluation); + assert_eq!(eval_p3, eval_p1 - eval_p2); + } + { + let p3 = &p1 - p2.clone(); + let eval_p3 = p3.eval(&random_evaluation); + assert_eq!(eval_p3, eval_p1 - eval_p2); + } + { + let p3 = &p1 - &p2; + let eval_p3 = p3.eval(&random_evaluation); + assert_eq!(eval_p3, eval_p1 - eval_p2); + } +} + +pub fn test_eval_pbt_mul_by_scalar< + F: PrimeField, + const N: usize, + const D: usize, + T: MVPoly, +>() { + let mut rng = o1_utils::tests::make_test_rng(None); + let random_evaluation: [F; N] = std::array::from_fn(|_| F::rand(&mut rng)); + let p1 = unsafe { T::random(&mut rng, None) }; + let c = F::rand(&mut rng); + let p2 = p1.clone() * T::from(c); + let eval_p1 = p1.eval(&random_evaluation); + let eval_p2 = p2.eval(&random_evaluation); + assert_eq!(eval_p2, eval_p1 * c); +} + +pub fn test_eval_pbt_neg>() { + let mut rng = o1_utils::tests::make_test_rng(None); + let random_evaluation: [F; N] = std::array::from_fn(|_| F::rand(&mut rng)); + let p1 = unsafe { T::random(&mut rng, None) }; + let p2 = -p1.clone(); + let eval_p1 = p1.eval(&random_evaluation); + let eval_p2 = p2.eval(&random_evaluation); + assert_eq!(eval_p2, -eval_p1); +} + +pub fn test_neg_ref>() +where + for<'a> &'a T: Neg, +{ + let mut rng = o1_utils::tests::make_test_rng(None); + let p1 = unsafe { T::random(&mut rng, None) }; + let p2 = -&p1; + // Test that p1 + (-&p1) = 0 + let sum = p1.clone() + p2.clone(); + assert_eq!(sum, T::zero()); + // Test that -(-&p1) = p1 + let p3 = -&p2; + assert_eq!(p1, p3); +} + +pub fn test_mul_by_scalar>() { + let mut rng = o1_utils::tests::make_test_rng(None); + let p1 = unsafe { T::random(&mut rng, None) }; + let mut p2 = T::zero(); + let c = F::rand(&mut rng); + p2.modify_monomial([0; N], c); + assert_eq!(p2 * p1.clone(), p1.clone().mul_by_scalar(c)); +} + +pub fn test_mul_by_scalar_with_zero< + F: PrimeField, + const N: usize, + const D: usize, + T: MVPoly, +>() { + let mut rng = o1_utils::tests::make_test_rng(None); + let p1 = unsafe { T::random(&mut rng, None) }; + let c = F::zero(); + assert_eq!(p1.mul_by_scalar(c), T::zero()); +} + +pub fn test_mul_by_scalar_with_one< + F: PrimeField, + const N: usize, + const D: usize, + T: MVPoly, +>() { + let mut rng = o1_utils::tests::make_test_rng(None); + let p1 = unsafe { T::random(&mut rng, None) }; + let c = F::one(); + assert_eq!(p1.mul_by_scalar(c), p1); +} + +pub fn test_evaluation_zero_polynomial< + F: PrimeField, + const N: usize, + const D: usize, + T: MVPoly, +>() { + let mut rng = o1_utils::tests::make_test_rng(None); + let random_evaluation: [F; N] = std::array::from_fn(|_| F::rand(&mut rng)); + let zero = T::zero(); + let evaluation = zero.eval(&random_evaluation); + assert_eq!(evaluation, F::zero()); +} + +pub fn test_evaluation_constant_polynomial< + F: PrimeField, + const N: usize, + const D: usize, + T: MVPoly, +>() { + let mut rng = o1_utils::tests::make_test_rng(None); + let random_evaluation: [F; N] = std::array::from_fn(|_| F::rand(&mut rng)); + let cst = F::rand(&mut rng); + let poly = T::from(cst); + let evaluation = poly.eval(&random_evaluation); + assert_eq!(evaluation, cst); +} + +pub fn test_degree_constant>() { + let mut rng = o1_utils::tests::make_test_rng(None); + let c = F::rand(&mut rng); + let p = T::from(c); + let degree = unsafe { p.degree() }; + assert_eq!(degree, 0); + let p = T::zero(); + let degree = unsafe { p.degree() }; + assert_eq!(degree, 0); +} + +pub fn test_degree_random_degree< + F: PrimeField, + const N: usize, + const D: usize, + T: MVPoly, +>() { + let mut rng = o1_utils::tests::make_test_rng(None); + let max_degree: usize = rng.gen_range(1..D); + let p = unsafe { T::random(&mut rng, Some(max_degree)) }; + let degree = unsafe { p.degree() }; + assert!(degree <= max_degree); +} + +pub fn test_mvpoly_add_degree_pbt< + F: PrimeField, + const N: usize, + const D: usize, + T: MVPoly, +>() { + let mut rng = o1_utils::tests::make_test_rng(None); + let degree = rng.gen_range(1..D); + let p1 = unsafe { T::random(&mut rng, Some(degree)) }; + let p2 = unsafe { T::random(&mut rng, Some(degree)) }; + let p3 = p1.clone() + p2.clone(); + let degree_p1 = unsafe { p1.degree() }; + let degree_p2 = unsafe { p2.degree() }; + let degree_p3 = unsafe { p3.degree() }; + assert!(degree_p3 <= std::cmp::max(degree_p1, degree_p2)); +} + +pub fn test_mvpoly_sub_degree_pbt< + F: PrimeField, + const N: usize, + const D: usize, + T: MVPoly, +>() { + let mut rng = o1_utils::tests::make_test_rng(None); + let degree = rng.gen_range(1..D); + let p1 = unsafe { T::random(&mut rng, Some(degree)) }; + let p2 = unsafe { T::random(&mut rng, Some(degree)) }; + let p3 = p1.clone() - p2.clone(); + let degree_p1 = unsafe { p1.degree() }; + let degree_p2 = unsafe { p2.degree() }; + let degree_p3 = unsafe { p3.degree() }; + assert!(degree_p3 <= std::cmp::max(degree_p1, degree_p2)); +} + +pub fn test_mvpoly_neg_degree_pbt< + F: PrimeField, + const N: usize, + const D: usize, + T: MVPoly, +>() { + let mut rng = o1_utils::tests::make_test_rng(None); + let degree = rng.gen_range(1..D); + let p1 = unsafe { T::random(&mut rng, Some(degree)) }; + let p2 = -p1.clone(); + let degree_p1 = unsafe { p1.degree() }; + let degree_p2 = unsafe { p2.degree() }; + assert_eq!(degree_p1, degree_p2); +} + +pub fn test_mvpoly_mul_by_scalar_degree_pbt< + F: PrimeField, + const N: usize, + const D: usize, + T: MVPoly, +>() { + let mut rng = o1_utils::tests::make_test_rng(None); + let max_degree = rng.gen_range(1..D / 2); + let p1 = unsafe { T::random(&mut rng, Some(max_degree)) }; + let c = F::rand(&mut rng); + let p2 = p1.clone() * T::from(c); + let degree_p1 = unsafe { p1.degree() }; + let degree_p2 = unsafe { p2.degree() }; + assert!(degree_p2 <= degree_p1); +} + +pub fn test_mvpoly_mul_degree_pbt< + F: PrimeField, + const N: usize, + const D: usize, + T: MVPoly, +>() { + let mut rng = o1_utils::tests::make_test_rng(None); + let max_degree = rng.gen_range(1..D / 2); + let p1 = unsafe { T::random(&mut rng, Some(max_degree)) }; + let p2 = unsafe { T::random(&mut rng, Some(max_degree)) }; + let p3 = p1.clone() * p2.clone(); + let degree_p1 = unsafe { p1.degree() }; + let degree_p2 = unsafe { p2.degree() }; + let degree_p3 = unsafe { p3.degree() }; + assert!(degree_p3 <= degree_p1 + degree_p2); +} + +pub fn test_mvpoly_mul_eval_pbt< + F: PrimeField, + const N: usize, + const D: usize, + T: MVPoly, +>() { + let mut rng = o1_utils::tests::make_test_rng(None); + let max_degree = rng.gen_range(1..D / 2); + let p1 = unsafe { T::random(&mut rng, Some(max_degree)) }; + let p2 = unsafe { T::random(&mut rng, Some(max_degree)) }; + let p3 = p1.clone() * p2.clone(); + let random_evaluation: [F; N] = std::array::from_fn(|_| F::rand(&mut rng)); + let eval_p1 = p1.eval(&random_evaluation); + let eval_p2 = p2.eval(&random_evaluation); + let eval_p3 = p3.eval(&random_evaluation); + assert_eq!(eval_p3, eval_p1 * eval_p2); +} + +pub fn test_mvpoly_mul_pbt>() { + let mut rng = o1_utils::tests::make_test_rng(None); + let max_degree = rng.gen_range(1..D / 2); + let p1 = unsafe { T::random(&mut rng, Some(max_degree)) }; + let p2 = unsafe { T::random(&mut rng, Some(max_degree)) }; + assert_eq!(p1.clone() * p2.clone(), p2.clone() * p1.clone()); +} + +pub fn test_can_be_printed_with_debug< + F: PrimeField, + const N: usize, + const D: usize, + T: MVPoly, +>() { + let mut rng = o1_utils::tests::make_test_rng(None); + let p1 = unsafe { T::random(&mut rng, None) }; + println!("{:?}", p1); +} + +pub fn test_is_zero>() { + let mut rng = o1_utils::tests::make_test_rng(None); + let p1 = T::zero(); + assert!(p1.is_zero()); + let p2 = unsafe { T::random(&mut rng, None) }; + assert!(!p2.is_zero()); +} + +pub fn test_homogeneous_eval>() { + let mut rng = o1_utils::tests::make_test_rng(None); + let random_eval = std::array::from_fn(|_| F::rand(&mut rng)); + let u = F::rand(&mut rng); + + // Homogeneous form is u^2 + let p1 = T::one(); + let homogenous_eval = p1.homogeneous_eval(&random_eval, u); + assert_eq!(homogenous_eval, u * u); + + let mut p2 = T::zero(); + let mut exp1 = [0; N]; + exp1[0] = 1; + p2.add_monomial(exp1, F::one()); + let homogenous_eval = p2.homogeneous_eval(&random_eval, u); + assert_eq!(homogenous_eval, random_eval[0] * u); + + let mut p3 = T::zero(); + let mut exp2 = [0; N]; + exp2[1] = 1; + p3.add_monomial(exp2, F::one()); + let homogenous_eval = p3.homogeneous_eval(&random_eval, u); + assert_eq!(homogenous_eval, random_eval[1] * u); + + let mut p4 = T::zero(); + let mut exp3 = [0; N]; + exp3[0] = 1; + exp3[1] = 1; + p4.add_monomial(exp3, F::one()); + let homogenous_eval = p4.homogeneous_eval(&random_eval, u); + assert_eq!(homogenous_eval, random_eval[0] * random_eval[1]); + + let mut p5 = T::zero(); + let mut exp4 = [0; N]; + exp4[0] = 2; + p5.add_monomial(exp4, F::one()); + let homogenous_eval = p5.homogeneous_eval(&random_eval, u); + assert_eq!(homogenous_eval, random_eval[0] * random_eval[0]); + + let mut p6 = T::zero(); + let mut exp5a = [0; N]; + let mut exp5b = [0; N]; + exp5a[1] = 2; + exp5b[0] = 2; + p6.add_monomial(exp5a, F::one()); + p6.add_monomial(exp5b, F::one()); + let homogenous_eval = p6.homogeneous_eval(&random_eval, u); + assert_eq!( + homogenous_eval, + random_eval[1] * random_eval[1] + random_eval[0] * random_eval[0] + ); + + let mut p7 = T::zero(); + let mut exp6a = [0; N]; + let mut exp6b = [0; N]; + let mut exp6c = [0; N]; + exp6a[1] = 2; + exp6b[0] = 2; + exp6c[0] = 1; + p7.add_monomial(exp6a, F::one()); + p7.add_monomial(exp6b, F::one()); + p7.add_monomial(exp6c, F::one()); + p7.add_monomial([0; N], F::from(42u32)); + let homogenous_eval = p7.homogeneous_eval(&random_eval, u); + assert_eq!( + homogenous_eval, + random_eval[1] * random_eval[1] + + random_eval[0] * random_eval[0] + + u * random_eval[0] + + u * u * F::from(42u32) + ); +} + +pub fn test_add_monomial>() { + let mut rng = o1_utils::tests::make_test_rng(None); + + // Adding constant monomial one to zero + let mut p1 = T::zero(); + p1.add_monomial([0; N], F::one()); + assert_eq!(p1, T::one()); + + // Adding random constant monomial one to zero + let mut p2 = T::zero(); + let random_c = F::rand(&mut rng); + p2.add_monomial([0; N], random_c); + assert_eq!(p2, T::from(random_c)); + + let mut p3 = T::zero(); + let random_c1 = F::rand(&mut rng); + let random_c2 = F::rand(&mut rng); + // X1 + X2 + let mut exp1 = [0; N]; + let mut exp2 = [0; N]; + exp1[0] = 1; + exp2[1] = 1; + p3.add_monomial(exp1, random_c1); + p3.add_monomial(exp2, random_c2); + + let random_eval = std::array::from_fn(|_| F::rand(&mut rng)); + let eval_p3 = p3.eval(&random_eval); + let exp_eval_p3 = random_c1 * random_eval[0] + random_c2 * random_eval[1]; + assert_eq!(eval_p3, exp_eval_p3); + + let mut p4 = T::zero(); + let random_c1 = F::rand(&mut rng); + let random_c2 = F::rand(&mut rng); + // X1^2 + X2^2 + let mut exp1 = [0; N]; + let mut exp2 = [0; N]; + exp1[0] = 2; + exp2[1] = 2; + p4.add_monomial(exp1, random_c1); + p4.add_monomial(exp2, random_c2); + let eval_p4 = p4.eval(&random_eval); + let exp_eval_p4 = + random_c1 * random_eval[0] * random_eval[0] + random_c2 * random_eval[1] * random_eval[1]; + assert_eq!(eval_p4, exp_eval_p4); +} + +pub fn test_is_multilinear>() { + let mut rng = o1_utils::tests::make_test_rng(None); + + // Test with zero polynomial + let p1 = T::zero(); + assert!(p1.is_multilinear()); + + // Test with a constant polynomial + let c = F::rand(&mut rng); + let p2 = T::from(c); + assert!(p2.is_multilinear()); + + // Test with a polynomial with one variable having a linear monomial + { + let mut p = T::zero(); + let c = F::rand(&mut rng); + let idx = rng.gen_range(0..N); + let monomials_exponents = std::array::from_fn(|i| if i == idx { 1 } else { 0 }); + p.add_monomial(monomials_exponents, c); + assert!(p.is_multilinear()); + } + + // Test with a multilinear polynomial with random variables + { + let mut p = T::zero(); + let c = F::rand(&mut rng); + let nb_var = rng.gen_range(0..D); + let mut monomials_exponents: [usize; N] = + std::array::from_fn(|i| if i <= nb_var { 1 } else { 0 }); + monomials_exponents.shuffle(&mut rng); + p.add_monomial(monomials_exponents, c); + assert!(p.is_multilinear()); + } + + // Test with a random polynomial (very unlikely to be multilinear) + { + let p = unsafe { T::random(&mut rng, None) }; + assert!(!p.is_multilinear()); + } +} + +pub fn test_is_constant>() { + let mut rng = o1_utils::tests::make_test_rng(None); + let c = F::rand(&mut rng); + let p = T::from(c); + assert!(p.is_constant()); + + let p = T::zero(); + assert!(p.is_constant()); + + let p = { + let mut res = T::zero(); + let monomial: [usize; N] = std::array::from_fn(|i| if i == 0 { 1 } else { 0 }); + res.add_monomial(monomial, F::one()); + res + }; + assert!(!p.is_constant()); + + let p = { + let mut res = T::zero(); + let monomial: [usize; N] = std::array::from_fn(|i| if i == 1 { 1 } else { 0 }); + res.add_monomial(monomial, F::one()); + res + }; + assert!(!p.is_constant()); + + // This might be flaky + let p = unsafe { T::random(&mut rng, None) }; + assert!(!p.is_constant()); +} diff --git a/mvpoly/src/prime.rs b/mvpoly/src/prime.rs new file mode 100644 index 0000000000..498cc1c7ee --- /dev/null +++ b/mvpoly/src/prime.rs @@ -0,0 +1,747 @@ +//! Multivariate polynomial dense representation using prime numbers +//! +//! First, we start by attributing a different prime number for each variable. +//! For instance, for `F^{<=2}[X_{1}, X_{2}]`, we assign `X_{1}` to `2` +//! and $X_{2}$ to $3$. +//! From there, we note `X_{1} X_{2}` as the value `6`, `X_{1}^2` as `4`, `X_{2}^2` +//! as 9. The constant is `1`. +//! +//! From there, we represent our polynomial coefficients in a sparse list. Some +//! cells, noted `NA`, won't be used for certain vector spaces. +//! +//! For instance, `X_{1} + X_{2}` will be represented as: +//! ```text +//! [0, 1, 1, 0, 0, 0, 0, 0, 0] +//! | | | | | | | | | +//! 1 2 3 4 5 6 7 8 9 +//! | | | | | | | | | +//! cst X1 X2 X1^2 NA X1*X2 NA NA X2^2 +//! ``` +//! +//! and the polynomial `42 X_{1} + 3 X_{1} X_{2} + 14 X_{2}^2` will be represented +//! as +//! +//! ```text +//! [0, 42, 1, 0, 0, 3, 0, 0, 14] +//! | | | | | | | | | +//! 1 2 3 4 5 6 7 8 9 +//! | | | | | | | | | +//! cst X1 X2 X1^2 NA X1*X2 NA NA X2^2 +//! ``` +//! +//! Adding two polynomials in this base is pretty straightforward: we simply add the +//! coefficients of the two lists. +//! +//! Multiplication is not more complicated. +//! To compute the result of $P_{1} * P_{2}$, the value of index $i$ will be the sum +//! of the decompositions. +//! +//! For instance, if we take `P_{1}(X_{1}) = 2 X_{1} + X_{2}` and `P_{2}(X_{1}, +//! X_{2}) = X_{2} + 3`, the expected product is +//! `P_{3}(X_{1}, X_{2}) = (2 X_{1} + X_{2}) * (X_{2} + 3) = 2 X_{1} X_{2} + 6 +//! X_{1} + 3 X_{2} + X_{2}^2` +//! +//! Given in the representation above, we have: +//! +//! ```text +//! For P_{1}: +//! +//! [0, 2, 1, 0, 0, 0, 0, 0, 0] +//! | | | | | | | | | +//! 1 2 3 4 5 6 7 8 9 +//! | | | | | | | | | +//! cst X1 X2 X1^2 NA X1*X2 NA NA X2^2 +//! +//! ``` +//! +//! ```text +//! For P_{2}: +//! +//! [3, 0, 1, 0, 0, 0, 0, 0, 0] +//! | | | | | | | | | +//! 1 2 3 4 5 6 7 8 9 +//! | | | | | | | | | +//! cst X1 X2 X1^2 NA X1*X2 NA NA X2^2 +//! +//! ``` +//! +//! +//! ```text +//! For P_{3}: +//! +//! [0, 6, 3, 0, 0, 2, 0, 0, 1] +//! | | | | | | | | | +//! 1 2 3 4 5 6 7 8 9 +//! | | | | | | | | | +//! cst X1 X2 X1^2 NA X1*X2 NA NA X2^2 +//! +//! ``` +//! +//! To compute `P_{3}`, we get iterate over an empty list of $9$ elements which will +//! define `P_{3}`. +//! +//! For index `1`, we multiply `P_{1}[1]` and `P_{1}[1]`. +//! +//! FOr index $2$, the only way to get this index is by fetching $2$ in each list. +//! Therefore, we do `P_{1}[2] P_{2}[1] + P_{2}[2] * P_{1}[1] = 2 * 3 + 0 * 0 = 6`. +//! +//! For index `3`, same than for `2`. +//! +//! For index `4`, we have `4 = 2 * 2`, therefore, we multiply `P_{1}[2]` and `P_{2}[2]` +//! +//! For index `6`, we have `6 = 2 * 3` and `6 = 3 * 2`, which are the prime +//! decompositions of $6$. Therefore we sum `P_{1}[2] * P_{2}[3]` and `P_{2}[2] * +//! P_{1}[3]`. +//! +//! For index $9$, we have $9 = 3 * 3$, therefore we do the same than for $4$. +//! +//! This can be generalized. +//! +//! The algorithm is as follow: +//! - for each cell `j`: +//! - if `j` is prime, compute `P_{1}[j] P_{2}[1] + P_{2}[j] P_{1}[1]` +//! - else: +//! - take the prime decompositions of `j` (and their permutations). +//! - for each decomposition, compute the product +//! - sum +//! +//! +//! #### Other examples degree $2$ with 3 variables. +//! +//! ```math +//! \begin{align} +//! $\mathbb{F}^{\le 2}[X_{1}, X_{2}, X_{3}] = \{ +//! & \, a_{0} + \\ +//! & \, a_{1} X_{1} + \\ +//! & \, a_{2} X_{2} + \\ +//! & \, a_{3} X_{3} + \\ +//! & \, a_{4} X_{1} X_{2} + \\ +//! & \, a_{5} X_{2} X_{3} + \\ +//! & \, a_{6} X_{1} X_{3} + \\ +//! & \, a_{7} X_{1}^2 + \\ +//! & \, a_{8} X_{2}^2 + \\ +//! & \, a_{9} X_{3}^2 \, | \, a_{i} \in \mathbb{F} +//! \} +//! \end{align} +//! ``` +//! +//! We assign: +//! +//! - `X_{1} = 2` +//! - `X_{2} = 3` +//! - `X_{3} = 5` +//! +//! And therefore, we have: +//! - `X_{1}^2 = 4` +//! - `X_{1} X_{2} = 6` +//! - `X_{1} X_{3} = 10` +//! - `X_{2}^2 = 9` +//! - `X_{2} X_{3} = 15` +//! - `X_{3}^2 = 25` +//! +//! We have an array with 25 indices, even though we need 10 elements only. + +use std::{ + collections::HashMap, + fmt::{Debug, Formatter, Result}, + ops::{Add, Mul, Neg, Sub}, +}; + +use ark_ff::{One, PrimeField, Zero}; +use kimchi::circuits::{expr::Variable, gate::CurrOrNext}; +use num_integer::binomial; +use o1_utils::FieldHelpers; +use rand::{Rng, RngCore}; +use std::ops::{Index, IndexMut}; + +use crate::{ + utils::{compute_all_two_factors_decomposition, naive_prime_factors, PrimeNumberGenerator}, + MVPoly, +}; + +/// Represents a multivariate polynomial of degree less than `D` in `N` variables. +/// The representation is dense, i.e., all coefficients are stored. +/// The polynomial is represented as a vector of coefficients, where the index +/// of the coefficient corresponds to the index of the monomial. +/// A mapping between the index and the prime decomposition of the monomial is +/// stored in `normalized_indices`. +#[derive(Clone)] +pub struct Dense { + coeff: Vec, + // keeping track of the indices of the monomials that are normalized + // to avoid recomputing them + // FIXME: this should be stored somewhere else; we should not have it for + // each polynomial + normalized_indices: Vec, +} + +impl Index for Dense { + type Output = F; + + fn index(&self, index: usize) -> &Self::Output { + &self.coeff[index] + } +} + +impl IndexMut for Dense { + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + &mut self.coeff[index] + } +} + +impl Zero for Dense { + fn is_zero(&self) -> bool { + self.coeff.iter().all(|c| c.is_zero()) + } + + fn zero() -> Self { + Dense { + coeff: vec![F::zero(); Self::dimension()], + normalized_indices: Self::compute_normalized_indices(), + } + } +} + +impl One for Dense { + fn one() -> Self { + let mut result = Dense::zero(); + result.coeff[0] = F::one(); + result + } +} + +impl MVPoly for Dense { + /// Generate a random polynomial of maximum degree `max_degree`. + /// + /// If `None` is provided as the maximum degree, the polynomial will be + /// generated with a maximum degree of `D`. + /// + /// # Safety + /// + /// Marked as unsafe to warn the user to use it with caution and to not + /// necessarily rely on it for security/randomness in cryptographic + /// protocols. The user is responsible for providing its own secure + /// polynomial random generator, if needed. + /// + /// In addition to that, zeroes coefficients are added one every 10 + /// monomials to be sure we do have some sparcity in the polynomial. + /// + /// For now, the function is only used for testing. + unsafe fn random(rng: &mut RNG, max_degree: Option) -> Self { + let mut prime_gen = PrimeNumberGenerator::new(); + let normalized_indices = Self::compute_normalized_indices(); + // Different cases to avoid complexity in the case no maximum degree is + // provided + let coeff = if let Some(max_degree) = max_degree { + normalized_indices + .iter() + .map(|idx| { + let degree = naive_prime_factors(*idx, &mut prime_gen) + .iter() + .fold(0, |acc, (_, d)| acc + d); + if degree > max_degree || rng.gen_range(0..10) == 0 { + // Adding zero coefficients one every 10 monomials + F::zero() + } else { + F::rand(rng) + } + }) + .collect::>() + } else { + normalized_indices + .iter() + .map(|_| { + if rng.gen_range(0..10) == 0 { + // Adding zero coefficients one every 10 monomials + F::zero() + } else { + F::rand(rng) + } + }) + .collect() + }; + Self { + coeff, + normalized_indices, + } + } + + fn is_constant(&self) -> bool { + self.coeff.iter().skip(1).all(|c| c.is_zero()) + } + + /// Returns the degree of the polynomial. + /// + /// The degree of the polynomial is the maximum degree of the monomials + /// that have a non-zero coefficient. + /// + /// # Safety + /// + /// The zero polynomial as a degree equals to 0, as the degree of the + /// constant polynomials. We do use the `unsafe` keyword to warn the user + /// for this specific case. + unsafe fn degree(&self) -> usize { + if self.is_constant() { + return 0; + } + let mut prime_gen = PrimeNumberGenerator::new(); + self.coeff.iter().enumerate().fold(1, |acc, (i, c)| { + if *c != F::zero() { + let decomposition_of_i = + naive_prime_factors(self.normalized_indices[i], &mut prime_gen); + let monomial_degree = decomposition_of_i.iter().fold(0, |acc, (_, d)| acc + d); + acc.max(monomial_degree) + } else { + acc + } + }) + } + + fn double(&self) -> Self { + let coeffs = self.coeff.iter().map(|c| c.double()).collect(); + Self::from_coeffs(coeffs) + } + + fn mul_by_scalar(&self, c: F) -> Self { + let coeffs = self.coeff.iter().map(|coef| *coef * c).collect(); + Self::from_coeffs(coeffs) + } + + /// Evaluate the polynomial at the vector point `x`. + /// + /// This is a dummy implementation. A cache can be used for the monomials to + /// speed up the computation. + fn eval(&self, x: &[F; N]) -> F { + let mut prime_gen = PrimeNumberGenerator::new(); + let primes = prime_gen.get_first_nth_primes(N); + self.coeff + .iter() + .enumerate() + .fold(F::zero(), |acc, (i, c)| { + if i == 0 { + acc + c + } else { + let normalized_index = self.normalized_indices[i]; + // IMPROVEME: we should keep the prime decomposition somewhere. + // It can be precomputed for a few multi-variate polynomials + // vector space + let prime_decomposition = naive_prime_factors(normalized_index, &mut prime_gen); + let mut monomial = F::one(); + prime_decomposition.iter().for_each(|(p, d)| { + // IMPROVEME: we should keep the inverse indices + let inv_p = primes.iter().position(|&x| x == *p).unwrap(); + let x_p = x[inv_p].pow([*d as u64]); + monomial *= x_p; + }); + acc + *c * monomial + } + }) + } + + fn from_variable>( + var: Variable, + offset_next_row: Option, + ) -> Self { + let Variable { col, row } = var; + if row == CurrOrNext::Next { + assert!( + offset_next_row.is_some(), + "The offset for the next row must be provided" + ); + } + let offset = if row == CurrOrNext::Curr { + 0 + } else { + offset_next_row.unwrap() + }; + let var_usize: usize = col.into(); + + let mut prime_gen = PrimeNumberGenerator::new(); + let primes = prime_gen.get_first_nth_primes(N); + assert!(primes.contains(&var_usize), "The usize representation of the variable must be a prime number, and unique for each variable"); + + let prime_idx = primes.iter().position(|&x| x == var_usize).unwrap(); + let idx = prime_gen.get_nth_prime(prime_idx + offset + 1); + + let mut res = Self::zero(); + let inv_idx = res + .normalized_indices + .iter() + .position(|&x| x == idx) + .unwrap(); + res[inv_idx] = F::one(); + res + } + + fn is_homogeneous(&self) -> bool { + let normalized_indices = self.normalized_indices.clone(); + let mut prime_gen = PrimeNumberGenerator::new(); + let is_homogeneous = normalized_indices + .iter() + .zip(self.coeff.clone()) + .all(|(idx, c)| { + let decomposition_of_i = naive_prime_factors(*idx, &mut prime_gen); + let monomial_degree = decomposition_of_i.iter().fold(0, |acc, (_, d)| acc + d); + monomial_degree == D || c == F::zero() + }); + is_homogeneous + } + + fn homogeneous_eval(&self, x: &[F; N], u: F) -> F { + let normalized_indices = self.normalized_indices.clone(); + let mut prime_gen = PrimeNumberGenerator::new(); + let primes = prime_gen.get_first_nth_primes(N); + normalized_indices + .iter() + .zip(self.coeff.clone()) + .fold(F::zero(), |acc, (idx, c)| { + let decomposition_of_i = naive_prime_factors(*idx, &mut prime_gen); + let monomial_degree = decomposition_of_i.iter().fold(0, |acc, (_, d)| acc + d); + let u_power = D - monomial_degree; + let monomial = decomposition_of_i.iter().fold(F::one(), |acc, (p, d)| { + let inv_p = primes.iter().position(|&x| x == *p).unwrap(); + let x_p = x[inv_p].pow([*d as u64]); + acc * x_p + }); + acc + c * monomial * u.pow([u_power as u64]) + }) + } + + fn add_monomial(&mut self, exponents: [usize; N], coeff: F) { + let mut prime_gen = PrimeNumberGenerator::new(); + let primes = prime_gen.get_first_nth_primes(N); + let normalized_index = exponents + .iter() + .zip(primes.iter()) + .fold(1, |acc, (d, p)| acc * p.pow(*d as u32)); + let inv_idx = self + .normalized_indices + .iter() + .position(|&x| x == normalized_index) + .unwrap(); + self.coeff[inv_idx] += coeff; + } + + fn compute_cross_terms( + &self, + _eval1: &[F; N], + _eval2: &[F; N], + _u1: F, + _u2: F, + ) -> HashMap { + unimplemented!() + } + + fn modify_monomial(&mut self, exponents: [usize; N], coeff: F) { + let mut prime_gen = PrimeNumberGenerator::new(); + let primes = prime_gen.get_first_nth_primes(N); + let index = exponents + .iter() + .zip(primes.iter()) + .fold(1, |acc, (exp, &prime)| acc * prime.pow(*exp as u32)); + if let Some(pos) = self.normalized_indices.iter().position(|&x| x == index) { + self.coeff[pos] = coeff; + } else { + panic!("Exponent combination out of bounds for the given polynomial degree and number of variables."); + } + } + + fn is_multilinear(&self) -> bool { + if self.is_zero() { + return true; + } + let normalized_indices = self.normalized_indices.clone(); + let mut prime_gen = PrimeNumberGenerator::new(); + normalized_indices + .iter() + .zip(self.coeff.iter()) + .all(|(idx, c)| { + if c.is_zero() { + true + } else { + let decomposition_of_i = naive_prime_factors(*idx, &mut prime_gen); + // Each prime number/variable should appear at most once + decomposition_of_i.iter().all(|(_p, d)| *d <= 1) + } + }) + } +} + +impl Dense { + pub fn new() -> Self { + let normalized_indices = Self::compute_normalized_indices(); + Self { + coeff: vec![F::zero(); Self::dimension()], + normalized_indices, + } + } + pub fn iter(&self) -> impl Iterator { + self.coeff.iter() + } + + pub fn dimension() -> usize { + binomial(N + D, D) + } + + pub fn from_coeffs(coeff: Vec) -> Self { + let normalized_indices = Self::compute_normalized_indices(); + Dense { + coeff, + normalized_indices, + } + } + + pub fn number_of_variables(&self) -> usize { + N + } + + pub fn maximum_degree(&self) -> usize { + D + } + + /// Output example for N = 2 and D = 2: + /// ```text + /// - 0 -> 1 + /// - 1 -> 2 + /// - 2 -> 3 + /// - 3 -> 4 + /// - 4 -> 6 + /// - 5 -> 9 + /// ``` + pub fn compute_normalized_indices() -> Vec { + let mut normalized_indices = vec![1; Self::dimension()]; + let mut prime_gen = PrimeNumberGenerator::new(); + let primes = prime_gen.get_first_nth_primes(N); + let max_index = primes[N - 1].checked_pow(D as u32); + let max_index = max_index.expect("Overflow in computing the maximum index"); + let mut j = 0; + (1..=max_index).for_each(|i| { + let prime_decomposition_of_index = naive_prime_factors(i, &mut prime_gen); + let is_valid_decomposition = prime_decomposition_of_index + .iter() + .all(|(p, _)| primes.contains(p)); + let monomial_degree = prime_decomposition_of_index + .iter() + .fold(0, |acc, (_, d)| acc + d); + let is_valid_decomposition = is_valid_decomposition && monomial_degree <= D; + if is_valid_decomposition { + normalized_indices[j] = i; + j += 1; + } + }); + normalized_indices + } + + pub fn increase_degree(&self) -> Dense { + assert!(D <= D_PRIME, "The degree of the target polynomial must be greater or equal to the degree of the source polynomial"); + let mut result: Dense = Dense::zero(); + let dst_normalized_indices = Dense::::compute_normalized_indices(); + let src_normalized_indices = Dense::::compute_normalized_indices(); + src_normalized_indices + .iter() + .enumerate() + .for_each(|(i, idx)| { + // IMPROVEME: should be computed once + let inv_idx = dst_normalized_indices + .iter() + .position(|&x| x == *idx) + .unwrap(); + result[inv_idx] = self[i]; + }); + result + } +} + +impl Default for Dense { + fn default() -> Self { + Dense::new() + } +} + +// Addition +impl Add for Dense { + type Output = Self; + + fn add(self, other: Self) -> Self { + &self + &other + } +} + +impl Add<&Dense> for Dense { + type Output = Dense; + + fn add(self, other: &Dense) -> Dense { + &self + other + } +} + +impl Add> for &Dense { + type Output = Dense; + + fn add(self, other: Dense) -> Dense { + self + &other + } +} + +impl Add<&Dense> for &Dense { + type Output = Dense; + + fn add(self, other: &Dense) -> Dense { + let coeffs = self + .coeff + .iter() + .zip(other.coeff.iter()) + .map(|(a, b)| *a + *b) + .collect(); + Dense::from_coeffs(coeffs) + } +} + +// Subtraction +impl Sub for Dense { + type Output = Self; + + fn sub(self, other: Self) -> Self { + self + (-other) + } +} + +impl Sub<&Dense> for Dense { + type Output = Dense; + + fn sub(self, other: &Dense) -> Dense { + self + (-other) + } +} + +impl Sub> for &Dense { + type Output = Dense; + + fn sub(self, other: Dense) -> Dense { + self + (-other) + } +} + +impl Sub<&Dense> for &Dense { + type Output = Dense; + + fn sub(self, other: &Dense) -> Dense { + self + (-other) + } +} + +// Negation +impl Neg for Dense { + type Output = Self; + + fn neg(self) -> Self::Output { + -&self + } +} + +impl Neg for &Dense { + type Output = Dense; + + fn neg(self) -> Self::Output { + let coeffs = self.coeff.iter().map(|c| -*c).collect(); + Dense::from_coeffs(coeffs) + } +} + +// Multiplication +impl Mul> for Dense { + type Output = Self; + + fn mul(self, other: Self) -> Self { + let mut cache = HashMap::new(); + let mut prime_gen = PrimeNumberGenerator::new(); + let mut result = vec![]; + (0..self.coeff.len()).for_each(|i| { + let mut sum = F::zero(); + let normalized_index = self.normalized_indices[i]; + let two_factors_decomposition = + compute_all_two_factors_decomposition(normalized_index, &mut cache, &mut prime_gen); + two_factors_decomposition.iter().for_each(|(a, b)| { + // FIXME: we should keep the inverse normalized indices + let inv_a = self + .normalized_indices + .iter() + .position(|&x| x == *a) + .unwrap(); + let inv_b = self + .normalized_indices + .iter() + .position(|&x| x == *b) + .unwrap(); + let a_coeff = self.coeff[inv_a]; + let b_coeff = other.coeff[inv_b]; + let product = a_coeff * b_coeff; + sum += product; + }); + result.push(sum); + }); + Self::from_coeffs(result) + } +} + +impl PartialEq for Dense { + fn eq(&self, other: &Self) -> bool { + self.coeff == other.coeff + } +} + +impl Eq for Dense {} + +impl Debug for Dense { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + let mut prime_gen = PrimeNumberGenerator::new(); + let primes = prime_gen.get_first_nth_primes(N); + let coeff: Vec<_> = self + .coeff + .iter() + .enumerate() + .filter(|(_i, c)| *c != &F::zero()) + .collect(); + // Print 0 if the polynomial is zero + if coeff.is_empty() { + write!(f, "0").unwrap(); + return Ok(()); + } + let l = coeff.len(); + coeff.into_iter().for_each(|(i, c)| { + let normalized_idx = self.normalized_indices[i]; + if normalized_idx == 1 && *c != F::one() { + write!(f, "{}", c.to_biguint()).unwrap(); + } else { + let prime_decomposition = naive_prime_factors(normalized_idx, &mut prime_gen); + if *c != F::one() { + write!(f, "{}", c.to_biguint()).unwrap(); + } + prime_decomposition.iter().for_each(|(p, d)| { + let inv_p = primes.iter().position(|&x| x == *p).unwrap(); + if *d > 1 { + write!(f, "x_{}^{}", inv_p, d).unwrap(); + } else { + write!(f, "x_{}", inv_p).unwrap(); + } + }); + } + // Avoid printing the last `+` or if the polynomial is a single + // monomial + if i != l - 1 && l != 1 { + write!(f, " + ").unwrap(); + } + }); + Ok(()) + } +} + +impl From for Dense { + fn from(value: F) -> Self { + let mut result = Self::zero(); + result.coeff[0] = value; + result + } +} + +// TODO: implement From/To Expr diff --git a/mvpoly/src/utils.rs b/mvpoly/src/utils.rs new file mode 100644 index 0000000000..b4e4aa0839 --- /dev/null +++ b/mvpoly/src/utils.rs @@ -0,0 +1,274 @@ +//! This module contains functions to work with prime numbers and to compute +//! dimension of multivariate spaces + +use std::collections::HashMap; + +/// Naive implementation checking if n is prime +/// You can also use the structure PrimeNumberGenerator to check if a number is +/// prime using +/// ```rust +/// use mvpoly::utils::PrimeNumberGenerator; +/// let n = 5; +/// let mut prime_gen = PrimeNumberGenerator::new(); +/// prime_gen.is_prime(n); +/// ``` +pub fn is_prime(n: usize) -> bool { + if n == 2 { + return true; + } + if n < 2 || n % 2 == 0 { + return false; + } + let mut i = 3; + while i * i <= n { + if n % i == 0 { + return false; + } + i += 2; + } + true +} + +/// Given a number n, return the list of prime factors of n, with their +/// multiplicity +/// The first argument is the number to factorize, the second argument is the +/// list of prime numbers to use to factorize the number +/// The output is a list of tuples, where the first element is the prime number +/// and the second element is the multiplicity of the prime number in the +/// factorization of n. +// IMPROVEME: native algorithm, could be optimized. Use a cache to store +// the prime factors of the previous numbers +pub fn naive_prime_factors(n: usize, prime_gen: &mut PrimeNumberGenerator) -> Vec<(usize, usize)> { + assert!(n > 0); + let mut hash_factors = HashMap::new(); + let mut n = n; + if prime_gen.is_prime(n) { + vec![(n, 1)] + } else { + let mut i = 1; + let mut p = prime_gen.get_nth_prime(i); + while n != 1 { + if n % p == 0 { + hash_factors.entry(p).and_modify(|e| *e += 1).or_insert(1); + n /= p; + } else { + i += 1; + p = prime_gen.get_nth_prime(i); + } + } + let mut factors = vec![]; + hash_factors.into_iter().for_each(|(k, v)| { + factors.push((k, v)); + }); + // sort by the prime number + factors.sort(); + factors + } +} + +pub struct PrimeNumberGenerator { + primes: Vec, +} + +impl PrimeNumberGenerator { + pub fn new() -> Self { + PrimeNumberGenerator { primes: vec![] } + } + + /// Generate the nth prime number + pub fn get_nth_prime(&mut self, n: usize) -> usize { + assert!(n > 0); + if n <= self.primes.len() { + self.primes[n - 1] + } else { + while self.primes.len() < n { + let mut i = { + if self.primes.is_empty() { + 2 + } else if self.primes.len() == 1 { + 3 + } else { + self.primes[self.primes.len() - 1] + 2 + } + }; + while !is_prime(i) { + i += 2; + } + self.primes.push(i); + } + self.primes[n - 1] + } + } + + /// Check if a number is prime using the list of prime numbers + /// It is different than the is_prime function because it uses the list + /// of prime numbers to check if a number is prime instead of checking + /// all the numbers up to the square root of n by step of 2. + /// This method can be more efficient if the list of prime numbers is + /// already computed. + pub fn is_prime(&mut self, n: usize) -> bool { + if n == 0 || n == 1 { + false + } else { + let mut i = 1; + let mut p = self.get_nth_prime(i); + while p * p <= n { + if n % p == 0 { + return false; + } + i += 1; + p = self.get_nth_prime(i); + } + true + } + } + + /// Get the next prime number + pub fn get_next_prime(&mut self) -> usize { + let n = self.primes.len(); + self.get_nth_prime(n + 1) + } + + pub fn get_first_nth_primes(&mut self, n: usize) -> Vec { + let _ = self.get_nth_prime(n); + self.primes.clone() + } +} + +impl Iterator for PrimeNumberGenerator { + type Item = usize; + + fn next(&mut self) -> Option { + let n = self.primes.len(); + Some(self.get_nth_prime(n + 1)) + } +} + +impl Default for PrimeNumberGenerator { + fn default() -> Self { + Self::new() + } +} + +/// Build mapping from 1..N to the first N prime numbers. It will be used to +/// encode variables as prime numbers +/// For instance, if N = 3, i.e. we have the variable $x_1, x_2, x_3$, the +/// mapping will be [1, 2, 3, 2, 3, 5] +/// The idea is to encode products of variables as products of prime numbers +/// and then use the factorization of the products to know which variables must +/// be fetched while computing a product of variables +pub fn get_mapping_with_primes() -> Vec { + let mut primes = PrimeNumberGenerator::new(); + let mut mapping = vec![0; 2 * N]; + for (i, v) in mapping.iter_mut().enumerate().take(N) { + *v = i + 1; + } + for (i, v) in mapping.iter_mut().enumerate().skip(N) { + *v = primes.get_nth_prime(i - N + 1); + } + mapping +} + +/// Compute all the possible two factors decomposition of a number n. +/// It uses an cache where previous values have been computed. +/// For instance, if n = 6, the function will return [(1, 6), (2, 3), (3, 2), (6, 1)]. +/// The cache might be used to store the results of previous computations. +/// The cache is a hashmap where the key is the number and the value is the +/// list of all the possible two factors decomposition. +/// The hashmap is updated in place. +/// The third parameter is a precomputed list of prime numbers. It is updated in +/// place in case new prime numbers are generated. +pub fn compute_all_two_factors_decomposition( + n: usize, + cache: &mut HashMap>, + prime_numbers: &mut PrimeNumberGenerator, +) -> Vec<(usize, usize)> { + if cache.contains_key(&n) { + cache[&n].clone() + } else { + let mut factors = vec![]; + if n == 1 { + factors.push((1, 1)); + } else if prime_numbers.is_prime(n) { + factors.push((1, n)); + factors.push((n, 1)); + } else { + let mut i = 1; + let mut p = prime_numbers.get_nth_prime(i); + while p * p <= n { + if n % p == 0 { + let res = n / p; + let res_factors = + compute_all_two_factors_decomposition(res, cache, prime_numbers); + for (a, b) in res_factors { + let x = (p * a, b); + if !factors.contains(&x) { + factors.push(x); + } + let x = (a, p * b); + if !factors.contains(&x) { + factors.push(x); + } + } + } + i += 1; + p = prime_numbers.get_nth_prime(i); + } + } + cache.insert(n, factors.clone()); + factors + } +} + +/// Compute the list of indices to perform N nested loops of different size each. +/// In other words, if we have to perform the 3 nested loops: +/// ```rust +/// let n1 = 3; +/// let n2 = 3; +/// let n3 = 5; +/// for i in 0..n1 { +/// for j in 0..n2 { +/// for k in 0..n3 { +/// } +/// } +/// } +/// ``` +/// the output will be all the possible values of `i`, `j`, and `k`. +/// The algorithm is as follows: +/// ```rust +/// let n1 = 3; +/// let n2 = 3; +/// let n3 = 5; +/// (0..(n1 * n2 * n3)).map(|l| { +/// let i = l % n1; +/// let j = (l / n1) % n2; +/// let k = (l / (n1 * n2)) % n3; +/// (i, j, k) +/// }); +/// ``` +/// For N nested loops, the algorithm is the same, with the division increasing +/// by the factor `N_k` for the index `i_(k + 1)` +/// +/// In the case of an empty list, the function will return a list containing a +/// single element which is the empty list. +/// +/// In the case of an empty loop (i.e. one value in the input list is 0), the +/// expected output is the empty list. +pub fn compute_indices_nested_loop(nested_loop_sizes: Vec) -> Vec> { + let n = nested_loop_sizes.iter().product(); + (0..n) + .map(|i| { + let mut div = 1; + // Compute indices for the loop, step i + let indices: Vec = nested_loop_sizes + .iter() + .map(|n_i| { + let k = (i / div) % n_i; + div *= n_i; + k + }) + .collect(); + indices + }) + .collect() +} diff --git a/mvpoly/tests/monomials.rs b/mvpoly/tests/monomials.rs new file mode 100644 index 0000000000..902a3fcfbc --- /dev/null +++ b/mvpoly/tests/monomials.rs @@ -0,0 +1,714 @@ +use ark_ff::{Field, One, UniformRand, Zero}; +use kimchi::circuits::{ + berkeley_columns::BerkeleyChallengeTerm, + expr::{ConstantExpr, Expr, ExprInner, Variable}, + gate::CurrOrNext, +}; +use mina_curves::pasta::Fp; +use mvpoly::{monomials::Sparse, MVPoly}; +use rand::Rng; + +#[test] +fn test_mul_by_one() { + mvpoly::pbt::test_mul_by_one::>(); +} + +#[test] +fn test_mul_by_zero() { + mvpoly::pbt::test_mul_by_zero::>(); +} + +#[test] +fn test_add_zero() { + mvpoly::pbt::test_add_zero::>(); +} + +#[test] +fn test_double_is_add_twice() { + mvpoly::pbt::test_double_is_add_twice::>(); +} + +#[test] +fn test_sub_zero() { + mvpoly::pbt::test_sub_zero::>(); +} + +#[test] +fn test_neg() { + mvpoly::pbt::test_neg::>(); +} + +#[test] +fn test_eval_pbt_add() { + mvpoly::pbt::test_eval_pbt_add::>(); +} + +#[test] +fn test_eval_pbt_sub() { + mvpoly::pbt::test_eval_pbt_sub::>(); +} + +#[test] +fn test_eval_pbt_mul_by_scalar() { + mvpoly::pbt::test_eval_pbt_mul_by_scalar::>(); +} + +#[test] +fn test_eval_pbt_neg() { + mvpoly::pbt::test_eval_pbt_neg::>(); +} + +#[test] +fn test_neg_ref() { + mvpoly::pbt::test_neg_ref::>(); +} + +#[test] +fn test_mul_by_scalar() { + mvpoly::pbt::test_mul_by_scalar::>(); +} + +#[test] +fn test_mul_by_scalar_with_zero() { + mvpoly::pbt::test_mul_by_scalar_with_zero::>(); +} + +#[test] +fn test_mul_by_scalar_with_one() { + mvpoly::pbt::test_mul_by_scalar_with_one::>(); +} + +#[test] +fn test_evaluation_zero_polynomial() { + mvpoly::pbt::test_evaluation_zero_polynomial::>(); +} + +#[test] +fn test_evaluation_constant_polynomial() { + mvpoly::pbt::test_evaluation_constant_polynomial::>(); +} + +#[test] +fn test_degree_constant() { + mvpoly::pbt::test_degree_constant::>(); +} + +#[test] +fn test_degree_random_degree() { + mvpoly::pbt::test_degree_random_degree::>(); + mvpoly::pbt::test_degree_random_degree::>(); +} + +#[test] +fn test_is_constant() { + mvpoly::pbt::test_is_constant::>(); +} + +#[test] +fn test_mvpoly_add_degree_pbt() { + mvpoly::pbt::test_mvpoly_add_degree_pbt::>(); +} + +#[test] +fn test_mvpoly_sub_degree_pbt() { + mvpoly::pbt::test_mvpoly_sub_degree_pbt::>(); +} + +#[test] +fn test_mvpoly_neg_degree_pbt() { + mvpoly::pbt::test_mvpoly_neg_degree_pbt::>(); +} + +#[test] +fn test_mvpoly_mul_by_scalar_degree_pbt() { + mvpoly::pbt::test_mvpoly_mul_by_scalar_degree_pbt::>(); +} + +#[test] +fn test_mvpoly_mul_degree_pbt() { + mvpoly::pbt::test_mvpoly_mul_degree_pbt::>(); +} + +#[test] +fn test_mvpoly_mul_eval_pbt() { + mvpoly::pbt::test_mvpoly_mul_eval_pbt::>(); +} + +#[test] +fn test_mvpoly_mul_pbt() { + mvpoly::pbt::test_mvpoly_mul_pbt::>(); +} + +#[test] +fn test_can_be_printed_with_debug() { + mvpoly::pbt::test_can_be_printed_with_debug::>(); +} + +#[test] +fn test_is_zero() { + mvpoly::pbt::test_is_zero::>(); +} + +#[test] +fn test_homogeneous_eval() { + mvpoly::pbt::test_homogeneous_eval::>(); +} + +#[test] +fn test_add_monomial() { + mvpoly::pbt::test_add_monomial::>(); +} + +#[test] +fn test_mvpoly_compute_cross_terms_degree_two_unit_test() { + let mut rng = o1_utils::tests::make_test_rng(None); + + { + // Homogeneous form is Y^2 + let p1 = Sparse::::from(Fp::from(1)); + + let random_eval1: [Fp; 4] = std::array::from_fn(|_| Fp::rand(&mut rng)); + let random_eval2: [Fp; 4] = std::array::from_fn(|_| Fp::rand(&mut rng)); + let u1 = Fp::rand(&mut rng); + let u2 = Fp::rand(&mut rng); + let cross_terms = p1.compute_cross_terms(&random_eval1, &random_eval2, u1, u2); + + // We only have one cross-term in this case as degree 2 + assert_eq!(cross_terms.len(), 1); + // Cross term of constant is r * (2 u1 u2) + assert_eq!(cross_terms[&1], (u1 * u2).double()); + } +} + +#[test] +fn test_mvpoly_compute_cross_terms_degree_two() { + let mut rng = o1_utils::tests::make_test_rng(None); + let p1 = unsafe { Sparse::::random(&mut rng, None) }; + let random_eval1: [Fp; 4] = std::array::from_fn(|_| Fp::rand(&mut rng)); + let random_eval2: [Fp; 4] = std::array::from_fn(|_| Fp::rand(&mut rng)); + let u1 = Fp::rand(&mut rng); + let u2 = Fp::rand(&mut rng); + let cross_terms = p1.compute_cross_terms(&random_eval1, &random_eval2, u1, u2); + // We only have one cross-term in this case + assert_eq!(cross_terms.len(), 1); + + let r = Fp::rand(&mut rng); + let random_lincomb: [Fp; 4] = std::array::from_fn(|i| random_eval1[i] + r * random_eval2[i]); + + let lhs = p1.homogeneous_eval(&random_lincomb, u1 + r * u2); + + let rhs = { + let eval1_hom = p1.homogeneous_eval(&random_eval1, u1); + let eval2_hom = p1.homogeneous_eval(&random_eval2, u2); + let cross_terms_eval = cross_terms.iter().fold(Fp::zero(), |acc, (power, term)| { + acc + r.pow([*power as u64]) * term + }); + eval1_hom + r * r * eval2_hom + cross_terms_eval + }; + assert_eq!(lhs, rhs); +} + +#[test] +fn test_mvpoly_compute_cross_terms_degree_three() { + let mut rng = o1_utils::tests::make_test_rng(None); + let p1 = unsafe { Sparse::::random(&mut rng, None) }; + let random_eval1: [Fp; 4] = std::array::from_fn(|_| Fp::rand(&mut rng)); + let random_eval2: [Fp; 4] = std::array::from_fn(|_| Fp::rand(&mut rng)); + let u1 = Fp::rand(&mut rng); + let u2 = Fp::rand(&mut rng); + let cross_terms = p1.compute_cross_terms(&random_eval1, &random_eval2, u1, u2); + + assert_eq!(cross_terms.len(), 2); + + let r = Fp::rand(&mut rng); + let random_lincomb: [Fp; 4] = std::array::from_fn(|i| random_eval1[i] + r * random_eval2[i]); + + let lhs = p1.homogeneous_eval(&random_lincomb, u1 + r * u2); + + let rhs = { + let eval1_hom = p1.homogeneous_eval(&random_eval1, u1); + let eval2_hom = p1.homogeneous_eval(&random_eval2, u2); + let cross_terms_eval = cross_terms.iter().fold(Fp::zero(), |acc, (power, term)| { + acc + r.pow([*power as u64]) * term + }); + let r_cube = r.pow([3]); + eval1_hom + r_cube * eval2_hom + cross_terms_eval + }; + assert_eq!(lhs, rhs); +} + +#[test] +fn test_mvpoly_compute_cross_terms_degree_four() { + let mut rng = o1_utils::tests::make_test_rng(None); + let p1 = unsafe { Sparse::::random(&mut rng, None) }; + let random_eval1: [Fp; 6] = std::array::from_fn(|_| Fp::rand(&mut rng)); + let random_eval2: [Fp; 6] = std::array::from_fn(|_| Fp::rand(&mut rng)); + let u1 = Fp::rand(&mut rng); + let u2 = Fp::rand(&mut rng); + let cross_terms = p1.compute_cross_terms(&random_eval1, &random_eval2, u1, u2); + + assert_eq!(cross_terms.len(), 3); + + let r = Fp::rand(&mut rng); + let random_lincomb: [Fp; 6] = std::array::from_fn(|i| random_eval1[i] + r * random_eval2[i]); + + let lhs = p1.homogeneous_eval(&random_lincomb, u1 + r * u2); + + let rhs = { + let eval1_hom = p1.homogeneous_eval(&random_eval1, u1); + let eval2_hom = p1.homogeneous_eval(&random_eval2, u2); + let cross_terms_eval = cross_terms.iter().fold(Fp::zero(), |acc, (power, term)| { + acc + r.pow([*power as u64]) * term + }); + let r_four = r.pow([4]); + eval1_hom + r_four * eval2_hom + cross_terms_eval + }; + assert_eq!(lhs, rhs); +} + +#[test] +fn test_mvpoly_compute_cross_terms_degree_five() { + let mut rng = o1_utils::tests::make_test_rng(None); + let p1 = unsafe { Sparse::::random(&mut rng, None) }; + let random_eval1: [Fp; 3] = std::array::from_fn(|_| Fp::rand(&mut rng)); + let random_eval2: [Fp; 3] = std::array::from_fn(|_| Fp::rand(&mut rng)); + let u1 = Fp::rand(&mut rng); + let u2 = Fp::rand(&mut rng); + let cross_terms = p1.compute_cross_terms(&random_eval1, &random_eval2, u1, u2); + + assert_eq!(cross_terms.len(), 4); + + let r = Fp::rand(&mut rng); + let random_lincomb: [Fp; 3] = std::array::from_fn(|i| random_eval1[i] + r * random_eval2[i]); + + let lhs = p1.homogeneous_eval(&random_lincomb, u1 + r * u2); + + let rhs = { + let eval1_hom = p1.homogeneous_eval(&random_eval1, u1); + let eval2_hom = p1.homogeneous_eval(&random_eval2, u2); + let cross_terms_eval = cross_terms.iter().fold(Fp::zero(), |acc, (power, term)| { + acc + r.pow([*power as u64]) * term + }); + let r_five = r.pow([5]); + eval1_hom + r_five * eval2_hom + cross_terms_eval + }; + assert_eq!(lhs, rhs); +} + +#[test] +fn test_mvpoly_compute_cross_terms_degree_six() { + let mut rng = o1_utils::tests::make_test_rng(None); + let p1 = unsafe { Sparse::::random(&mut rng, None) }; + let random_eval1: [Fp; 4] = std::array::from_fn(|_| Fp::rand(&mut rng)); + let random_eval2: [Fp; 4] = std::array::from_fn(|_| Fp::rand(&mut rng)); + let u1 = Fp::rand(&mut rng); + let u2 = Fp::rand(&mut rng); + let cross_terms = p1.compute_cross_terms(&random_eval1, &random_eval2, u1, u2); + + assert_eq!(cross_terms.len(), 5); + + let r = Fp::rand(&mut rng); + let random_lincomb: [Fp; 4] = std::array::from_fn(|i| random_eval1[i] + r * random_eval2[i]); + + let lhs = p1.homogeneous_eval(&random_lincomb, u1 + r * u2); + + let rhs = { + let eval1_hom = p1.homogeneous_eval(&random_eval1, u1); + let eval2_hom = p1.homogeneous_eval(&random_eval2, u2); + let cross_terms_eval = cross_terms.iter().fold(Fp::zero(), |acc, (power, term)| { + acc + r.pow([*power as u64]) * term + }); + let r_six = r.pow([6]); + eval1_hom + r_six * eval2_hom + cross_terms_eval + }; + assert_eq!(lhs, rhs); +} + +// The cross-terms of a sum of polynomials is the sum of the cross-terms, per +// power. +#[test] +fn test_mvpoly_pbt_cross_terms_addition() { + let mut rng = o1_utils::tests::make_test_rng(None); + let p1 = unsafe { Sparse::::random(&mut rng, None) }; + let p2 = unsafe { Sparse::::random(&mut rng, None) }; + let p = p1.clone() + p2.clone(); + + let random_eval1: [Fp; 4] = std::array::from_fn(|_| Fp::rand(&mut rng)); + let random_eval2: [Fp; 4] = std::array::from_fn(|_| Fp::rand(&mut rng)); + let u1 = Fp::rand(&mut rng); + let u2 = Fp::rand(&mut rng); + + let cross_terms1 = p1.compute_cross_terms(&random_eval1, &random_eval2, u1, u2); + let cross_terms2 = p2.compute_cross_terms(&random_eval1, &random_eval2, u1, u2); + let cross_terms = p.compute_cross_terms(&random_eval1, &random_eval2, u1, u2); + + let cross_terms_sum = + cross_terms1 + .iter() + .fold(cross_terms2.clone(), |mut acc, (power, term)| { + acc.entry(*power) + .and_modify(|v| *v += term) + .or_insert(*term); + acc + }); + assert_eq!(cross_terms, cross_terms_sum); +} + +#[test] +fn test_mvpoly_compute_cross_terms_degree_seven() { + let mut rng = o1_utils::tests::make_test_rng(None); + let p1 = unsafe { Sparse::::random(&mut rng, None) }; + let random_eval1: [Fp; 4] = std::array::from_fn(|_| Fp::rand(&mut rng)); + let random_eval2: [Fp; 4] = std::array::from_fn(|_| Fp::rand(&mut rng)); + let u1 = Fp::rand(&mut rng); + let u2 = Fp::rand(&mut rng); + let cross_terms = p1.compute_cross_terms(&random_eval1, &random_eval2, u1, u2); + + assert_eq!(cross_terms.len(), 6); + + let r = Fp::rand(&mut rng); + let random_lincomb: [Fp; 4] = std::array::from_fn(|i| random_eval1[i] + r * random_eval2[i]); + + let lhs = p1.homogeneous_eval(&random_lincomb, u1 + r * u2); + + let rhs = { + let eval1_hom = p1.homogeneous_eval(&random_eval1, u1); + let eval2_hom = p1.homogeneous_eval(&random_eval2, u2); + let cross_terms_eval = cross_terms.iter().fold(Fp::zero(), |acc, (power, term)| { + acc + r.pow([*power as u64]) * term + }); + let r_seven = r.pow([7]); + eval1_hom + r_seven * eval2_hom + cross_terms_eval + }; + assert_eq!(lhs, rhs); +} + +#[test] +fn test_is_multilinear() { + mvpoly::pbt::test_is_multilinear::>(); +} + +#[test] +fn test_increase_number_of_variables() { + let mut rng = o1_utils::tests::make_test_rng(None); + let p1: Sparse = unsafe { Sparse::::random(&mut rng, None) }; + + let p2: Result, String> = p1.into(); + p2.unwrap(); +} + +#[test] +fn test_pbt_increase_number_of_variables_with_addition() { + let mut rng = o1_utils::tests::make_test_rng(None); + let p1: Sparse = unsafe { Sparse::::random(&mut rng, None) }; + let p2: Sparse = unsafe { Sparse::::random(&mut rng, None) }; + + let lhs: Sparse = { + let p: Result, String> = (p1.clone() + p2.clone()).into(); + p.unwrap() + }; + + let rhs: Sparse = { + let p1: Result, String> = p1.clone().into(); + let p2: Result, String> = p2.clone().into(); + p1.unwrap() + p2.unwrap() + }; + + assert_eq!(lhs, rhs); +} + +#[test] +fn test_pbt_increase_number_of_variables_zero_one_cst() { + let mut rng = o1_utils::tests::make_test_rng(None); + { + let lhs_zero: Sparse = { + let p: Result, String> = Sparse::::zero().into(); + p.unwrap() + }; + let rhs_zero: Sparse = Sparse::::zero(); + assert_eq!(lhs_zero, rhs_zero); + } + + { + let lhs_one: Sparse = { + let p: Result, String> = Sparse::::one().into(); + p.unwrap() + }; + let rhs_one: Sparse = Sparse::::one(); + assert_eq!(lhs_one, rhs_one); + } + + { + let c = Fp::rand(&mut rng); + let lhs: Sparse = { + let p: Result, String> = Sparse::::from(c).into(); + p.unwrap() + }; + let rhs: Sparse = Sparse::::from(c); + assert_eq!(lhs, rhs); + } +} + +#[test] +fn test_build_from_variable() { + #[derive(Clone, Copy, PartialEq)] + enum Column { + X(usize), + } + + impl From for usize { + fn from(val: Column) -> usize { + match val { + Column::X(i) => i, + } + } + } + + let mut rng = o1_utils::tests::make_test_rng(None); + let idx: usize = rng.gen_range(0..4); + let p = Sparse::::from_variable::( + Variable { + col: Column::X(idx), + row: CurrOrNext::Curr, + }, + None, + ); + + let eval: [Fp; 4] = std::array::from_fn(|_i| Fp::rand(&mut rng)); + + assert_eq!(p.eval(&eval), eval[idx]); +} + +#[test] +#[should_panic] +fn test_build_from_variable_next_row_without_offset_given() { + #[derive(Clone, Copy, PartialEq)] + enum Column { + X(usize), + } + + impl From for usize { + fn from(val: Column) -> usize { + match val { + Column::X(i) => i, + } + } + } + + let mut rng = o1_utils::tests::make_test_rng(None); + let idx: usize = rng.gen_range(0..4); + let _p = Sparse::::from_variable::( + Variable { + col: Column::X(idx), + row: CurrOrNext::Next, + }, + None, + ); +} + +#[test] +fn test_build_from_variable_next_row_with_offset_given() { + #[derive(Clone, Copy, PartialEq)] + enum Column { + X(usize), + } + + impl From for usize { + fn from(val: Column) -> usize { + match val { + Column::X(i) => i, + } + } + } + + let mut rng = o1_utils::tests::make_test_rng(None); + let idx: usize = rng.gen_range(0..4); + + // Using next + { + let p = Sparse::::from_variable::( + Variable { + col: Column::X(idx), + row: CurrOrNext::Next, + }, + Some(4), + ); + + let eval: [Fp; 8] = std::array::from_fn(|_i| Fp::rand(&mut rng)); + assert_eq!(p.eval(&eval), eval[idx + 4]); + } + + // Still using current + { + let p = Sparse::::from_variable::( + Variable { + col: Column::X(idx), + row: CurrOrNext::Curr, + }, + Some(4), + ); + + let eval: [Fp; 8] = std::array::from_fn(|_i| Fp::rand(&mut rng)); + assert_eq!(p.eval(&eval), eval[idx]); + } +} + +/// As a reminder, here are the equations to compute the addition of two +/// different points `P1 = (X1, Y1)` and `P2 = (X2, Y2)`. Let `P3 = (X3, +/// Y3) = P1 + P2`. +/// +/// ```text +/// - λ = (Y1 - Y2) / (X1 - X2) +/// - X3 = λ^2 - X1 - X2 +/// - Y3 = λ (X1 - X3) - Y1 +/// ``` +/// +/// Therefore, the addition of elliptic curve points can be computed using the +/// following degree-2 constraints +/// +/// ```text +/// - Constraint 1: λ (X1 - X2) - Y1 + Y2 = 0 +/// - Constraint 2: X3 + X1 + X2 - λ^2 = 0 +/// - Constraint 3: Y3 - λ (X1 - X3) + Y1 = 0 +/// ``` +#[test] +fn test_from_expr_ec_addition() { + // Simulate a real usecase + // The following lines/design look similar to the ones we use in + // o1vm/arrabiata + #[derive(Clone, Copy, PartialEq)] + enum Column { + X(usize), + } + + impl From for usize { + fn from(val: Column) -> usize { + match val { + Column::X(i) => i, + } + } + } + + struct Constraint { + idx: usize, + } + + trait Interpreter { + type Position: Clone + Copy; + + type Variable: Clone + + std::ops::Add + + std::ops::Sub + + std::ops::Mul; + + fn allocate(&mut self) -> Self::Position; + + // Simulate fetching/reading a value from outside + // In the case of the witness, it will be getting a value from the + // environment + fn fetch(&self, pos: Self::Position) -> Self::Variable; + } + + impl Interpreter for Constraint { + type Position = Column; + + type Variable = Expr, Column>; + + fn allocate(&mut self) -> Self::Position { + let col = Column::X(self.idx); + self.idx += 1; + col + } + + fn fetch(&self, col: Self::Position) -> Self::Variable { + Expr::Atom(ExprInner::Cell(Variable { + col, + row: CurrOrNext::Curr, + })) + } + } + + impl Constraint { + fn new() -> Self { + Self { idx: 0 } + } + } + + let mut interpreter = Constraint::new(); + // Constraints for elliptic curve addition, without handling the case of the + // point at infinity or double + let lambda = { + let pos = interpreter.allocate(); + interpreter.fetch(pos) + }; + let x1 = { + let pos = interpreter.allocate(); + interpreter.fetch(pos) + }; + let x2 = { + let pos = interpreter.allocate(); + interpreter.fetch(pos) + }; + + let y1 = { + let pos = interpreter.allocate(); + interpreter.fetch(pos) + }; + let y2 = { + let pos = interpreter.allocate(); + interpreter.fetch(pos) + }; + + let x3 = { + let pos = interpreter.allocate(); + interpreter.fetch(pos) + }; + let y3 = { + let pos = interpreter.allocate(); + interpreter.fetch(pos) + }; + + // Check we can convert into a Sparse polynomial. + // We have 7 variables, maximum degree 2. + // We test by evaluating at a random point. + let mut rng = o1_utils::tests::make_test_rng(None); + { + // - Constraint 1: λ (X1 - X2) - Y1 + Y2 = 0 + let expression = lambda.clone() * (x1.clone() - x2.clone()) - (y1.clone() - y2.clone()); + + let p = Sparse::::from_expr::(expression, None); + let random_evaluation: [Fp; 7] = std::array::from_fn(|_| Fp::rand(&mut rng)); + let eval = p.eval(&random_evaluation); + let exp_eval = { + random_evaluation[0] * (random_evaluation[1] - random_evaluation[2]) + - (random_evaluation[3] - random_evaluation[4]) + }; + assert_eq!(eval, exp_eval); + } + + { + // - Constraint 2: X3 + X1 + X2 - λ^2 = 0 + let expr = x3.clone() + x1.clone() + x2.clone() - lambda.clone() * lambda.clone(); + let p = Sparse::::from_expr::(expr, None); + let random_evaluation: [Fp; 7] = std::array::from_fn(|_| Fp::rand(&mut rng)); + let eval = p.eval(&random_evaluation); + let exp_eval = { + random_evaluation[5] + random_evaluation[1] + random_evaluation[2] + - random_evaluation[0] * random_evaluation[0] + }; + assert_eq!(eval, exp_eval); + } + { + // - Constraint 3: Y3 - λ (X1 - X3) + Y1 = 0 + let expr = y3.clone() - lambda.clone() * (x1.clone() - x3.clone()) + y1.clone(); + let p = Sparse::::from_expr::(expr, None); + let random_evaluation: [Fp; 7] = std::array::from_fn(|_| Fp::rand(&mut rng)); + let eval = p.eval(&random_evaluation); + let exp_eval = { + random_evaluation[6] + - random_evaluation[0] * (random_evaluation[1] - random_evaluation[5]) + + random_evaluation[3] + }; + assert_eq!(eval, exp_eval); + } +} diff --git a/mvpoly/tests/prime.rs b/mvpoly/tests/prime.rs new file mode 100644 index 0000000000..17cc6a3eb1 --- /dev/null +++ b/mvpoly/tests/prime.rs @@ -0,0 +1,709 @@ +use ark_ff::{One, UniformRand, Zero}; +use kimchi::circuits::{ + berkeley_columns::BerkeleyChallengeTerm, + expr::{ConstantExpr, Expr, ExprInner, Variable}, + gate::CurrOrNext, +}; +use mina_curves::pasta::Fp; +use mvpoly::{prime::Dense, utils::PrimeNumberGenerator, MVPoly}; +use rand::Rng; + +#[test] +fn test_vector_space_dimension() { + assert_eq!(Dense::::dimension(), 6); + assert_eq!(Dense::::dimension(), 10); + assert_eq!(Dense::::dimension(), 11); +} + +#[test] +fn test_add() { + let p1 = Dense::::new(); + let p2 = Dense::::new(); + let _p3 = p1 + p2; +} + +#[test] +pub fn test_normalized_indices() { + let indices = Dense::::compute_normalized_indices(); + assert_eq!(indices.len(), 6); + assert_eq!(indices[0], 1); + assert_eq!(indices[1], 2); + assert_eq!(indices[2], 3); + assert_eq!(indices[3], 4); + assert_eq!(indices[4], 6); + assert_eq!(indices[5], 9); + + let indices = Dense::::compute_normalized_indices(); + assert_eq!(indices.len(), 10); + assert_eq!(indices[0], 1); + assert_eq!(indices[1], 2); + assert_eq!(indices[2], 3); + assert_eq!(indices[3], 4); + assert_eq!(indices[4], 5); + assert_eq!(indices[5], 6); + assert_eq!(indices[6], 9); + assert_eq!(indices[7], 10); + assert_eq!(indices[8], 15); + assert_eq!(indices[9], 25); +} + +#[test] +fn test_is_homogeneous() { + // X1 X2 + X1^2 + X1^2 + let coeffs: Vec = vec![ + Fp::zero(), + Fp::zero(), + Fp::zero(), + Fp::one(), + Fp::one(), + Fp::one(), + ]; + let p = Dense::::from_coeffs(coeffs); + assert!(p.is_homogeneous()); + + // X1 X2 + X1^2 + let coeffs: Vec = vec![ + Fp::zero(), + Fp::zero(), + Fp::zero(), + Fp::one(), + Fp::one(), + Fp::zero(), + ]; + let p = Dense::::from_coeffs(coeffs); + assert!(p.is_homogeneous()); + + // X1 X2 + X2^2 + let coeffs: Vec = vec![ + Fp::zero(), + Fp::zero(), + Fp::zero(), + Fp::one(), + Fp::zero(), + Fp::one(), + ]; + let p = Dense::::from_coeffs(coeffs); + assert!(p.is_homogeneous()); + + // X1 X2 + let coeffs: Vec = vec![ + Fp::zero(), + Fp::zero(), + Fp::zero(), + Fp::one(), + Fp::zero(), + Fp::zero(), + ]; + let p = Dense::::from_coeffs(coeffs); + assert!(p.is_homogeneous()); +} + +#[test] +fn test_is_not_homogeneous() { + // 1 + X1 X2 + X1^2 + X2^2 + let coeffs: Vec = vec![ + Fp::from(42_u32), + Fp::zero(), + Fp::zero(), + Fp::one(), + Fp::one(), + Fp::one(), + ]; + let p = Dense::::from_coeffs(coeffs); + assert!(!p.is_homogeneous()); + + let coeffs: Vec = vec![ + Fp::zero(), + Fp::zero(), + Fp::one(), + Fp::one(), + Fp::one(), + Fp::zero(), + ]; + let p = Dense::::from_coeffs(coeffs); + assert!(!p.is_homogeneous()); +} + +#[test] +fn test_mul() { + let coeff_p1 = vec![ + Fp::zero(), + Fp::from(2_u32), + Fp::one(), + Fp::zero(), + Fp::zero(), + Fp::zero(), + ]; + let coeff_p2 = vec![ + Fp::from(3_u32), + Fp::zero(), + Fp::one(), + Fp::zero(), + Fp::zero(), + Fp::zero(), + ]; + let coeff_p3 = vec![ + Fp::zero(), + Fp::from(6_u32), + Fp::from(3_u32), + Fp::zero(), + Fp::from(2_u32), + Fp::one(), + ]; + + let p1 = Dense::::from_coeffs(coeff_p1); + let p2 = Dense::::from_coeffs(coeff_p2); + let exp_p3 = Dense::::from_coeffs(coeff_p3); + let p3 = p1 * p2; + assert_eq!(p3, exp_p3); +} + +#[test] +fn test_mul_by_one() { + mvpoly::pbt::test_mul_by_one::>(); +} + +#[test] +fn test_mul_by_zero() { + mvpoly::pbt::test_mul_by_zero::>(); +} + +#[test] +fn test_add_zero() { + mvpoly::pbt::test_add_zero::>(); +} + +#[test] +fn test_double_is_add_twice() { + mvpoly::pbt::test_double_is_add_twice::>(); +} + +#[test] +fn test_sub_zero() { + mvpoly::pbt::test_sub_zero::>(); +} + +#[test] +fn test_neg() { + mvpoly::pbt::test_neg::>(); +} + +#[test] +fn test_eval_pbt_add() { + mvpoly::pbt::test_eval_pbt_add::>(); +} + +#[test] +fn test_eval_pbt_sub() { + mvpoly::pbt::test_eval_pbt_sub::>(); +} + +#[test] +fn test_eval_pbt_mul_by_scalar() { + mvpoly::pbt::test_eval_pbt_mul_by_scalar::>(); +} + +#[test] +fn test_eval_pbt_neg() { + mvpoly::pbt::test_eval_pbt_neg::>(); +} + +#[test] +fn test_neg_ref() { + mvpoly::pbt::test_neg_ref::>(); +} + +#[test] +fn test_mul_by_scalar() { + mvpoly::pbt::test_mul_by_scalar::>(); +} + +#[test] +fn test_mul_by_scalar_with_zero() { + mvpoly::pbt::test_mul_by_scalar_with_zero::>(); +} + +#[test] +fn test_mul_by_scalar_with_one() { + mvpoly::pbt::test_mul_by_scalar_with_one::>(); +} + +#[test] +fn test_mul_by_scalar_with_from() { + let mut rng = o1_utils::tests::make_test_rng(None); + let p = unsafe { Dense::::random(&mut rng, None) }; + let c = Fp::rand(&mut rng); + + // Create a constant polynomial from the field element + let constant_poly = Dense::::from(c); + + // Multiply p by c using mul_by_scalar + let result1 = p.mul_by_scalar(c); + + // Multiply p by the constant polynomial + let result2 = p.clone() * constant_poly; + + // Check that both methods produce the same result + assert_eq!(result1, result2); +} + +#[test] +fn test_from_variable_column() { + // Simulate a real usecase + enum Column { + X(usize), + } + + impl From for usize { + fn from(val: Column) -> usize { + match val { + Column::X(i) => { + let mut prime_gen = PrimeNumberGenerator::new(); + prime_gen.get_nth_prime(i + 1) + } + } + } + } + + let p = Dense::::from_variable::( + Variable { + col: Column::X(0), + row: CurrOrNext::Curr, + }, + None, + ); + assert_eq!(p[0], Fp::zero()); + assert_eq!(p[1], Fp::one()); + assert_eq!(p[2], Fp::zero()); + assert_eq!(p[3], Fp::zero()); + assert_eq!(p[4], Fp::zero()); + assert_eq!(p[5], Fp::zero()); + + // Test for z variable (index 3) + let p = Dense::::from_variable::( + Variable { + col: Column::X(1), + row: CurrOrNext::Curr, + }, + None, + ); + assert_eq!(p[0], Fp::zero()); + assert_eq!(p[1], Fp::zero()); + assert_eq!(p[2], Fp::one()); + assert_eq!(p[3], Fp::zero()); + assert_eq!(p[4], Fp::zero()); + + // Test for w variable (index 5) + let p = Dense::::from_variable::( + Variable { + col: Column::X(2), + row: CurrOrNext::Curr, + }, + None, + ); + assert_eq!(p[0], Fp::zero()); + assert_eq!(p[1], Fp::zero()); + assert_eq!(p[2], Fp::zero()); + assert_eq!(p[3], Fp::zero()); + assert_eq!(p[4], Fp::one()); +} + +#[test] +fn test_evaluation_zero_polynomial() { + mvpoly::pbt::test_evaluation_zero_polynomial::>(); +} + +#[test] +fn test_evaluation_constant_polynomial() { + mvpoly::pbt::test_evaluation_constant_polynomial::>(); +} + +#[test] +fn test_evaluation_predefined_polynomial() { + // Evaluating at random points + let mut rng = o1_utils::tests::make_test_rng(None); + + let random_evaluation: [Fp; 2] = std::array::from_fn(|_| Fp::rand(&mut rng)); + // P(X1, X2) = 2 + 3X1 + 4X2 + 5X1^2 + 6X1 X2 + 7 X2^2 + let p = Dense::::from_coeffs(vec![ + Fp::from(2_u32), + Fp::from(3_u32), + Fp::from(4_u32), + Fp::from(5_u32), + Fp::from(6_u32), + Fp::from(7_u32), + ]); + let exp_eval = Fp::from(2_u32) + + Fp::from(3_u32) * random_evaluation[0] + + Fp::from(4_u32) * random_evaluation[1] + + Fp::from(5_u32) * random_evaluation[0] * random_evaluation[0] + + Fp::from(6_u32) * random_evaluation[0] * random_evaluation[1] + + Fp::from(7_u32) * random_evaluation[1] * random_evaluation[1]; + let evaluation = p.eval(&random_evaluation); + assert_eq!(evaluation, exp_eval); +} + +/// As a reminder, here are the equations to compute the addition of two +/// different points `P1 = (X1, Y1)` and `P2 = (X2, Y2)`. Let `P3 = (X3, +/// Y3) = P1 + P2`. +/// +/// ```text +/// - λ = (Y1 - Y2) / (X1 - X2) +/// - X3 = λ^2 - X1 - X2 +/// - Y3 = λ (X1 - X3) - Y1 +/// ``` +/// +/// Therefore, the addition of elliptic curve points can be computed using the +/// following degree-2 constraints +/// +/// ```text +/// - Constraint 1: λ (X1 - X2) - Y1 + Y2 = 0 +/// - Constraint 2: X3 + X1 + X2 - λ^2 = 0 +/// - Constraint 3: Y3 - λ (X1 - X3) + Y1 = 0 +/// ``` +#[test] +fn test_from_expr_ec_addition() { + // Simulate a real usecase + // The following lines/design look similar to the ones we use in + // o1vm/arrabiata + #[derive(Clone, Copy, PartialEq)] + enum Column { + X(usize), + } + + impl From for usize { + fn from(val: Column) -> usize { + match val { + Column::X(i) => { + let mut prime_gen = PrimeNumberGenerator::new(); + prime_gen.get_nth_prime(i + 1) + } + } + } + } + + struct Constraint { + idx: usize, + } + + trait Interpreter { + type Position: Clone + Copy; + + type Variable: Clone + + std::ops::Add + + std::ops::Sub + + std::ops::Mul; + + fn allocate(&mut self) -> Self::Position; + + // Simulate fetching/reading a value from outside + // In the case of the witness, it will be getting a value from the + // environment + fn fetch(&self, pos: Self::Position) -> Self::Variable; + } + + impl Interpreter for Constraint { + type Position = Column; + + type Variable = Expr, Column>; + + fn allocate(&mut self) -> Self::Position { + let col = Column::X(self.idx); + self.idx += 1; + col + } + + fn fetch(&self, col: Self::Position) -> Self::Variable { + Expr::Atom(ExprInner::Cell(Variable { + col, + row: CurrOrNext::Curr, + })) + } + } + + impl Constraint { + fn new() -> Self { + Self { idx: 0 } + } + } + + let mut interpreter = Constraint::new(); + // Constraints for elliptic curve addition, without handling the case of the + // point at infinity or double + let lambda = { + let pos = interpreter.allocate(); + interpreter.fetch(pos) + }; + let x1 = { + let pos = interpreter.allocate(); + interpreter.fetch(pos) + }; + let x2 = { + let pos = interpreter.allocate(); + interpreter.fetch(pos) + }; + + let y1 = { + let pos = interpreter.allocate(); + interpreter.fetch(pos) + }; + let y2 = { + let pos = interpreter.allocate(); + interpreter.fetch(pos) + }; + + let x3 = { + let pos = interpreter.allocate(); + interpreter.fetch(pos) + }; + let y3 = { + let pos = interpreter.allocate(); + interpreter.fetch(pos) + }; + + // Check we can convert into a Dense polynomial using prime representation + // We have 7 variables, maximum degree 2 + // We test by evaluating at a random point. + let mut rng = o1_utils::tests::make_test_rng(None); + { + // - Constraint 1: λ (X1 - X2) - Y1 + Y2 = 0 + let expression = lambda.clone() * (x1.clone() - x2.clone()) - (y1.clone() - y2.clone()); + + let p = Dense::::from_expr::(expression, None); + let random_evaluation: [Fp; 7] = std::array::from_fn(|_| Fp::rand(&mut rng)); + let eval = p.eval(&random_evaluation); + let exp_eval = { + random_evaluation[0] * (random_evaluation[1] - random_evaluation[2]) + - (random_evaluation[3] - random_evaluation[4]) + }; + assert_eq!(eval, exp_eval); + } + + { + // - Constraint 2: X3 + X1 + X2 - λ^2 = 0 + let expr = x3.clone() + x1.clone() + x2.clone() - lambda.clone() * lambda.clone(); + let p = Dense::::from_expr::(expr, None); + let random_evaluation: [Fp; 7] = std::array::from_fn(|_| Fp::rand(&mut rng)); + let eval = p.eval(&random_evaluation); + let exp_eval = { + random_evaluation[5] + random_evaluation[1] + random_evaluation[2] + - random_evaluation[0] * random_evaluation[0] + }; + assert_eq!(eval, exp_eval); + } + { + // - Constraint 3: Y3 - λ (X1 - X3) + Y1 = 0 + let expr = y3.clone() - lambda.clone() * (x1.clone() - x3.clone()) + y1.clone(); + let p = Dense::::from_expr::(expr, None); + let random_evaluation: [Fp; 7] = std::array::from_fn(|_| Fp::rand(&mut rng)); + let eval = p.eval(&random_evaluation); + let exp_eval = { + random_evaluation[6] + - random_evaluation[0] * (random_evaluation[1] - random_evaluation[5]) + + random_evaluation[3] + }; + assert_eq!(eval, exp_eval); + } +} + +#[test] +pub fn test_prime_increase_degree() { + let mut rng = o1_utils::tests::make_test_rng(None); + let p1 = unsafe { Dense::::random(&mut rng, None) }; + { + let p1_prime = p1.increase_degree::<3>(); + let random_evaluation: [Fp; 6] = std::array::from_fn(|_| Fp::rand(&mut rng)); + assert_eq!( + p1.eval(&random_evaluation), + p1_prime.eval(&random_evaluation) + ); + } + { + let p1_prime = p1.increase_degree::<4>(); + let random_evaluation: [Fp; 6] = std::array::from_fn(|_| Fp::rand(&mut rng)); + assert_eq!( + p1.eval(&random_evaluation), + p1_prime.eval(&random_evaluation) + ); + } + { + let p1_prime = p1.increase_degree::<5>(); + let random_evaluation: [Fp; 6] = std::array::from_fn(|_| Fp::rand(&mut rng)); + assert_eq!( + p1.eval(&random_evaluation), + p1_prime.eval(&random_evaluation) + ); + } + // When precompution of prime factor decomposition is done, increase degree + // in testing +} + +#[test] +fn test_degree_with_coeffs() { + let p = Dense::::from_coeffs(vec![ + Fp::from(2_u32), + Fp::from(3_u32), + Fp::from(4_u32), + Fp::from(5_u32), + Fp::from(6_u32), + Fp::from(7_u32), + ]); + let degree = unsafe { p.degree() }; + assert_eq!(degree, 2); +} + +#[test] +fn test_degree_constant() { + mvpoly::pbt::test_degree_constant::>(); +} + +#[test] +fn test_degree_random_degree() { + mvpoly::pbt::test_degree_random_degree::>(); + mvpoly::pbt::test_degree_random_degree::>(); +} + +#[test] +fn test_is_constant() { + mvpoly::pbt::test_is_constant::>(); +} + +#[test] +fn test_mvpoly_add_degree_pbt() { + mvpoly::pbt::test_mvpoly_add_degree_pbt::>(); +} + +#[test] +fn test_mvpoly_sub_degree_pbt() { + mvpoly::pbt::test_mvpoly_sub_degree_pbt::>(); +} + +#[test] +fn test_mvpoly_neg_degree_pbt() { + mvpoly::pbt::test_mvpoly_neg_degree_pbt::>(); +} + +#[test] +fn test_mvpoly_mul_by_scalar_degree_pbt() { + mvpoly::pbt::test_mvpoly_mul_by_scalar_degree_pbt::>(); +} + +#[test] +fn test_mvpoly_mul_degree_pbt() { + mvpoly::pbt::test_mvpoly_mul_degree_pbt::>(); +} + +#[test] +fn test_mvpoly_mul_eval_pbt() { + mvpoly::pbt::test_mvpoly_mul_eval_pbt::>(); +} + +#[test] +fn test_mvpoly_mul_pbt() { + mvpoly::pbt::test_mvpoly_mul_pbt::>(); +} + +#[test] +fn test_can_be_printed_with_debug() { + mvpoly::pbt::test_can_be_printed_with_debug::>(); +} + +#[test] +fn test_is_zero() { + mvpoly::pbt::test_is_zero::>(); +} + +#[test] +fn test_homogeneous_eval() { + mvpoly::pbt::test_homogeneous_eval::>(); +} + +#[test] +fn test_add_monomial() { + mvpoly::pbt::test_add_monomial::>(); +} + +#[test] +fn test_is_multilinear() { + mvpoly::pbt::test_is_multilinear::>(); +} + +#[test] +#[should_panic] +fn test_build_from_variable_next_row_without_offset_given() { + #[derive(Clone, Copy, PartialEq)] + enum Column { + X(usize), + } + + impl From for usize { + fn from(val: Column) -> usize { + match val { + Column::X(i) => { + let mut prime_gen = PrimeNumberGenerator::new(); + prime_gen.get_nth_prime(i + 1) + } + } + } + } + + let mut rng = o1_utils::tests::make_test_rng(None); + let idx: usize = rng.gen_range(0..4); + let _p = Dense::::from_variable::( + Variable { + col: Column::X(idx), + row: CurrOrNext::Next, + }, + None, + ); +} + +#[test] +fn test_build_from_variable_next_row_with_offset_given() { + #[derive(Clone, Copy, PartialEq)] + enum Column { + X(usize), + } + + impl From for usize { + fn from(val: Column) -> usize { + match val { + Column::X(i) => { + let mut prime_gen = PrimeNumberGenerator::new(); + prime_gen.get_nth_prime(i + 1) + } + } + } + } + + let mut rng = o1_utils::tests::make_test_rng(None); + let idx: usize = rng.gen_range(0..4); + + // Using next + { + let p = Dense::::from_variable::( + Variable { + col: Column::X(idx), + row: CurrOrNext::Next, + }, + Some(4), + ); + + let eval: [Fp; 8] = std::array::from_fn(|_i| Fp::rand(&mut rng)); + assert_eq!(p.eval(&eval), eval[idx + 4]); + } + + // Still using current + { + let p = Dense::::from_variable::( + Variable { + col: Column::X(idx), + row: CurrOrNext::Curr, + }, + Some(4), + ); + + let eval: [Fp; 8] = std::array::from_fn(|_i| Fp::rand(&mut rng)); + assert_eq!(p.eval(&eval), eval[idx]); + } +} diff --git a/mvpoly/tests/utils.rs b/mvpoly/tests/utils.rs new file mode 100644 index 0000000000..dad2abe4d3 --- /dev/null +++ b/mvpoly/tests/utils.rs @@ -0,0 +1,329 @@ +use std::collections::HashMap; + +use mvpoly::utils::{ + compute_all_two_factors_decomposition, compute_indices_nested_loop, get_mapping_with_primes, + is_prime, naive_prime_factors, PrimeNumberGenerator, +}; + +pub const FIRST_FIFTY_PRIMES: [usize; 50] = [ + 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, + 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, + 197, 199, 211, 223, 227, 229, +]; + +#[test] +pub fn test_is_prime() { + FIRST_FIFTY_PRIMES + .iter() + .for_each(|&prime| assert!(is_prime(prime))); + + let mut prime_gen = PrimeNumberGenerator::new(); + FIRST_FIFTY_PRIMES + .iter() + .for_each(|&prime| assert!(prime_gen.is_prime(prime))) +} + +#[test] +pub fn test_is_not_prime() { + assert!(!is_prime(0)); + assert!(!is_prime(1)); + { + let random_even = 2 * (rand::random::() % 1000); + assert!(!is_prime(random_even)); + } + { + let random_product = + (2 + rand::random::() % 10000) * (2 + rand::random::() % 10000); + assert!(!is_prime(random_product)); + } +} + +#[test] +pub fn test_nth_prime() { + let mut prime_gen = PrimeNumberGenerator::new(); + // Repeated on purpose + assert_eq!(2, prime_gen.get_nth_prime(1)); + assert_eq!(2, prime_gen.get_nth_prime(1)); + // Repeated on purpose + assert_eq!(3, prime_gen.get_nth_prime(2)); + assert_eq!(5, prime_gen.get_nth_prime(3)); + // Repeated on purpose + assert_eq!(7, prime_gen.get_nth_prime(4)); + assert_eq!(7, prime_gen.get_nth_prime(4)); + assert_eq!(11, prime_gen.get_nth_prime(5)); + assert_eq!(13, prime_gen.get_nth_prime(6)); + assert_eq!(17, prime_gen.get_nth_prime(7)); + assert_eq!(19, prime_gen.get_nth_prime(8)); + assert_eq!(23, prime_gen.get_nth_prime(9)); + assert_eq!(29, prime_gen.get_nth_prime(10)); + assert_eq!(31, prime_gen.get_nth_prime(11)); + assert_eq!(37, prime_gen.get_nth_prime(12)); + assert_eq!(41, prime_gen.get_nth_prime(13)); + assert_eq!(43, prime_gen.get_nth_prime(14)); + assert_eq!(47, prime_gen.get_nth_prime(15)); + assert_eq!(53, prime_gen.get_nth_prime(16)); + assert_eq!(59, prime_gen.get_nth_prime(17)); + assert_eq!(61, prime_gen.get_nth_prime(18)); + assert_eq!(67, prime_gen.get_nth_prime(19)); + assert_eq!(71, prime_gen.get_nth_prime(20)); +} + +#[test] +pub fn test_mapping_variables_indexes_to_primes() { + { + let map = get_mapping_with_primes::<3>(); + assert_eq!(map[3], 2); + assert_eq!(map[4], 3); + assert_eq!(map[5], 5); + assert_eq!(map.len(), 6); + } +} + +#[test] +pub fn test_naive_prime_factors() { + let mut prime_gen = PrimeNumberGenerator::new(); + let n = 4; + let factors = naive_prime_factors(n, &mut prime_gen); + assert_eq!(factors, vec![(2, 2)]); + + let n = 6; + let factors = naive_prime_factors(n, &mut prime_gen); + assert_eq!(factors, vec![(2, 1), (3, 1)]); + + let n = 2; + let factors = naive_prime_factors(n, &mut prime_gen); + assert_eq!(factors, vec![(2, 1)]); + + let n = 10; + let factors = naive_prime_factors(n, &mut prime_gen); + assert_eq!(factors, vec![(2, 1), (5, 1)]); + + let n = 12; + let factors = naive_prime_factors(n, &mut prime_gen); + assert_eq!(factors, vec![(2, 2), (3, 1)]); + + let n = 40; + let factors = naive_prime_factors(n, &mut prime_gen); + assert_eq!(factors, vec![(2, 3), (5, 1)]); + + let n = 1023; + let factors = naive_prime_factors(n, &mut prime_gen); + assert_eq!(factors, vec![(3, 1), (11, 1), (31, 1)]); +} + +#[test] +pub fn test_get_next_prime() { + let mut prime_gen = PrimeNumberGenerator::new(); + assert_eq!(2, prime_gen.get_next_prime()); + assert_eq!(3, prime_gen.get_next_prime()); + assert_eq!(5, prime_gen.get_next_prime()); + assert_eq!(7, prime_gen.get_next_prime()); + assert_eq!(11, prime_gen.get_next_prime()); + assert_eq!(13, prime_gen.get_next_prime()); + assert_eq!(17, prime_gen.get_next_prime()); +} + +#[test] +pub fn test_iterator_on_prime_number_generator() { + let mut prime_gen = PrimeNumberGenerator::new(); + assert_eq!(2, prime_gen.next().unwrap()); + assert_eq!(3, prime_gen.next().unwrap()); + assert_eq!(5, prime_gen.next().unwrap()); + assert_eq!(7, prime_gen.next().unwrap()); + assert_eq!(11, prime_gen.next().unwrap()); + assert_eq!(13, prime_gen.next().unwrap()); + assert_eq!(17, prime_gen.next().unwrap()); +} + +#[test] +pub fn test_compute_all_two_factors_decomposition_prime() { + let mut prime_gen = PrimeNumberGenerator::new(); + let mut acc = HashMap::new(); + { + let mut res = compute_all_two_factors_decomposition(2, &mut acc, &mut prime_gen); + res.sort(); + assert_eq!(res, [(1, 2), (2, 1)].to_vec()); + } + { + let random_prime = prime_gen.get_nth_prime(rand::random::() % 1000); + let mut res = compute_all_two_factors_decomposition(random_prime, &mut acc, &mut prime_gen); + res.sort(); + assert_eq!(res, [(1, random_prime), (random_prime, 1)].to_vec()); + } +} + +#[test] +pub fn test_compute_all_two_factors_decomposition_special_case_one() { + let mut prime_gen = PrimeNumberGenerator::new(); + let mut acc = HashMap::new(); + let mut res = compute_all_two_factors_decomposition(1, &mut acc, &mut prime_gen); + res.sort(); + assert_eq!(res, [(1, 1)].to_vec()); +} + +#[test] +pub fn test_compoute_all_factors_decomposition_no_multiplicity() { + let mut prime_gen = PrimeNumberGenerator::new(); + let mut acc = HashMap::new(); + + let mut res = compute_all_two_factors_decomposition(6, &mut acc, &mut prime_gen); + res.sort(); + assert_eq!(res, [(1, 6), (2, 3), (3, 2), (6, 1)].to_vec()); +} + +#[test] +pub fn test_compute_all_two_factors_decomposition_with_multiplicity() { + let mut prime_gen = PrimeNumberGenerator::new(); + let mut acc = HashMap::new(); + + let mut res = compute_all_two_factors_decomposition(4, &mut acc, &mut prime_gen); + res.sort(); + assert_eq!(res, [(1, 4), (2, 2), (4, 1)].to_vec()); + + let mut res = compute_all_two_factors_decomposition(8, &mut acc, &mut prime_gen); + res.sort(); + assert_eq!(res, [(1, 8), (2, 4), (4, 2), (8, 1)].to_vec()); + + let mut res = compute_all_two_factors_decomposition(16, &mut acc, &mut prime_gen); + res.sort(); + assert_eq!(res, [(1, 16), (2, 8), (4, 4), (8, 2), (16, 1)].to_vec()); + + let mut res = compute_all_two_factors_decomposition(48, &mut acc, &mut prime_gen); + res.sort(); + assert_eq!( + res, + [ + (1, 48), + (2, 24), + (3, 16), + (4, 12), + (6, 8), + (8, 6), + (12, 4), + (16, 3), + (24, 2), + (48, 1) + ] + .to_vec() + ); + + let mut res = compute_all_two_factors_decomposition(100, &mut acc, &mut prime_gen); + res.sort(); + assert_eq!( + res, + [ + (1, 100), + (2, 50), + (4, 25), + (5, 20), + (10, 10), + (20, 5), + (25, 4), + (50, 2), + (100, 1) + ] + .to_vec() + ); +} + +#[test] +pub fn test_compute_indices_nested_loop() { + let nested_loops = vec![2, 2]; + // sorting to get the same order + let mut exp_indices = vec![vec![0, 0], vec![0, 1], vec![1, 0], vec![1, 1]]; + exp_indices.sort(); + let mut comp_indices = compute_indices_nested_loop(nested_loops); + comp_indices.sort(); + assert_eq!(exp_indices, comp_indices); + + let nested_loops = vec![3, 2]; + // sorting to get the same order + let mut exp_indices = vec![ + vec![0, 0], + vec![0, 1], + vec![1, 0], + vec![1, 1], + vec![2, 0], + vec![2, 1], + ]; + exp_indices.sort(); + let mut comp_indices = compute_indices_nested_loop(nested_loops); + comp_indices.sort(); + assert_eq!(exp_indices, comp_indices); + + let nested_loops = vec![3, 3, 2, 2]; + // sorting to get the same order + let mut exp_indices = vec![ + vec![0, 0, 0, 0], + vec![0, 0, 0, 1], + vec![0, 0, 1, 0], + vec![0, 0, 1, 1], + vec![0, 1, 0, 0], + vec![0, 1, 0, 1], + vec![0, 1, 1, 0], + vec![0, 1, 1, 1], + vec![0, 2, 0, 0], + vec![0, 2, 0, 1], + vec![0, 2, 1, 0], + vec![0, 2, 1, 1], + vec![1, 0, 0, 0], + vec![1, 0, 0, 1], + vec![1, 0, 1, 0], + vec![1, 0, 1, 1], + vec![1, 1, 0, 0], + vec![1, 1, 0, 1], + vec![1, 1, 1, 0], + vec![1, 1, 1, 1], + vec![1, 2, 0, 0], + vec![1, 2, 0, 1], + vec![1, 2, 1, 0], + vec![1, 2, 1, 1], + vec![2, 0, 0, 0], + vec![2, 0, 0, 1], + vec![2, 0, 1, 0], + vec![2, 0, 1, 1], + vec![2, 1, 0, 0], + vec![2, 1, 0, 1], + vec![2, 1, 1, 0], + vec![2, 1, 1, 1], + vec![2, 2, 0, 0], + vec![2, 2, 0, 1], + vec![2, 2, 1, 0], + vec![2, 2, 1, 1], + ]; + exp_indices.sort(); + let mut comp_indices = compute_indices_nested_loop(nested_loops); + comp_indices.sort(); + assert_eq!(exp_indices, comp_indices); + + // Simple and single loop + let nested_loops = vec![3]; + let exp_indices = vec![vec![0], vec![1], vec![2]]; + let mut comp_indices = compute_indices_nested_loop(nested_loops); + comp_indices.sort(); + assert_eq!(exp_indices, comp_indices); + + // relatively large loops + let nested_loops = vec![10, 10]; + let comp_indices = compute_indices_nested_loop(nested_loops); + // Only checking the length as it would take too long to unroll the result + assert_eq!(comp_indices.len(), 100); + + // Non-uniform loop sizes, relatively large + let nested_loops = vec![5, 7, 3]; + let comp_indices = compute_indices_nested_loop(nested_loops); + assert_eq!(comp_indices.len(), 5 * 7 * 3); +} + +#[test] +fn test_compute_indices_nested_loop_edge_cases() { + let nested_loops = vec![]; + let comp_indices: Vec> = compute_indices_nested_loop(nested_loops); + let exp_output: Vec> = vec![vec![]]; + assert_eq!(comp_indices, exp_output); + + // With one empty loop. Should match the documentation + let nested_loops = vec![3, 0, 2]; + let comp_indices = compute_indices_nested_loop(nested_loops); + assert_eq!(comp_indices.len(), 0); +} diff --git a/o1vm/Cargo.toml b/o1vm/Cargo.toml new file mode 100644 index 0000000000..1e57ce7c0d --- /dev/null +++ b/o1vm/Cargo.toml @@ -0,0 +1,65 @@ +[package] +name = "o1vm" +version = "0.1.0" +description = "o1VM" +repository = "https://github.com/o1-labs/proof-systems" +homepage = "https://o1-labs.github.io/proof-systems/" +documentation = "https://o1-labs.github.io/proof-systems/rustdoc/" +readme = "README.md" +edition = "2021" +license = "Apache-2.0" + +[lib] +path = "src/lib.rs" + +[[bin]] +name = "test_optimism_preimage_read" +path = "src/test_preimage_read.rs" + +[[bin]] +name = "legacy_o1vm" +path = "src/legacy/main.rs" + +[[bin]] +name = "pickles_o1vm" +path = "src/pickles/main.rs" + +[dependencies] +# FIXME: Only activate this when legacy_o1vm is built +ark-bn254.workspace = true +# FIXME: Only activate this when legacy_o1vm is built +folding.workspace = true +# FIXME: Only activate this when legacy_o1vm is built +kimchi = { workspace = true, features = [ "bn254" ] } +ark-ec.workspace = true +ark-ff.workspace = true +ark-poly.workspace = true +base64.workspace = true +clap.workspace = true +command-fds.workspace = true +elf.workspace = true +env_logger.workspace = true +groupmap.workspace = true +hex.workspace = true +itertools.workspace = true +kimchi-msm.workspace = true +libc.workspace = true +libflate.workspace = true +log.workspace = true +mina-curves.workspace = true +mina-poseidon.workspace = true +o1-utils.workspace = true +os_pipe.workspace = true +poly-commitment.workspace = true +rand.workspace = true +rayon.workspace = true +regex.workspace = true +rmp-serde.workspace = true +serde.workspace = true +serde_json.workspace = true +serde_with.workspace = true +stacker = "0.1" +strum.workspace = true +strum_macros.workspace = true +sha3.workspace = true +thiserror.workspace = true \ No newline at end of file diff --git a/o1vm/Dockerfile-e2e-testing-cache b/o1vm/Dockerfile-e2e-testing-cache new file mode 100644 index 0000000000..deeb1c29f6 --- /dev/null +++ b/o1vm/Dockerfile-e2e-testing-cache @@ -0,0 +1,12 @@ +FROM alpine + +# Add necessary utilities +RUN apk add --no-cache shadow unzip zip + +# Copy the files from the host to the Docker image +# Assuming the files and folder are located +# in the same directory as the current Dockerfile on the host +COPY o1vm-e2e-testing-cache.zip /tmp/cache-source/o1vm-e2e-testing-cache.zip + +# Command to copy the files to mounted volume +CMD ["/bin/sh", "-c", "cp -r /tmp/cache-source/* /tmp/cache/"] diff --git a/o1vm/README-optimism.md b/o1vm/README-optimism.md new file mode 100644 index 0000000000..740d495a34 --- /dev/null +++ b/o1vm/README-optimism.md @@ -0,0 +1,191 @@ +## Run infra for OP Sepolia on Ubuntu + +Install the first dependencies: +``` +sudo apt update +# chrony will ensure the system clock is open to date +sudo apt install build-essential git vim chrony ufw -y +``` + +Set up your firewall +``` +sudo ufw default deny incoming +sudo ufw default allow outgoing +sudo ufw allow OpenSSH +sudo ufw allow 8545 +sudo ufw allow 8546 +sudo ufw allow 8551 +sudo ufw allow 30303 +sudo ufw show added +sudo ufw enable +``` + +Set up a non-root user and its environment +``` +sudo useradd validator +mkdir /validator +``` + +Set up go: +``` +cd /validator +curl -OL https://go.dev/dl/go1.22.3.linux-amd64.tar.gz +sudo tar -C /usr/local -xvf go1.22.3.linux-amd64.tar.gz +echo 'export PATH=$PATH:/usr/local/go/bin' >> ~/.profile +source ~/.profile +``` + +Set up op-geth: +``` +git clone https://github.com/ethereum-optimism/op-geth.git +cd op-geth +git checkout v1.101315.1 #or another more recent stable tag +make all +``` + +Set up op-node: +``` +cd /validator +git clone https://github.com/ethereum-optimism/optimism.git +git checkout v1.7.6 #or another more recent stable tag +cd optimism/op-node +make op-node +``` + +Generate a JWT token: +``` +cd /validator +mkdir infra-for-op && cd infra-for-op +openssl rand -hex 32 > jwt.txt +``` + +Generate genesis: +``` +curl -o genesis.json -sL https://storage.googleapis.com/oplabs-network-data/Sepolia/genesis.json +../op-geth/build/bin/geth init --datadir="op-geth-data" genesis.json +``` + +Create the script `/validator/infra-for-op/op-geth-run.sh` to run OP geth +``` +#!/bin/bash +SEQUENCER_URL="https://sepolia-sequencer.optimism.io/" + +DATADIR="/validator/infra-for-optimism/data" +JWT_TOKEN="/validator/infra-for-optimism/jwt.txt" +HTTP_PORT=8545 +WS_PORT=8546 +AUTHRPC_PORT=8551 +PORT=30303 + +/validator/op-geth/build/bin/geth \ + --datadir=${DATADIR} \ + --verbosity=3 \ + --http \ + --http.corsdomain="*" \ + --http.vhosts="*" \ + --http.addr 0.0.0.0 \ + --http.port ${HTTP_PORT} \ + --http.api web3,debug,eth,txpool,net,engine \ + --ws \ + --ws.addr "0.0.0.0" \ + --ws.port ${WS_PORT} \ + --ws.origins="*" \ + --ws.api=debug,eth,txpool,net,engine \ + --authrpc.vhosts="*" \ + --authrpc.addr "0.0.0.0" \ + --authrpc.port ${AUTHRPC_PORT} \ + --authrpc.jwtsecret=${JWT_TOKEN} \ + --syncmode=full \ + --op-network=op-sepolia \ + --rollup.sequencerhttp=$SEQUENCER_URL \ + --port=${PORT} \ + --discovery.port=${PORT} \ + --gcmode=archive +``` + +Create the service `/etc/systemd/system/op-geth.service` +``` +[Unit] +Description=OP geth service + +[Service] +Type=simple +User=validator +Restart=always +ExecStart=/validator/infra-for-optimism/op-geth-run.sh + +[Install] +WantedBy=default.target +``` + +Enable the op-geth service +``` +sudo systemctl enable op-geth +``` + +Create the folder to store op-node files: +``` +mkdir /validator/infra-for-optimism/op-node-data +``` + +Create the script `/validator/infra-for-op/op-node-run.sh` to run the OP node: +``` +#!/bin/bash +L1_RPC_URL=https://ethereum-sepolia-rpc.publicnode.com +L1_RPC_KIND=standard +L1_BEACON_URL=https://ethereum-sepolia-beacon-api.publicnode.com +L2_PORT=8551 +LISTENING_ADDR="0.0.0.0" + +JWT_FILE="/validator/infra-for-optimism/jwt.txt" + +cd /validator/infra-for-optimism/op-node-data + +/validator/optimism/op-node/bin/op-node \ + --l1=${L1_RPC_URL} \ + --l1.rpckind=${L1_RPC_KIND} \ + --l1.beacon=${L1_BEACON_URL} \ + --l2=ws://localhost:${L2_PORT} \ + --l2.jwt-secret="${JWT_FILE}" \ + --network=op-sepolia \ + --syncmode=execution-layer +``` + +Create the service `/etc/systemd/system/op-node.service` +``` +[Unit] +Description=OP node service + +[Service] +Type=simple +Restart=on-failure +RestartSec=5s +User=validator +ExecStart=/validator/infra-for-optimism/op-node-run.sh + +[Install] +WantedBy=multi-user.target +``` + +Enable the op-node service +``` +sudo systemctl enable op-node +``` + +Make sure the user `validator` can run the services smoothly: +``` +chown -R validator:validator /validator +loginctl enable-linger validator +``` + +Start and check the op-geth service +``` +sudo systemctl start op-geth +journalctl -u op-geth -n 20 -f +``` + +Start and check the op-node service +``` +sudo systemctl start op-node +journalctl -u op-node -n 20 -f +``` diff --git a/o1vm/README.md b/o1vm/README.md new file mode 100644 index 0000000000..0a7cc6540b --- /dev/null +++ b/o1vm/README.md @@ -0,0 +1,177 @@ +# o1VM: a zero-knowledge virtual machine + +This crate contains an implementation of different components used to build a +zero-knowledge virtual machine. For now, the implementation is specialised for +the ISA MIPS used by [Cannon](https://github.com/ethereum-optimism/cannon). In +the future, the codebase will be generalised to handle more ISA and more +programs. + +## Description + +The current version of o1vm depends on an Optimism infrastructure to fetch +blocks and transaction data (see [README-optimism.md](./README-optimism.md)). +Currently, the only program that the codebase has been tested on is the +[op-program](./ethereum-optimism/op-program), which contains code to verify +Ethereum state transitions (EVM). + +`op-program` is first compiled into MIPS, using the Go compiler. +From there, we fetch the latest Ethereum/Optimism network information (latest +block, etc), and execute the op-program using the MIPS VM provided by Optimism, +named Cannon (`./run-cannon`). + +We can execute o1vm later using `run-vm.sh`. It will build the whole data +points (witness) required to make a proof later. +Note that everything is only local at the moment. Nothing is posted on-chain or +anywhere else. + +Each different step can be run using `./run-code.sh`. + +## Pre-requisites + +o1vm compiles a certain version of the Optimism codebase (written in Go), and +therefore you need to have a Go compiler installed on your machine. For now, +at least go 1.21 is required. + +You can use [gvm](https://github.com/moovweb/gvm) to install a Go compiler. +Switch to go 1.21 before continuing. + +```shell +gvm install go1.21 +gvm use go1.21 [--default] +``` + +If you do not have a go version installed you will need earlier versions +to install 1.21 + +```shell +gvm install go1.4 -B +gvm use go1.4 +export GOROOT_BOOTSTRAP=$GOROOT +gvm install go1.17.13 +gvm use go1.17.13 +export GOROOT_BOOTSTRAP=$GOROOT +gvm install go1.21 +gvm use go1.21s +``` + +You also will need to install the [Foundry](https://getfoundry.sh/) toolkit +in order to utilize applicaitons like `cast`. + +```shell +foundryup +``` + +You will also need to install jq with your favorite packet manager. + +eg. on Ubuntu + +```shell +sudo apt-get install jq +``` + +## Running the Optimism demo + +Start by initializing the submodules: + +```bash +git submodule init && git submodule update +``` + +Create an executable `rpcs.sh` file like: + +```bash +#!/usr/bin/env bash +export L1_RPC=http://xxxxxxxxx +export L2_RPC=http://xxxxxxxxx +export OP_NODE_RPC=http://xxxxxxxxx +export L1_BEACON_RPC=http://xxxxxxxxx +``` + +If you just want to test the state transition between the latest finalized L2 +block and its predecessor: + +```bash +./run-code.sh +``` + +By default this will also create a script named `env-for-latest-l2-block.sh` with a +snapshot of all the information that you need to rerun the same test again: + +```bash +FILENAME=env-for-latest-l2-block.sh bash run-code.sh +``` + +Alternatively, you also have the option to test the state transition between a +specific block and its predecessor: + +```bash +# Set -n to the desired block transition you want to test. +./setenv-for-l2-block.sh -n 12826645 +``` + +In this case, you can run the demo using the following format: + +```bash +FILENAME=env-for-l2-block-12826645.sh bash run-code.sh +``` + +In either case, `run-code.sh` will: + +1. Generate the initial state. +2. Execute the OP program. +3. Execute the OP program through the Cannon MIPS VM. +4. Execute the OP program through the o1VM MIPS + +## Flavors + +Different versions/flavors of the o1vm are available. + +- [legacy](./src/legacy/mod.rs) - to be deprecated. +- [pickles](./src/pickles/mod.rs) (currently the default) + +You can select the flavor you want to run with `run-code.sh` by using the +environment variable `O1VM_FLAVOR`. + +## Testing the preimage read + +Run: + +```bash +./test_preimage_read.sh [OP_DB_DIRECTORY] [NETWORK_NAME] +``` + +The default value for `OP_DB_DIRECTORY` would be the one from +`setenv-for-latest-l2-block.sh` if the parameter is omitted. + +The `NETWORK_NAME` defaults to `sepolia`. + +## Running the o1vm with cached data + +If you want to run the o1vm with cached data, you can use the following steps: + +- Make sure you have [Docker Engine](https://docs.docker.com/engine/install/) and [Python3](https://www.python.org/downloads/) installed on your machine. +- Fetch the cached data by executing the following command (it might take some time): + +```shell +./fetch-e2e-testing-cache.sh +``` + +- Start the simple HTTP server (in background or in another terminal session): + +```shell +python3 -m http.server 8765 & +``` + +- Then run the o1vm with the following command: + +```shell +RUN_WITH_CACHED_DATA="y" FILENAME="env-for-latest-l2-block.sh" O1VM_FLAVOR="pickles" STOP_AT="=3000000" ./run-code.sh +``` + +- Don't forget to stop the HTTP server after you are done. + +- You can clean the cached data by executing the following command: + +```shell +./clear-e2e-testing-cache.sh +``` diff --git a/o1vm/clear-e2e-testing-cache.sh b/o1vm/clear-e2e-testing-cache.sh new file mode 100755 index 0000000000..8379737f49 --- /dev/null +++ b/o1vm/clear-e2e-testing-cache.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash +set -euo pipefail + +START=$(date +%s) + +echo "" +echo "Cleaning up the E2E testing cache file system..." +echo "" +rm -rf op-program-db-for-latest-l2-block \ + env-for-latest-l2-block.sh \ + snapshot-state-3000000.json \ + state.json \ + meta.json \ + out.json \ + cpu.pprof + +RUNTIME=$(($(date +%s) - START)) +echo "" +echo "Execution time: ${RUNTIME} s" +echo "" diff --git a/o1vm/ethereum-optimism b/o1vm/ethereum-optimism new file mode 160000 index 0000000000..4a487b8920 --- /dev/null +++ b/o1vm/ethereum-optimism @@ -0,0 +1 @@ +Subproject commit 4a487b8920daa9dc4b496d691d5f283f9bb659b1 diff --git a/o1vm/fetch-e2e-testing-cache.sh b/o1vm/fetch-e2e-testing-cache.sh new file mode 100755 index 0000000000..b090d80c86 --- /dev/null +++ b/o1vm/fetch-e2e-testing-cache.sh @@ -0,0 +1,16 @@ +#!/usr/bin/env bash +set -euo pipefail + +START=$(date +%s) + +echo "" +echo "Fetching the E2E testing cache..." +echo "" +docker run --rm -it --name o1vm-e2e-testing-cache --pull=always -v ./:/tmp/cache o1labs/proof-systems:o1vm-e2e-testing-cache +unzip -q -o o1vm-e2e-testing-cache.zip -d ./ +rm -rf o1vm-e2e-testing-cache.zip + +RUNTIME=$(($(date +%s) - START)) +echo "" +echo "Execution time: ${RUNTIME} s" +echo "" diff --git a/o1vm/publish-e2e-testing-cache.sh b/o1vm/publish-e2e-testing-cache.sh new file mode 100755 index 0000000000..397617fd19 --- /dev/null +++ b/o1vm/publish-e2e-testing-cache.sh @@ -0,0 +1,46 @@ +#!/usr/bin/env bash +set -euo pipefail + +START=$(date +%s) +# ARCH="amd64" +DOCKER_FILE_NAME="Dockerfile-e2e-testing-cache" +DOCKER_HUB_USER_NAME="o1labs" +DOCKER_HUB_REPO_NAME="proof-systems" +DOCKER_HUB_IMAGE_TAG="o1vm-e2e-testing-cache" +DOCKER_IMAGE_FULL_NAME="${DOCKER_HUB_USER_NAME}/${DOCKER_HUB_REPO_NAME}:${DOCKER_HUB_IMAGE_TAG}" + +echo "" +echo "Preparing the file system..." +echo "" +zip -r -q o1vm-e2e-testing-cache.zip op-program-db-for-latest-l2-block \ + env-for-latest-l2-block.sh \ + snapshot-state-3000000.json \ + state.json \ + meta.json \ + out.json \ + cpu.pprof + +echo "" +echo "Building the '${DOCKER_IMAGE_FULL_NAME}' Docker image..." +echo "" +docker rmi -f ${DOCKER_IMAGE_FULL_NAME} || true +docker rmi -f ${DOCKER_HUB_IMAGE_TAG} || true +# docker build --platform linux/${ARCH} -t ${DOCKER_HUB_IMAGE_TAG} -f ${DOCKER_FILE_NAME} . +docker build -t ${DOCKER_HUB_IMAGE_TAG} -f ${DOCKER_FILE_NAME} . + +echo "" +echo "Publishing the '${DOCKER_IMAGE_FULL_NAME}' Docker image..." +echo "" +docker tag ${DOCKER_HUB_IMAGE_TAG} ${DOCKER_IMAGE_FULL_NAME} +docker push ${DOCKER_IMAGE_FULL_NAME} +echo "" + +echo "" +echo "Cleaning up the file system..." +echo "" +rm -rf o1vm-e2e-testing-cache.zip + +RUNTIME=$(($(date +%s) - START)) +echo "" +echo "Execution time: ${RUNTIME} s" +echo "" diff --git a/o1vm/resources/programs/riscv32im/fibonacci b/o1vm/resources/programs/riscv32im/fibonacci new file mode 100644 index 0000000000..aa6542bb3d Binary files /dev/null and b/o1vm/resources/programs/riscv32im/fibonacci differ diff --git a/o1vm/resources/programs/riscv32im/fibonacci-7 b/o1vm/resources/programs/riscv32im/fibonacci-7 new file mode 100644 index 0000000000..fad8dea9fd Binary files /dev/null and b/o1vm/resources/programs/riscv32im/fibonacci-7 differ diff --git a/o1vm/resources/programs/riscv32im/is_prime_naive b/o1vm/resources/programs/riscv32im/is_prime_naive new file mode 100755 index 0000000000..97567077a9 Binary files /dev/null and b/o1vm/resources/programs/riscv32im/is_prime_naive differ diff --git a/o1vm/resources/programs/riscv32im/no-action b/o1vm/resources/programs/riscv32im/no-action new file mode 100755 index 0000000000..d94f45752a Binary files /dev/null and b/o1vm/resources/programs/riscv32im/no-action differ diff --git a/o1vm/resources/tests/0x022107307879258577230c5aa2f90567bda40877a7e85dceb6ff1f37480fef3d.txt b/o1vm/resources/tests/0x022107307879258577230c5aa2f90567bda40877a7e85dceb6ff1f37480fef3d.txt new file mode 100644 index 0000000000..2fdbf76e25 --- /dev/null +++ b/o1vm/resources/tests/0x022107307879258577230c5aa2f90567bda40877a7e85dceb6ff1f37480fef3d.txt @@ -0,0 +1 @@ +f90163822080b9015d7ef90159a0132561da48c8ec09d76d3d711396d6b44acab25365908108f7db2920ba19b9ea94deaddeaddeaddeaddeaddeaddeaddeaddead00019442000000000000000000000000000000000000158080830f424080b90104015d8eb900000000000000000000000000000000000000000000000000000000004a71d6000000000000000000000000000000000000000000000000000000006579a848000000000000000000000000000000000000000000000000000000037900e0e18fec22bdc89a697fd9e8836ebece95c73f86b2ed15e4c50d6e66a5e69ed4472f00000000000000000000000000000000000000000000000000000000000000030000000000000000000000008f23bb38f531600e5d8fddaaec41f13fab46e98c00000000000000000000000000000000000000000000000000000000000000bc00000000000000000000000000000000000000000000000000000000000a6fe0 \ No newline at end of file diff --git a/o1vm/run-cannon.sh b/o1vm/run-cannon.sh new file mode 100755 index 0000000000..654714370b --- /dev/null +++ b/o1vm/run-cannon.sh @@ -0,0 +1,27 @@ +#!/usr/bin/env bash +set -euo pipefail + +make -C ./ethereum-optimism/op-program op-program +make -C ./ethereum-optimism/cannon cannon + +./ethereum-optimism/cannon/bin/cannon load-elf --path=./ethereum-optimism/op-program/bin/op-program-client.elf + +./ethereum-optimism/cannon/bin/cannon run \ + --pprof.cpu \ + --info-at "${INFO_AT:-%10000000}" \ + --proof-at never \ + --stop-at "${STOP_AT:-never}" \ + --input "${CANNON_STATE_FILENAME:-./state.json}" \ + -- \ + ./ethereum-optimism/op-program/bin/op-program \ + --log.level DEBUG \ + --l1 "${L1_RPC}" \ + --l2 "${L2_RPC}" \ + --network sepolia \ + --datadir "${OP_PROGRAM_DATA_DIR}" \ + --l1.head "${L1_HEAD}" \ + --l2.head "${L2_HEAD}" \ + --l2.outputroot "${STARTING_OUTPUT_ROOT}" \ + --l2.claim "${L2_CLAIM}" \ + --l2.blocknumber "${L2_BLOCK_NUMBER}" \ + --server diff --git a/o1vm/run-code.sh b/o1vm/run-code.sh new file mode 100755 index 0000000000..fba71d8d0c --- /dev/null +++ b/o1vm/run-code.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash +set -euo pipefail + +set +u +if [ -z "${FILENAME}" ]; then + echo "Using the latest finalized l2 block information to configure the env variables..." + FILENAME="$(./setenv-for-latest-l2-block.sh)" +fi +set -u + +source $FILENAME + +if [ "${RUN_WITH_CACHED_DATA:-}" == "y" ]; then + echo "The Op-Program and the Cannon apps were not executed because the cached data usage was requested" +else + ./run-op-program.sh + ./run-cannon.sh +fi + +./run-vm.sh diff --git a/o1vm/run-op-program.sh b/o1vm/run-op-program.sh new file mode 100755 index 0000000000..a2d992c3d7 --- /dev/null +++ b/o1vm/run-op-program.sh @@ -0,0 +1,18 @@ +#!/usr/bin/env bash +set -euo pipefail + +make -C ./ethereum-optimism/op-program op-program + +set -x +./ethereum-optimism/op-program/bin/op-program \ + --log.level DEBUG \ + --l1 "${L1_RPC}" \ + --l2 "${L2_RPC}" \ + --network sepolia \ + --datadir "${OP_PROGRAM_DATA_DIR}" \ + --l1.head "${L1_HEAD}" \ + --l1.beacon "${L1_BEACON_RPC}" \ + --l2.head "${L2_HEAD}" \ + --l2.outputroot "${STARTING_OUTPUT_ROOT}" \ + --l2.claim "${L2_CLAIM}" \ + --l2.blocknumber "${L2_BLOCK_NUMBER}" diff --git a/o1vm/run-vm.sh b/o1vm/run-vm.sh new file mode 100755 index 0000000000..34e61f4da7 --- /dev/null +++ b/o1vm/run-vm.sh @@ -0,0 +1,52 @@ +#!/usr/bin/env bash +set -euo pipefail + +O1VM_FLAVOR="${O1VM_FLAVOR:-pickles}" + +case ${O1VM_FLAVOR} in +"legacy") + BINARY_FLAVOR="legacy_o1vm" + ;; +"pickles") + BINARY_FLAVOR="pickles_o1vm" + ;; +*) + echo "${BASH_SOURCE[0]}:${LINENO}: only flavors 'legacy' and \ +'pickles' (default) are supported for now. Use the environment variable \ +O1VM_FLAVOR to one of these values to run the flavor you would like" + exit 1 + ;; +esac + +if [ "${RUN_WITH_CACHED_DATA:-}" == "y" ]; then + echo "Setting the environment variables to use the cached data" + export OP_PROGRAM_DATA_DIR="$(pwd)/op-program-db-for-latest-l2-block" + export ZKVM_STATE_FILENAME="$(pwd)/snapshot-state-3000000.json" + # We need to set the L1 and L2 RPC endpoints for op-program to run successfully + export L1_RPC="http://localhost:8765" + export L2_RPC="http://localhost:8765" +fi + +cargo run --bin ${BINARY_FLAVOR} \ + --all-features \ + --release \ + -p o1vm -- \ + --pprof.cpu \ + --info-at "${INFO_AT:-%10000000}" \ + --snapshot-state-at "${SNAPSHOT_STATE_AT:-%10000000}" \ + --proof-at never \ + --stop-at "${STOP_AT:-never}" \ + --input "${ZKVM_STATE_FILENAME:-./state.json}" \ + -- \ + ./ethereum-optimism/op-program/bin/op-program \ + --log.level DEBUG \ + --l1 "${L1_RPC}" \ + --l2 "${L2_RPC}" \ + --network sepolia \ + --datadir "${OP_PROGRAM_DATA_DIR}" \ + --l1.head "${L1_HEAD}" \ + --l2.head "${L2_HEAD}" \ + --l2.outputroot "${STARTING_OUTPUT_ROOT}" \ + --l2.claim "${L2_CLAIM}" \ + --l2.blocknumber "${L2_BLOCK_NUMBER}" \ + --server diff --git a/o1vm/setenv-for-l2-block.sh b/o1vm/setenv-for-l2-block.sh new file mode 100755 index 0000000000..a369fb528d --- /dev/null +++ b/o1vm/setenv-for-l2-block.sh @@ -0,0 +1,67 @@ +#!/bin/bash + +while getopts ":n:" opt; do + case "$opt" in + n) L2_BLOCK_NUMBER=$OPTARG + ;; + *) + echo "Usage: $0 -n L2_BLOCK_NUMBER" 1>&2 + exit 0 + ;; + esac +done + +shift $((OPTIND-1)) +[ "${1:-}" = "--" ] && shift + +if [ -z "${L2_BLOCK_NUMBER}" ]; then + # You can check for a block at: + # https://sepolia-optimism.etherscan.io/block/L2_BLOCK_NUMBER + echo "Usage: $0 -n L2_BLOCK_NUMBER" 1>&2 + exit 0 +fi + +re='^[0-9]+$' +if ! [[ $L2_BLOCK_NUMBER =~ $re ]] ; then + echo "The -n argument must be a number" 1>&2 + exit 0 +fi + +source rpcs.sh + +OUTPUT_L2_BLOCK=$(cast rpc optimism_outputAtBlock $(cast th $L2_BLOCK_NUMBER) --rpc-url "${OP_NODE_RPC}") +OUTPUT_L2_PREVIOUS_BLOCK=$(cast rpc optimism_outputAtBlock $(cast th $((L2_BLOCK_NUMBER-1))) --rpc-url "${OP_NODE_RPC}") + +# We use the STARTING_OUTPUT_ROOT as the output state just prior to the +# L2_BLOCK_NUMBER +STARTING_OUTPUT_ROOT=$(echo $OUTPUT_L2_PREVIOUS_BLOCK | jq -r .outputRoot) + +#Claim of transition between STARTING_OUTPUT_ROOT up to L2_BLOCK_NUMBER +L2_CLAIM=$(echo $OUTPUT_L2_BLOCK | jq -r .outputRoot) + +# Get the finalized heads +L1_HEAD=$(echo $OUTPUT_L2_BLOCK | jq -r .blockRef.l1origin.hash) +L2_HEAD=$(echo $OUTPUT_L2_PREVIOUS_BLOCK | jq -r .blockRef.hash) + +echo "The claim of the transition to block ${L2_BLOCK_NUMBER} is $L2_CLAIM" 1>&2 + +FILENAME=env-for-l2-block-${L2_BLOCK_NUMBER}.sh +# Delete all lines in the file if it already exists +cat /dev/null > ${FILENAME} + +OP_PROGRAM_DATA_DIR=$(pwd)/op-program-db-for-l2-block-${L2_BLOCK_NUMBER} + +echo "export L1_HEAD=${L1_HEAD}" >> "${FILENAME}" +echo "export L2_HEAD=${L2_HEAD}" >> "${FILENAME}" +echo "export L2_BLOCK_NUMBER=${L2_BLOCK_NUMBER}" >> "${FILENAME}" +echo "export STARTING_OUTPUT_ROOT=${STARTING_OUTPUT_ROOT}" >> "${FILENAME}" +echo "export L2_CLAIM=${L2_CLAIM}" >> ${FILENAME} +echo "export OP_PROGRAM_DATA_DIR=${OP_PROGRAM_DATA_DIR}" >> "${FILENAME}" +echo "export L1_RPC=${L1_RPC}" >> "${FILENAME}" +echo "export L2_RPC=${L2_RPC}" >> "${FILENAME}" +echo "export L1_BEACON_RPC=${L1_BEACON_RPC}" >> "${FILENAME}" + + +# echo "Env variables for block ${L2_BLOCK_NUMBER} can be loaded using +# ./${FILENAME}" +echo "${FILENAME}" diff --git a/o1vm/setenv-for-latest-l2-block.sh b/o1vm/setenv-for-latest-l2-block.sh new file mode 100755 index 0000000000..f3d3376637 --- /dev/null +++ b/o1vm/setenv-for-latest-l2-block.sh @@ -0,0 +1,44 @@ +#!/usr/bin/env bash +set -euo pipefail + +source rpcs.sh + +# AnchorStateRegistry (or ASR) on L1 Sepolia +ANCHOR_STATE_REGISTRY=0x218CD9489199F321E1177b56385d333c5B598629 + +# ASR returns the output root and the L2 Block Number of the latest finalized +# output root. +ANCHOR_OUTPUT=$(cast call --rpc-url "${L1_RPC}" ${ANCHOR_STATE_REGISTRY} 'anchors(uint32) returns(bytes32,uint256)' 0) +L2_CLAIM=$(echo ${ANCHOR_OUTPUT} | cut -d' ' -f 1) +L2_BLOCK_NUMBER=$(echo ${ANCHOR_OUTPUT} | cut -d' ' -f 2) + +OUTPUT_L2_PREVIOUS_BLOCK=$(cast rpc optimism_outputAtBlock $(cast th $((L2_BLOCK_NUMBER-1))) --rpc-url "${OP_NODE_RPC}") +# We use the STARTING_OUTPUT_ROOT as the output state just prior to the +# L2_BLOCK_NUMBER +STARTING_OUTPUT_ROOT=$(echo $OUTPUT_L2_PREVIOUS_BLOCK | jq -r .outputRoot) +L2_HEAD=$(echo $OUTPUT_L2_PREVIOUS_BLOCK | jq -r .blockRef.hash) + +L1_FINALIZED_NUMBER=$(cast block finalized --rpc-url "${L1_RPC}" -f number) +L1_FINALIZED_HASH=$(cast block "${L1_FINALIZED_NUMBER}" --rpc-url "${L1_RPC}" -f hash) +L1_HEAD=$L1_FINALIZED_HASH + +echo "The claim of the transition to block ${L2_BLOCK_NUMBER} is $L2_CLAIM" 1>&2 + +FILENAME=env-for-latest-l2-block.sh + +# Delete all lines in the file if it already exists +cat /dev/null > ${FILENAME} + +OP_PROGRAM_DATA_DIR=$(pwd)/op-program-db-for-latest-l2-block + +echo "export L1_HEAD=${L1_HEAD}" >> "${FILENAME}" +echo "export L2_HEAD=${L2_HEAD}" >> "${FILENAME}" +echo "export L2_BLOCK_NUMBER=${L2_BLOCK_NUMBER}" >> "${FILENAME}" +echo "export STARTING_OUTPUT_ROOT=${STARTING_OUTPUT_ROOT}" >> "${FILENAME}" +echo "export L2_CLAIM=${L2_CLAIM}" >> ${FILENAME} +echo "export OP_PROGRAM_DATA_DIR=${OP_PROGRAM_DATA_DIR}" >> "${FILENAME}" +echo "export L1_RPC=${L1_RPC}" >> "${FILENAME}" +echo "export L2_RPC=${L2_RPC}" >> "${FILENAME}" +echo "export L1_BEACON_RPC=${L1_BEACON_RPC}" >> "${FILENAME}" + +echo "${FILENAME}" diff --git a/o1vm/src/cannon.rs b/o1vm/src/cannon.rs new file mode 100644 index 0000000000..1a47a3041f --- /dev/null +++ b/o1vm/src/cannon.rs @@ -0,0 +1,488 @@ +// Data structure and stuff for compatibility with Cannon + +use base64::{engine::general_purpose, Engine as _}; + +use libflate::zlib::{Decoder, Encoder}; +use regex::Regex; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use std::io::{Read, Write}; + +pub const PAGE_ADDRESS_SIZE: u32 = 12; +pub const PAGE_SIZE: u32 = 1 << PAGE_ADDRESS_SIZE; +pub const PAGE_ADDRESS_MASK: u32 = PAGE_SIZE - 1; + +#[derive(Serialize, Deserialize, Debug)] +pub struct Page { + pub index: u32, + #[serde(deserialize_with = "from_base64", serialize_with = "to_base64")] + pub data: Vec, +} + +fn from_base64<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + let s: String = Deserialize::deserialize(deserializer)?; + let b64_decoded = general_purpose::STANDARD.decode(s).unwrap(); + let mut decoder = Decoder::new(&b64_decoded[..]).unwrap(); + let mut data = Vec::new(); + decoder.read_to_end(&mut data).unwrap(); + assert_eq!(data.len(), PAGE_SIZE as usize); + Ok(data) +} + +fn to_base64(v: &[u8], serializer: S) -> Result +where + S: Serializer, +{ + let encoded_v = Vec::new(); + let mut encoder = Encoder::new(encoded_v).unwrap(); + encoder.write_all(v).unwrap(); + let res = encoder.finish().into_result().unwrap(); + let b64_encoded = general_purpose::STANDARD.encode(res); + serializer.serialize_str(&b64_encoded) +} + +// The renaming below keeps compatibility with OP Cannon's state format +#[derive(Serialize, Deserialize, Debug)] +pub struct State { + pub memory: Vec, + #[serde( + rename = "preimageKey", + deserialize_with = "deserialize_preimage_key", + serialize_with = "serialize_preimage_key" + )] + pub preimage_key: [u8; 32], + #[serde(rename = "preimageOffset")] + pub preimage_offset: u32, + pub pc: u32, + #[serde(rename = "nextPC")] + pub next_pc: u32, + pub lo: u32, + pub hi: u32, + pub heap: u32, + pub exit: u8, + pub exited: bool, + pub step: u64, + pub registers: [u32; 32], + pub last_hint: Option>, + pub preimage: Option>, +} + +#[derive(Debug, PartialEq, Eq)] +pub struct ParsePreimageKeyError(String); + +#[derive(Debug, PartialEq)] +pub struct PreimageKey(pub [u8; 32]); + +use std::str::FromStr; + +impl FromStr for PreimageKey { + type Err = ParsePreimageKeyError; + + fn from_str(s: &str) -> Result { + let parts = s.split('x').collect::>(); + let hex_value: &str = if parts.len() == 1 { + parts[0] + } else { + if parts.len() != 2 { + return Err(ParsePreimageKeyError( + format!("Badly structured value to convert {s}").to_string(), + )); + }; + parts[1] + }; + // We only handle a hexadecimal representations of exactly 32 bytes (no auto-padding) + if hex_value.len() == 64 { + hex::decode(hex_value).map_or_else( + |_| { + Err(ParsePreimageKeyError( + format!("Could not hex decode {hex_value}").to_string(), + )) + }, + |h| { + h.clone().try_into().map_or_else( + |_| { + Err(ParsePreimageKeyError( + format!("Could not cast vector {:#?} into 32 bytes array", h) + .to_string(), + )) + }, + |res| Ok(PreimageKey(res)), + ) + }, + ) + } else { + Err(ParsePreimageKeyError( + format!("{hex_value} is not 32-bytes long").to_string(), + )) + } + } +} + +fn deserialize_preimage_key<'de, D>(deserializer: D) -> Result<[u8; 32], D::Error> +where + D: Deserializer<'de>, +{ + let s: String = Deserialize::deserialize(deserializer)?; + let p = PreimageKey::from_str(s.as_str()) + .unwrap_or_else(|_| panic!("Parsing {s} as preimage key failed")); + Ok(p.0) +} + +fn serialize_preimage_key(v: &[u8], serializer: S) -> Result +where + S: Serializer, +{ + let s: String = format!("0x{}", hex::encode(v)); + serializer.serialize_str(&s) +} + +#[derive(Clone, Debug, PartialEq)] +pub enum StepFrequency { + Never, + Always, + Exactly(u64), + Every(u64), + Range(u64, Option), +} + +// Simple parser for Cannon's "frequency format" +// A frequency input is either +// - never/always +// - = (only at step n) +// - % (every steps multiple of n) +// - n..[m] (from n on, until m excluded if specified, until the end otherwise) +pub fn step_frequency_parser(s: &str) -> std::result::Result { + use StepFrequency::*; + + let mod_re = Regex::new(r"^%([0-9]+)").unwrap(); + let eq_re = Regex::new(r"^=([0-9]+)").unwrap(); + let ival_re = Regex::new(r"^([0-9]+)..([0-9]+)?").unwrap(); + + match s { + "never" => Ok(Never), + "always" => Ok(Always), + s => { + if let Some(m) = mod_re.captures(s) { + Ok(Every(m[1].parse::().unwrap())) + } else if let Some(m) = eq_re.captures(s) { + Ok(Exactly(m[1].parse::().unwrap())) + } else if let Some(m) = ival_re.captures(s) { + let lo = m[1].parse::().unwrap(); + let hi_opt = m.get(2).map(|x| x.as_str().parse::().unwrap()); + Ok(Range(lo, hi_opt)) + } else { + Err(format!("Unknown frequency format {}", s)) + } + } + } +} + +impl ToString for State { + // A very debatable and incomplete, but serviceable, `to_string` implementation. + fn to_string(&self) -> String { + format!( + "memory_size (length): {}\nfirst page size: {}\npreimage key: {:#?}\npreimage offset:{}\npc: {}\nlo: {}\nhi: {}\nregisters:{:#?} ", + self.memory.len(), + self.memory[0].data.len(), + self.preimage_key, + self.preimage_offset, + self.pc, + self.lo, + self.hi, + self.registers + ) + } +} + +#[derive(Debug, Clone)] +pub struct HostProgram { + pub name: String, + pub arguments: Vec, +} + +#[derive(Debug, Clone)] +pub struct VmConfiguration { + pub input_state_file: String, + pub output_state_file: String, + pub metadata_file: String, + pub proof_at: StepFrequency, + pub stop_at: StepFrequency, + pub snapshot_state_at: StepFrequency, + pub info_at: StepFrequency, + pub proof_fmt: String, + pub snapshot_fmt: String, + pub pprof_cpu: bool, + pub host: Option, +} + +#[derive(Debug, Clone)] +pub struct Start { + pub time: std::time::Instant, + pub step: usize, +} + +impl Start { + pub fn create(step: usize) -> Start { + Start { + time: std::time::Instant::now(), + step, + } + } +} + +#[derive(Debug, PartialEq, Clone, Deserialize)] +pub struct Symbol { + pub name: String, + pub start: u32, + pub size: usize, +} + +#[derive(Debug, PartialEq, Clone, Deserialize)] +pub struct Meta { + #[serde(deserialize_with = "filtered_ordered")] + pub symbols: Vec, // Needs to be in ascending order w.r.t start address +} + +// Make sure that deserialized data are ordered in ascending order and that we +// have removed 0-size symbols +fn filtered_ordered<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + let v: Vec = Deserialize::deserialize(deserializer)?; + let mut filtered: Vec = v.into_iter().filter(|e| e.size != 0).collect(); + filtered.sort_by(|a, b| a.start.cmp(&b.start)); + Ok(filtered) +} + +impl Meta { + pub fn find_address_symbol(&self, address: u32) -> Option { + use std::cmp::Ordering; + + self.symbols + .binary_search_by( + |Symbol { + start, + size, + name: _, + }| { + if address < *start { + Ordering::Greater + } else { + let end = *start + *size as u32; + if address >= end { + Ordering::Less + } else { + Ordering::Equal + } + } + }, + ) + .map_or_else(|_| None, |idx| Some(self.symbols[idx].name.to_string())) + } +} + +#[cfg(test)] +mod tests { + + use super::*; + use std::{ + fs::File, + io::{BufReader, Write}, + }; + + #[test] + fn sp_parser() { + use StepFrequency::*; + assert_eq!(step_frequency_parser("never"), Ok(Never)); + assert_eq!(step_frequency_parser("always"), Ok(Always)); + assert_eq!(step_frequency_parser("=123"), Ok(Exactly(123))); + assert_eq!(step_frequency_parser("%123"), Ok(Every(123))); + assert_eq!(step_frequency_parser("1..3"), Ok(Range(1, Some(3)))); + assert_eq!(step_frequency_parser("1.."), Ok(Range(1, None))); + assert!(step_frequency_parser("@123").is_err()); + } + + // This sample is a subset taken from a Cannon-generated "meta.json" file + // Interestingly, it contains 0-size symbols - there are removed by + // deserialization. + const META_SAMPLE: &str = r#"{ + "symbols": [ + { + "name": "go.go", + "start": 0, + "size": 0 + }, + { + "name": "internal/cpu.processOptions", + "start": 69632, + "size": 1872 + }, + { + "name": "runtime.text", + "start": 69632, + "size": 0 + }, + { + "name": "runtime/internal/atomic.(*Uint8).Load", + "start": 71504, + "size": 28 + }, + { + "name": "runtime/internal/atomic.(*Uint8).Store", + "start": 71532, + "size": 28 + }, + { + "name": "runtime/internal/atomic.(*Uint8).And", + "start": 71560, + "size": 88 + }, + { + "name": "runtime/internal/atomic.(*Uint8).Or", + "start": 71648, + "size": 72 + }]}"#; + + fn deserialize_meta_sample() -> Meta { + serde_json::from_str::(META_SAMPLE).unwrap() + } + + #[test] + fn test_serialize_deserialize_page() { + let value: &str = r#"{"index":16,"data":"eJztlkFoE0EUht8k21ZEtFYFg1FCTW0qSGoTS6pFJU3TFlNI07TEQJHE1kJMmhwi1ihaRJCqiAdBKR5Ez4IXvQk5eBaP4iEWpAchV0Hoof5vd14SoQcvve0H/5s3O//OzuzMLHtvNBZVDkUNHLQLUdHugSTKINJgnDoNZB60+MhFBq63Q0G4LCFYQptZoKR9r0hpEc1r4bopy8WRtdptmCJqM+t89RHiY60Xc39M8b26XXUjHLdEbf4qdTyMIWvn9vnyxhTy7eBxGwvGoRWU23ASIqNE5MT4H2DslogOa/EY+f38LxiNKYyrEwW02sV9CJLfgdjnMOfLc0+6biMKHohJFLe2fqO0qLl4Hui0AfcB1H0EzEFTc73GtSfIBO0jnhvnDvpx5CLVIJoKoS7Ic59C2pdfoRpEe+KoC+J7CWnf8leqQf/CbcwbiHP2rcO3TuENfr+C9HcGYp+T15nXnMjdOl/JOyDtc3tUt9tDzto31AXprwuyfCc2SfVsohZ8j7ogPh4Lr7NT+fxV1Yv9pXJ11AXxHYUsX99aVfnWqkT11vcsvk8QnstWJD4EUr0Igt4HqodD0wdP59kIUkH76DvU9IXOXSfnr0tIBe1T5zlAJmrY+xHFICRIG+8p5Lq/YW+djt1tfX/S314ODV/67Wc6eOEZUkF8CxwavqWfSWo/9QWpoH2UhXjtHDhn+E6wzO+EIL4RnEk+nOzDnmWZayRYDyJ6BzkgE3Vjv5faYrjV9F6DuD/eMx+gxvlQlbnndMDdh1TA2G1sbGxsbGxsbGx2Co9Sqvk/2gL/r05DxlgRP8bZK0O50cJQPjMxO5HKhCOlQr8/sVy5uRTuD5RGKuXFaDgYSQ+E/LOlsZlEIZ8NBqKlcmby8mIpPOjPpWYmxwPF06lI+mpqPB+O35ou0l+FGHpe"}"#; + let decoded_page: Page = serde_json::from_str(value).unwrap(); + let res = serde_json::to_string(&decoded_page).unwrap(); + assert_eq!(res, value); + } + + #[test] + fn test_preimage_key_serialisation() { + #[derive(Serialize, Deserialize)] + struct TestPreimageKeyStruct { + #[serde( + rename = "preimageKey", + deserialize_with = "deserialize_preimage_key", + serialize_with = "serialize_preimage_key" + )] + pub preimage_key: [u8; 32], + } + + let preimage_key: &str = r#"{"preimageKey":"0x0000000000000000000000000000000000000000000000000000000000000000"}"#; + let s: TestPreimageKeyStruct = serde_json::from_str(preimage_key).unwrap(); + let res = serde_json::to_string(&s).unwrap(); + assert_eq!(preimage_key, res); + } + + #[test] + fn test_meta_deserialize_from_file() { + let path = "meta_test.json"; + let mut output = + File::create(path).unwrap_or_else(|_| panic!("Could not create file {path}")); + write!(output, "{}", META_SAMPLE) + .unwrap_or_else(|_| panic!("Could not write to file {path}")); + + let input = File::open(path).unwrap_or_else(|_| panic!("Could not open file {path}")); + let buffered = BufReader::new(input); + let read: Meta = serde_json::from_reader(buffered) + .unwrap_or_else(|_| panic!("Failed to deserialize metadata from file {path}")); + + let expected = Meta { + symbols: vec![ + Symbol { + name: "internal/cpu.processOptions".to_string(), + start: 69632, + size: 1872, + }, + Symbol { + name: "runtime/internal/atomic.(*Uint8).Load".to_string(), + start: 71504, + size: 28, + }, + Symbol { + name: "runtime/internal/atomic.(*Uint8).Store".to_string(), + start: 71532, + size: 28, + }, + Symbol { + name: "runtime/internal/atomic.(*Uint8).And".to_string(), + start: 71560, + size: 88, + }, + Symbol { + name: "runtime/internal/atomic.(*Uint8).Or".to_string(), + start: 71648, + size: 72, + }, + ], + }; + + assert_eq!(read, expected); + } + + #[test] + fn test_find_address_symbol() { + let meta = deserialize_meta_sample(); + + assert_eq!( + meta.find_address_symbol(69633), + Some("internal/cpu.processOptions".to_string()) + ); + assert_eq!( + meta.find_address_symbol(69632), + Some("internal/cpu.processOptions".to_string()) + ); + assert_eq!(meta.find_address_symbol(42), None); + } + + #[test] + fn test_parse_preimagekey() { + assert_eq!( + PreimageKey::from_str( + "0x0000000000000000000000000000000000000000000000000000000000000000" + ), + Ok(PreimageKey([0; 32])) + ); + assert_eq!( + PreimageKey::from_str( + "0x0000000000000000000000000000000000000000000000000000000000000001" + ), + Ok(PreimageKey([ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 1 + ])) + ); + assert!(PreimageKey::from_str("0x01").is_err()); + } +} + +pub const HINT_CLIENT_READ_FD: i32 = 3; +pub const HINT_CLIENT_WRITE_FD: i32 = 4; +pub const PREIMAGE_CLIENT_READ_FD: i32 = 5; +pub const PREIMAGE_CLIENT_WRITE_FD: i32 = 6; + +pub struct Preimage(Vec); + +impl Preimage { + pub fn create(v: Vec) -> Self { + Preimage(v) + } + + pub fn get(self) -> Vec { + self.0 + } +} + +pub struct Hint(Vec); + +impl Hint { + pub fn create(v: Vec) -> Self { + Hint(v) + } + + pub fn get(self) -> Vec { + self.0 + } +} diff --git a/o1vm/src/cannon_cli.rs b/o1vm/src/cannon_cli.rs new file mode 100644 index 0000000000..d894570a5f --- /dev/null +++ b/o1vm/src/cannon_cli.rs @@ -0,0 +1,112 @@ +use crate::cannon::*; +use clap::{arg, value_parser, Arg, ArgAction}; + +pub fn main_cli() -> clap::Command { + let app_name = "o1vm"; + clap::Command::new(app_name) + .version("0.1") + .about("o1vm - a generic purpose zero-knowledge virtual machine") + .arg(arg!(--input "initial state file").default_value("state.json")) + .arg(arg!(--output "output state file").default_value("out.json")) + .arg(arg!(--meta "metadata file").default_value("meta.json")) + // The CLI arguments below this line are ignored at this point + .arg( + Arg::new("proof-at") + .short('p') + .long("proof-at") + .value_name("FREQ") + .default_value("never") + .value_parser(step_frequency_parser), + ) + .arg( + Arg::new("proof-fmt") + .long("proof-fmt") + .value_name("FORMAT") + .default_value("proof-%d.json"), + ) + .arg( + Arg::new("snapshot-fmt") + .long("snapshot-fmt") + .value_name("FORMAT") + .default_value("state-%d.json"), + ) + .arg( + Arg::new("stop-at") + .long("stop-at") + .value_name("FREQ") + .default_value("never") + .value_parser(step_frequency_parser), + ) + .arg( + Arg::new("info-at") + .long("info-at") + .value_name("FREQ") + .default_value("never") + .value_parser(step_frequency_parser), + ) + .arg( + Arg::new("pprof-cpu") + .long("pprof.cpu") + .action(ArgAction::SetTrue), + ) + .arg( + arg!(host: [HOST] "host program specification [host program arguments]") + .num_args(1..) + .last(true) + .value_parser(value_parser!(String)), + ) + .arg( + Arg::new("snapshot-state-at") + .long("snapshot-state-at") + .value_name("FREQ") + .default_value("never") + .value_parser(step_frequency_parser), + ) +} + +pub fn read_configuration(cli: &clap::ArgMatches) -> VmConfiguration { + let input_state_file = cli.get_one::("input").unwrap(); + let output_state_file = cli.get_one::("output").unwrap(); + let metadata_file = cli.get_one::("meta").unwrap(); + + let proof_at = cli.get_one::("proof-at").unwrap(); + let info_at = cli.get_one::("info-at").unwrap(); + let stop_at = cli.get_one::("stop-at").unwrap(); + let snapshot_state_at = cli.get_one::("snapshot-state-at").unwrap(); + + let proof_fmt = cli.get_one::("proof-fmt").unwrap(); + let snapshot_fmt = cli.get_one::("snapshot-fmt").unwrap(); + let pprof_cpu = cli.get_one::("pprof-cpu").unwrap(); + + let host_spec = cli + .get_many::("host") + .map(|vals| vals.collect::>()) + .unwrap_or_default(); + + let host = if host_spec.is_empty() { + None + } else { + Some(HostProgram { + name: host_spec[0].to_string(), + arguments: host_spec[1..] + .to_vec() + .iter() + .map(|x| x.to_string()) + .collect(), + }) + }; + + VmConfiguration { + input_state_file: input_state_file.to_string(), + output_state_file: output_state_file.to_string(), + metadata_file: metadata_file.to_string(), + proof_at: proof_at.clone(), + stop_at: stop_at.clone(), + snapshot_state_at: snapshot_state_at.clone(), + info_at: info_at.clone(), + proof_fmt: proof_fmt.to_string(), + snapshot_fmt: snapshot_fmt.to_string(), + pprof_cpu: *pprof_cpu, + host, + } +} diff --git a/o1vm/src/elf_loader.rs b/o1vm/src/elf_loader.rs new file mode 100644 index 0000000000..0fc26a4788 --- /dev/null +++ b/o1vm/src/elf_loader.rs @@ -0,0 +1,145 @@ +use crate::cannon::{Page, State, PAGE_SIZE}; +use elf::{endian::LittleEndian, section::SectionHeader, ElfBytes}; +use log::debug; +use std::{collections::HashMap, path::Path}; + +/// Parse an ELF file and return the parsed data as a structure that is expected +/// by the o1vm RISC-V 32 bits edition. +// FIXME: parametrize by an architecture. We should return a state depending on the +// architecture. In the meantime, we can have parse_riscv32i and parse_mips. +// FIXME: for now, we return a State structure, either for RISC-V 32i or MIPS. +// We should return a structure specifically built for the o1vm, and not tight +// to Cannon. It will be done in a future PR to avoid breaking the current code +// and have a huge diff. +pub fn parse_riscv32(path: &Path) -> Result { + debug!("Start parsing the ELF file to load a RISC-V 32i compatible state"); + let file_data = std::fs::read(path).expect("Could not read file."); + let slice = file_data.as_slice(); + let file = ElfBytes::::minimal_parse(slice).expect("Open ELF file failed."); + + // Checking it is RISC-V + assert_eq!(file.ehdr.e_machine, 243); + + let (shdrs_opt, strtab_opt) = file + .section_headers_with_strtab() + .expect("shdrs offsets should be valid"); + let (shdrs, strtab) = ( + shdrs_opt.expect("Should have shdrs"), + strtab_opt.expect("Should have strtab"), + ); + + // Parse the shdrs and collect them into a map keyed on their zero-copied name + let sections_by_name: HashMap<&str, SectionHeader> = shdrs + .iter() + .map(|shdr| { + ( + strtab + .get(shdr.sh_name as usize) + .expect("Failed to get section name"), + shdr, + ) + }) + .collect(); + + debug!("Loading the text section, which contains the executable code."); + // Getting the executable code. + let text_section = sections_by_name + .get(".text") + .expect("Should have .text section"); + + let (text_section_data, _) = file + .section_data(text_section) + .expect("Failed to read data from .text section"); + + let code_section_starting_address = text_section.sh_addr as usize; + let code_section_size = text_section.sh_size as usize; + let code_section_end_address = code_section_starting_address + code_section_size; + debug!( + "The executable code starts at address {}, has size {} bytes, and ends at address {}.", + code_section_starting_address, code_section_size, code_section_end_address + ); + + // Building the memory pages + let mut memory: Vec = vec![]; + let page_size_usize: usize = PAGE_SIZE.try_into().unwrap(); + // Padding to get the right number of pages. We suppose that the memory + // index starts at 0. + let start_page_address: usize = + (code_section_starting_address / page_size_usize) * page_size_usize; + let end_page_address = + (code_section_end_address / (page_size_usize - 1)) * page_size_usize + page_size_usize; + + let first_page_index = start_page_address / page_size_usize; + let last_page_index = (end_page_address - 1) / page_size_usize; + let mut data_offset = 0; + (first_page_index..=last_page_index).for_each(|page_index| { + let mut data = vec![0; page_size_usize]; + // Special case of only one page + if first_page_index == last_page_index { + let data_length = code_section_end_address - code_section_starting_address; + let page_offset = code_section_starting_address - start_page_address; + data[page_offset..page_offset + data_length] + .copy_from_slice(&text_section_data[0..data_length]); + data_offset += data_length; + } else { + let data_length = if page_index == last_page_index { + code_section_end_address - (page_index * page_size_usize) + } else { + page_size_usize + }; + let page_offset = if page_index == first_page_index { + code_section_starting_address - start_page_address + } else { + 0 + }; + data[page_offset..] + .copy_from_slice(&text_section_data[data_offset..data_offset + data_length]); + data_offset += data_length; + } + let page = Page { + index: page_index as u32, + data, + }; + memory.push(page); + }); + + // FIXME: add data section into memory for static data saved in the binary + + // FIXME: we're lucky that RISCV32i and MIPS have the same number of + let registers: [u32; 32] = [0; 32]; + + // FIXME: it is only because we share the same structure for the state. + let preimage_key: [u8; 32] = [0; 32]; + // FIXME: it is only because we share the same structure for the state. + let preimage_offset = 0; + + // Entry point of the program + let pc: u32 = file.ehdr.e_entry as u32; + assert!(pc != 0, "Entry point is 0. The documentation of the ELF library says that it means the ELF doesn't have an entry point. This is not supported."); + let next_pc: u32 = pc + 4u32; + + let state = State { + memory, + // FIXME: only because Cannon related + preimage_key, + // FIXME: only because Cannon related + preimage_offset, + pc, + next_pc, + // FIXME: only because Cannon related + lo: 0, + // FIXME: only because Cannon related + hi: 0, + heap: 0, + exit: 0, + exited: false, + step: 0, + registers, + // FIXME: only because Cannon related + last_hint: None, + // FIXME: only because Cannon related + preimage: None, + }; + + Ok(state) +} diff --git a/o1vm/src/interpreters/keccak/column.rs b/o1vm/src/interpreters/keccak/column.rs new file mode 100644 index 0000000000..fe2e1cd9b1 --- /dev/null +++ b/o1vm/src/interpreters/keccak/column.rs @@ -0,0 +1,409 @@ +//! This module defines the custom columns used in the Keccak witness, which +//! are aliases for the actual Keccak witness columns also defined here. +use self::{Absorbs::*, Sponges::*, Steps::*}; +use crate::interpreters::keccak::{ZKVM_KECCAK_COLS_CURR, ZKVM_KECCAK_COLS_NEXT}; +use kimchi::circuits::polynomials::keccak::constants::{ + CHI_SHIFTS_B_LEN, CHI_SHIFTS_B_OFF, CHI_SHIFTS_SUM_LEN, CHI_SHIFTS_SUM_OFF, PIRHO_DENSE_E_LEN, + PIRHO_DENSE_E_OFF, PIRHO_DENSE_ROT_E_LEN, PIRHO_DENSE_ROT_E_OFF, PIRHO_EXPAND_ROT_E_LEN, + PIRHO_EXPAND_ROT_E_OFF, PIRHO_QUOTIENT_E_LEN, PIRHO_QUOTIENT_E_OFF, PIRHO_REMAINDER_E_LEN, + PIRHO_REMAINDER_E_OFF, PIRHO_SHIFTS_E_LEN, PIRHO_SHIFTS_E_OFF, QUARTERS, RATE_IN_BYTES, + SPONGE_BYTES_LEN, SPONGE_BYTES_OFF, SPONGE_COLS, SPONGE_NEW_STATE_LEN, SPONGE_NEW_STATE_OFF, + SPONGE_SHIFTS_LEN, SPONGE_SHIFTS_OFF, SPONGE_ZEROS_LEN, SPONGE_ZEROS_OFF, STATE_LEN, + THETA_DENSE_C_LEN, THETA_DENSE_C_OFF, THETA_DENSE_ROT_C_LEN, THETA_DENSE_ROT_C_OFF, + THETA_EXPAND_ROT_C_LEN, THETA_EXPAND_ROT_C_OFF, THETA_QUOTIENT_C_LEN, THETA_QUOTIENT_C_OFF, + THETA_REMAINDER_C_LEN, THETA_REMAINDER_C_OFF, THETA_SHIFTS_C_LEN, THETA_SHIFTS_C_OFF, +}; +use kimchi_msm::{ + columns::{Column, ColumnIndexer}, + witness::Witness, +}; +use std::ops::{Index, IndexMut}; +use strum::{EnumCount, IntoEnumIterator}; +use strum_macros::{EnumCount, EnumIter}; + +/// The maximum total number of witness columns used by the Keccak circuit. +/// Note that in round steps, the columns used to store padding information are not needed. +pub const N_ZKVM_KECCAK_REL_COLS: usize = STATUS_LEN + CURR_LEN + NEXT_LEN + RC_LEN; + +/// The number of columns required for the Keccak selectors. They are located after the relation columns. +pub const N_ZKVM_KECCAK_SEL_COLS: usize = 6; + +/// Total number of columns used in Keccak, including relation and selectors +pub const N_ZKVM_KECCAK_COLS: usize = N_ZKVM_KECCAK_REL_COLS + N_ZKVM_KECCAK_SEL_COLS; + +const STATUS_OFF: usize = 0; // The offset of the columns reserved for the status indices +const STATUS_LEN: usize = 3; // The number of columns used by the Keccak circuit to represent the status flags. + +const CURR_OFF: usize = STATUS_OFF + STATUS_LEN; // The offset of the curr chunk inside the witness columns +const CURR_LEN: usize = ZKVM_KECCAK_COLS_CURR; // The length of the curr chunk inside the witness columns +const NEXT_OFF: usize = CURR_OFF + CURR_LEN; // The offset of the next chunk inside the witness columns +const NEXT_LEN: usize = ZKVM_KECCAK_COLS_NEXT; // The length of the next chunk inside the witness columns + +/// The number of sparse round constants used per round +pub(crate) const ROUND_CONST_LEN: usize = QUARTERS; +const RC_OFF: usize = NEXT_OFF + NEXT_LEN; // The offset of the Round coefficients inside the witness columns +const RC_LEN: usize = ROUND_CONST_LEN + 1; // The round constants plus the round number + +const PAD_FLAGS_OFF: usize = STATUS_LEN + SPONGE_COLS; // Offset of the Pad flags inside the witness columns. Starts after sponge columns are finished. +const PAD_LEN_OFF: usize = 0; // Offset of the PadLength column inside the sponge coefficients +const PAD_TWO_OFF: usize = 1; // Offset of the TwoToPad column inside the sponge coefficients +const PAD_SUFFIX_OFF: usize = 2; // Offset of the PadSuffix column inside the sponge coefficients +/// The padding suffix of 1088 bits is stored as 5 field elements: 1x12 + 4x31 bytes +pub(crate) const PAD_SUFFIX_LEN: usize = 5; +const PAD_BYTES_OFF: usize = PAD_SUFFIX_OFF + PAD_SUFFIX_LEN; // Offset of the PadBytesFlags inside the sponge coefficients +/// The maximum number of padding bytes involved +pub(crate) const PAD_BYTES_LEN: usize = RATE_IN_BYTES; + +const FLAG_ROUND_OFF: usize = N_ZKVM_KECCAK_REL_COLS; // Offset of the Round selector inside DynamicSelector +const FLAG_FST_OFF: usize = FLAG_ROUND_OFF + 1; // Offset of the Absorb(First) selector inside DynamicSelector +const FLAG_MID_OFF: usize = FLAG_ROUND_OFF + 2; // Offset of the Absorb(Middle) selector inside DynamicSelector +const FLAG_LST_OFF: usize = FLAG_ROUND_OFF + 3; // Offset of the Absorb(Last) selector inside DynamicSelector +const FLAG_ONE_OFF: usize = FLAG_ROUND_OFF + 4; // Offset of the Absorb(Only) selector inside DynamicSelector +const FLAG_SQUEEZE_OFF: usize = FLAG_ROUND_OFF + 5; // Offset of the Squeeze selector inside DynamicSelector + +/// Column aliases used by the Keccak circuit. +/// The number of aliases is not necessarily equal to the actual number of +/// columns. +/// Each alias will be mapped to a column index depending on the step kind +/// (Sponge or Round) that is currently being executed. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum ColumnAlias { + /// Hash identifier to distinguish inside the syscalls communication channel + HashIndex, + /// Block index inside the hash to enumerate preimage bytes + BlockIndex, + /// Hash step identifier to distinguish inside interstep communication + StepIndex, + + Input(usize), // Curr[0..100) either ThetaStateA or SpongeOldState + Output(usize), // Next[0..100) either IotaStateG or SpongeXorState + + ThetaShiftsC(usize), // Round Curr[100..180) + ThetaDenseC(usize), // Round Curr[180..200) + ThetaQuotientC(usize), // Round Curr[200..205) + ThetaRemainderC(usize), // Round Curr[205..225) + ThetaDenseRotC(usize), // Round Curr[225..245) + ThetaExpandRotC(usize), // Round Curr[245..265) + PiRhoShiftsE(usize), // Round Curr[265..665) + PiRhoDenseE(usize), // Round Curr[665..765) + PiRhoQuotientE(usize), // Round Curr[765..865) + PiRhoRemainderE(usize), // Round Curr[865..965) + PiRhoDenseRotE(usize), // Round Curr[965..1065) + PiRhoExpandRotE(usize), // Round Curr[1065..1165) + ChiShiftsB(usize), // Round Curr[1165..1565) + ChiShiftsSum(usize), // Round Curr[1565..1965) + + SpongeNewState(usize), // Sponge Curr[100..200) + SpongeZeros(usize), // Sponge Curr[168..200) + SpongeBytes(usize), // Sponge Curr[200..400) + SpongeShifts(usize), // Sponge Curr[400..800) + + RoundNumber, // Only nonzero when Selector(Flag::Round) = 1 : Round 0 | 1 ..=23 + RoundConstants(usize), // Only nonzero when Selector(Flag::Round) = 1 : Round constants + + PadLength, // Only nonzero when Selector(Flag::Pad) = 1 : Length 0 | 1 ..=136 + TwoToPad, // Only nonzero when Selector(Flag::Pad) = 1 : 2^PadLength + PadSuffix(usize), // Only nonzero when Selector(Flag::Pad) = 1 : 5 field elements + PadBytesFlags(usize), // Only nonzero when Selector(Flag::Pad) = 1 : 136 boolean values +} + +/// Variants of Keccak steps available for the interpreter. +/// These selectors determine the specific behaviour so that Keccak steps +/// can be split into different instances for folding +#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, EnumIter, EnumCount)] +pub enum Steps { + /// Current step performs a round of the permutation. + /// The round number stored in the Step is only used for the environment execution. + Round(u64), + /// Current step is a sponge + Sponge(Sponges), +} +/// Variants of Keccak sponges +#[derive( + Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, EnumIter, EnumCount, Default, +)] +pub enum Sponges { + Absorb(Absorbs), + #[default] + Squeeze, +} + +/// Order of absorb steps in the computation depending on the number of blocks to absorb +#[derive( + Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, EnumIter, EnumCount, Default, +)] +pub enum Absorbs { + First, // Also known as the root absorb + Middle, // Any other absorb + Last, // Also known as the padding absorb + #[default] + Only, // Only one block to absorb (preimage data is less than 136 bytes), both root and pad +} + +impl IntoIterator for Steps { + type Item = Steps; + type IntoIter = std::vec::IntoIter; + + /// Iterate over the instruction variants + fn into_iter(self) -> Self::IntoIter { + match self { + Steps::Round(_) => vec![Steps::Round(0)].into_iter(), + Steps::Sponge(_) => { + let mut iter_contents = Vec::with_capacity(Absorbs::COUNT + 1); + iter_contents + .extend(Absorbs::iter().map(|absorb| Steps::Sponge(Sponges::Absorb(absorb)))); + iter_contents.push(Steps::Sponge(Sponges::Squeeze)); + iter_contents.into_iter() + } + } + } +} + +impl From for usize { + /// Returns the index of the column corresponding to the given selector. + /// They are located at the end of the witness columns. + fn from(step: Steps) -> usize { + match step { + Round(_) => FLAG_ROUND_OFF, + Sponge(sponge) => match sponge { + Absorb(absorb) => match absorb { + First => FLAG_FST_OFF, + Middle => FLAG_MID_OFF, + Last => FLAG_LST_OFF, + Only => FLAG_ONE_OFF, + }, + Squeeze => FLAG_SQUEEZE_OFF, + }, + } + } +} + +// The columns used by the Keccak circuit. +// The Keccak circuit is split into two main modes: Round and Sponge (split into Root, Absorb, Pad, RootPad, Squeeze). +// The columns are shared between the Sponge and Round steps +// (the total number of columns refers to the maximum of columns used by each mode) +// The hash, block, and step indices are shared between both modes. +impl From for usize { + /// Returns the witness column index for the given alias + fn from(alias: ColumnAlias) -> usize { + match alias { + ColumnAlias::HashIndex => STATUS_OFF, + ColumnAlias::BlockIndex => STATUS_OFF + 1, + ColumnAlias::StepIndex => STATUS_OFF + 2, + + ColumnAlias::Input(i) => { + assert!(i < STATE_LEN); + CURR_OFF + i + } + ColumnAlias::Output(i) => { + assert!(i < STATE_LEN); + NEXT_OFF + i + } + + ColumnAlias::ThetaShiftsC(i) => { + assert!(i < THETA_SHIFTS_C_LEN); + CURR_OFF + THETA_SHIFTS_C_OFF + i + } + ColumnAlias::ThetaDenseC(i) => { + assert!(i < THETA_DENSE_C_LEN); + CURR_OFF + THETA_DENSE_C_OFF + i + } + ColumnAlias::ThetaQuotientC(i) => { + assert!(i < THETA_QUOTIENT_C_LEN); + CURR_OFF + THETA_QUOTIENT_C_OFF + i + } + ColumnAlias::ThetaRemainderC(i) => { + assert!(i < THETA_REMAINDER_C_LEN); + CURR_OFF + THETA_REMAINDER_C_OFF + i + } + ColumnAlias::ThetaDenseRotC(i) => { + assert!(i < THETA_DENSE_ROT_C_LEN); + CURR_OFF + THETA_DENSE_ROT_C_OFF + i + } + ColumnAlias::ThetaExpandRotC(i) => { + assert!(i < THETA_EXPAND_ROT_C_LEN); + CURR_OFF + THETA_EXPAND_ROT_C_OFF + i + } + ColumnAlias::PiRhoShiftsE(i) => { + assert!(i < PIRHO_SHIFTS_E_LEN); + CURR_OFF + PIRHO_SHIFTS_E_OFF + i + } + ColumnAlias::PiRhoDenseE(i) => { + assert!(i < PIRHO_DENSE_E_LEN); + CURR_OFF + PIRHO_DENSE_E_OFF + i + } + ColumnAlias::PiRhoQuotientE(i) => { + assert!(i < PIRHO_QUOTIENT_E_LEN); + CURR_OFF + PIRHO_QUOTIENT_E_OFF + i + } + ColumnAlias::PiRhoRemainderE(i) => { + assert!(i < PIRHO_REMAINDER_E_LEN); + CURR_OFF + PIRHO_REMAINDER_E_OFF + i + } + ColumnAlias::PiRhoDenseRotE(i) => { + assert!(i < PIRHO_DENSE_ROT_E_LEN); + CURR_OFF + PIRHO_DENSE_ROT_E_OFF + i + } + ColumnAlias::PiRhoExpandRotE(i) => { + assert!(i < PIRHO_EXPAND_ROT_E_LEN); + CURR_OFF + PIRHO_EXPAND_ROT_E_OFF + i + } + ColumnAlias::ChiShiftsB(i) => { + assert!(i < CHI_SHIFTS_B_LEN); + CURR_OFF + CHI_SHIFTS_B_OFF + i + } + ColumnAlias::ChiShiftsSum(i) => { + assert!(i < CHI_SHIFTS_SUM_LEN); + CURR_OFF + CHI_SHIFTS_SUM_OFF + i + } + + ColumnAlias::SpongeNewState(i) => { + assert!(i < SPONGE_NEW_STATE_LEN); + CURR_OFF + SPONGE_NEW_STATE_OFF + i + } + ColumnAlias::SpongeZeros(i) => { + assert!(i < SPONGE_ZEROS_LEN); + CURR_OFF + SPONGE_ZEROS_OFF + i + } + ColumnAlias::SpongeBytes(i) => { + assert!(i < SPONGE_BYTES_LEN); + CURR_OFF + SPONGE_BYTES_OFF + i + } + ColumnAlias::SpongeShifts(i) => { + assert!(i < SPONGE_SHIFTS_LEN); + CURR_OFF + SPONGE_SHIFTS_OFF + i + } + + ColumnAlias::RoundNumber => RC_OFF, + ColumnAlias::RoundConstants(i) => { + assert!(i < ROUND_CONST_LEN); + RC_OFF + 1 + i + } + + ColumnAlias::PadLength => PAD_FLAGS_OFF + PAD_LEN_OFF, + ColumnAlias::TwoToPad => PAD_FLAGS_OFF + PAD_TWO_OFF, + ColumnAlias::PadSuffix(i) => { + assert!(i < PAD_SUFFIX_LEN); + PAD_FLAGS_OFF + PAD_SUFFIX_OFF + i + } + ColumnAlias::PadBytesFlags(i) => { + assert!(i < PAD_BYTES_LEN); + PAD_FLAGS_OFF + PAD_BYTES_OFF + i + } + } + } +} + +/// The witness columns used by the Keccak circuit. +/// The Keccak circuit is split into two main modes: Sponge and Round. +/// The columns are shared between the Sponge and Round steps. +/// The hash and step indices are shared between both modes. +/// The row is split into the following entries: +/// - hash_index: Which hash this is inside the circuit +/// - block_index: Which block this is inside the hash +/// - step_index: Which step this is inside the hash +/// - curr: Contains 1965 witnesses used in the current step including Input +/// - next: Contains the 100 Output witnesses +/// - round_flags: contain 5 elements with information about the current round step +/// - pad_flags: PadLength, TwoToPad, PadBytesFlags, PadSuffix +/// - mode_flags: what kind of mode is running: round, root, absorb, pad, rootpad, squeeze. Only 1 of them can be active. +/// +/// Keccak Witness Columns: KeccakWitness.cols +/// ------------------------------------------------------- +/// | 0 | 1 | 2 | 3..=1967 | 1968..=2067 | 2068..=2071 | +/// ------------------------------------------------------- +/// 0 -> hash_index +/// 1 -> block_index +/// 2 -> step_index +/// 3..=1967 -> curr +/// 3 1967 +/// <--------------------------------if_round<----------------------------------> +/// <-------------if_sponge--------------> +/// 3 802 +/// -> SPONGE:   | -> ROUND: +/// -> 3..=102: Input == SpongeOldState | -> 3..=102: Input == ThetaStateA +/// -> 103..=202: SpongeNewState | -> 103..=182: ThetaShiftsC +/// : 170..=202 -> SpongeZeros | -> 183..=202: ThetaDenseC +/// -> 203..=402: SpongeBytes | -> 203..=207: ThetaQuotientC +/// -> 403..=802: SpongeShifts | -> 208..=227: ThetaRemainderC +/// | -> 228..=247: ThetaDenseRotC +/// | -> 248..=267: ThetaExpandRotC +/// | -> 268..=667: PiRhoShiftsE +/// | -> 668..=767: PiRhoDenseE +/// | -> 768..=867: PiRhoQuotientE +/// | -> 868..=967: PiRhoRemainderE +/// | -> 968..=1067: PiRhoDenseRotE +/// | -> 1068..=1167: PiRhoExpandRotE +/// | -> 1168..=1567: ChiShiftsB +/// | -> 1568..=1967: ChiShiftsSum +/// 1968..=2067 -> next +/// -> 1968..=2067: Output (if Round, then IotaStateG, if Sponge then SpongeXorState) +/// +/// 2068..=2072 -> round_flags +/// -> 2068: RoundNumber +/// -> 2069..=2072: RoundConstants +/// +/// 2073..=2078 -> selectors +/// +/// 803..=945 -> pad_flags +/// -> 803: PadLength +/// -> 804: TwoToPad +/// -> 805..=809: PadSuffix +/// -> 810..=945: PadBytesFlags +/// +pub type KeccakWitness = Witness; + +/// IMPLEMENTATIONS FOR COLUMN ALIAS + +impl Index for KeccakWitness { + type Output = T; + + /// Map the column alias to the actual column index. + /// Note that the column index depends on the step kind (Sponge or Round). + /// For instance, the column 800 represents PadLength in the Sponge step, while it + /// is used by intermediary values when executing the Round step. + fn index(&self, index: ColumnAlias) -> &Self::Output { + &self.cols[usize::from(index)] + } +} + +impl IndexMut for KeccakWitness { + fn index_mut(&mut self, index: ColumnAlias) -> &mut Self::Output { + &mut self.cols[usize::from(index)] + } +} + +impl ColumnIndexer for ColumnAlias { + const N_COL: usize = N_ZKVM_KECCAK_REL_COLS + N_ZKVM_KECCAK_SEL_COLS; + fn to_column(self) -> Column { + Column::Relation(usize::from(self)) + } +} + +// IMPLEMENTAIONS FOR SELECTOR + +impl Index for KeccakWitness { + type Output = T; + + /// Map the column alias to the actual column index. + /// Note that the column index depends on the step kind (Sponge or Round). + /// For instance, the column 800 represents PadLength in the Sponge step, while it + /// is used by intermediary values when executing the Round step. + /// The selector columns are located at the end of the witness relation columns. + fn index(&self, index: Steps) -> &Self::Output { + &self.cols[usize::from(index)] + } +} + +impl IndexMut for KeccakWitness { + fn index_mut(&mut self, index: Steps) -> &mut Self::Output { + &mut self.cols[usize::from(index)] + } +} + +impl ColumnIndexer for Steps { + const N_COL: usize = N_ZKVM_KECCAK_REL_COLS + N_ZKVM_KECCAK_SEL_COLS; + fn to_column(self) -> Column { + Column::DynamicSelector(usize::from(self) - N_ZKVM_KECCAK_REL_COLS) + } +} diff --git a/o1vm/src/interpreters/keccak/constraints.rs b/o1vm/src/interpreters/keccak/constraints.rs new file mode 100644 index 0000000000..092ed9c6be --- /dev/null +++ b/o1vm/src/interpreters/keccak/constraints.rs @@ -0,0 +1,82 @@ +//! This module contains the constraints for one Keccak step. +use crate::{ + interpreters::keccak::{ + helpers::{ArithHelpers, BoolHelpers, LogupHelpers}, + interpreter::{Interpreter, KeccakInterpreter}, + Constraint, KeccakColumn, + }, + lookups::Lookup, + E, +}; +use ark_ff::{Field, One}; +use kimchi::{ + circuits::{ + expr::{ConstantTerm::Literal, Expr, ExprInner, Operations, Variable}, + gate::CurrOrNext, + }, + o1_utils::Two, +}; +use kimchi_msm::columns::ColumnIndexer; + +/// This struct contains all that needs to be kept track of during the execution of the Keccak step interpreter +#[derive(Clone, Debug)] +pub struct Env { + /// Constraints that are added to the circuit + pub constraints: Vec>, + /// Variables that are looked up in the circuit + pub lookups: Vec>>, +} + +impl Default for Env { + fn default() -> Self { + Self { + constraints: Vec::new(), + lookups: Vec::new(), + } + } +} + +impl ArithHelpers for Env { + fn two_pow(x: u64) -> Self::Variable { + Self::constant_field(F::two_pow(x)) + } +} + +impl BoolHelpers for Env {} + +impl LogupHelpers for Env {} + +impl Interpreter for Env { + type Variable = E; + + fn constant(x: u64) -> Self::Variable { + Self::constant_field(F::from(x)) + } + + fn constant_field(x: F) -> Self::Variable { + Self::Variable::constant(Operations::from(Literal(x))) + } + + fn variable(&self, column: KeccakColumn) -> Self::Variable { + // Despite `KeccakWitness` containing both `curr` and `next` fields, + // the Keccak step spans across one row only. + Expr::Atom(ExprInner::Cell(Variable { + col: column.to_column(), + row: CurrOrNext::Curr, + })) + } + + fn constrain(&mut self, _tag: Constraint, if_true: Self::Variable, x: Self::Variable) { + if if_true == Self::Variable::one() { + self.constraints.push(x); + } + } + + fn add_lookup(&mut self, if_true: Self::Variable, lookup: Lookup) { + if if_true == Self::Variable::one() { + self.lookups.push(lookup); + } + } +} + +impl KeccakInterpreter for Env {} diff --git a/o1vm/src/interpreters/keccak/environment.rs b/o1vm/src/interpreters/keccak/environment.rs new file mode 100644 index 0000000000..7c8c4d087d --- /dev/null +++ b/o1vm/src/interpreters/keccak/environment.rs @@ -0,0 +1,460 @@ +//! This module contains the definition and implementation of the Keccak environment +//! including the common functions between the witness and the constraints environments +//! for arithmetic, boolean, and column operations. +use crate::{ + interpreters::keccak::{ + column::{ + Absorbs::{self, *}, + KeccakWitness, + Sponges::{self, *}, + Steps, + Steps::*, + PAD_SUFFIX_LEN, + }, + constraints::Env as ConstraintsEnv, + grid_index, + interpreter::KeccakInterpreter, + pad_blocks, standardize, + witness::Env as WitnessEnv, + KeccakColumn, DIM, HASH_BYTELENGTH, QUARTERS, WORDS_IN_HASH, + }, + lookups::Lookup, + E, +}; + +use ark_ff::Field; +use kimchi::{ + circuits::polynomials::keccak::{ + constants::*, + witness::{Chi, Iota, PiRho, Theta}, + Keccak, + }, + o1_utils::Two, +}; +use std::array; + +/// This struct contains all that needs to be kept track of during the execution of the Keccak step interpreter +#[derive(Clone, Debug)] +pub struct KeccakEnv { + /// Environment for the constraints (includes lookups). + /// The step of the hash that is being executed can be None if just ended + pub constraints_env: ConstraintsEnv, + /// Environment for the witness (includes multiplicities) + pub witness_env: WitnessEnv, + /// Current step + pub step: Option, + + /// Hash index in the circuit + pub(crate) hash_idx: u64, + /// Step counter of the total number of steps executed so far in the current hash (starts with 0) + pub(crate) step_idx: u64, + /// Current block of preimage data + pub(crate) block_idx: u64, + + /// Expanded block of previous step + pub(crate) prev_block: Vec, + /// How many blocks are left to absorb (including current absorb) + pub(crate) blocks_left_to_absorb: u64, + + /// Padded preimage data + pub(crate) padded: Vec, + /// Byte-length of the 10*1 pad (<=136) + pub(crate) pad_len: u64, + + /// Precomputed 2^pad_len + two_to_pad: [F; RATE_IN_BYTES], + /// Precomputed suffixes for the padding blocks + pad_suffixes: [[F; PAD_SUFFIX_LEN]; RATE_IN_BYTES], +} + +impl Default for KeccakEnv { + fn default() -> Self { + Self { + constraints_env: ConstraintsEnv::default(), + witness_env: WitnessEnv::default(), + step: None, + hash_idx: 0, + step_idx: 0, + block_idx: 0, + prev_block: vec![], + blocks_left_to_absorb: 0, + padded: vec![], + pad_len: 0, + two_to_pad: array::from_fn(|i| F::two_pow(1 + i as u64)), + pad_suffixes: array::from_fn(|i| pad_blocks::(1 + i)), + } + } +} + +impl KeccakEnv { + /// Starts a new Keccak environment for a given hash index and bytestring of preimage data + pub fn new(hash_idx: u64, preimage: &[u8]) -> Self { + // Must update the flag type at each step from the witness interpretation + let mut env = KeccakEnv:: { + hash_idx, + ..Default::default() + }; + + // Store hash index in the witness + env.write_column(KeccakColumn::HashIndex, env.hash_idx); + + // Update the number of blocks left to be absorbed depending on the length of the preimage + env.blocks_left_to_absorb = Keccak::num_blocks(preimage.len()) as u64; + + // Configure first step depending on number of blocks remaining, updating the selector for the row + env.step = if env.blocks_left_to_absorb == 1 { + Some(Sponge(Absorb(Only))) + } else { + Some(Sponge(Absorb(First))) + }; + env.step_idx = 0; + + // Root state (all zeros) shall be used for the first step + env.prev_block = vec![0u64; STATE_LEN]; + + // Pad preimage with the 10*1 padding rule + env.padded = Keccak::pad(preimage); + env.block_idx = 0; + env.pad_len = (env.padded.len() - preimage.len()) as u64; + + env + } + + /// Writes an integer value to a column of the Keccak witness + pub fn write_column(&mut self, column: KeccakColumn, value: u64) { + self.write_column_field(column, F::from(value)); + } + + /// Writes a field value to a column of the Keccak witness + pub fn write_column_field(&mut self, column: KeccakColumn, value: F) { + self.witness_env.witness[column] = value; + } + + /// Nullifies the Witness and Constraint environments by resetting it to default values (except for table-related) + /// This way, each row only adds the constraints of that step (do not nullify the step) + pub fn null_state(&mut self) { + self.witness_env.witness = KeccakWitness::default(); + self.witness_env.errors = vec![]; + // The multiplicities are not reset. + // The fixed tables are not modified. + self.constraints_env.constraints = vec![]; + self.constraints_env.lookups = vec![]; + // Step is not reset between iterations + } + + /// Returns the selector of the current step in standardized form + pub fn selector(&self) -> Steps { + standardize(self.step.unwrap()) + } + + /// Entrypoint for the interpreter. It executes one step of the Keccak circuit (one row), + /// and updates the environment accordingly (including the witness and inter-step lookups). + /// When it finishes, it updates the value of the current step, so that the next call to + /// the `step()` function executes the next step. + pub fn step(&mut self) { + // Reset columns to zeros to avoid conflicts between steps + self.null_state(); + + match self.step.unwrap() { + Sponge(typ) => self.run_sponge(typ), + Round(i) => self.run_round(i), + } + self.write_column(KeccakColumn::StepIndex, self.step_idx); + + self.update_step(); + } + + /// This function updates the next step of the environment depending on the current step + pub fn update_step(&mut self) { + match self.step { + Some(step) => match step { + Sponge(sponge) => match sponge { + Absorb(_) => self.step = Some(Round(0)), + Squeeze => self.step = None, + }, + Round(round) => { + if round < ROUNDS as u64 - 1 { + self.step = Some(Round(round + 1)); + } else { + self.blocks_left_to_absorb -= 1; + match self.blocks_left_to_absorb { + 0 => self.step = Some(Sponge(Squeeze)), + 1 => self.step = Some(Sponge(Absorb(Last))), + _ => self.step = Some(Sponge(Absorb(Middle))), + } + } + } + }, + None => panic!("No step to update"), + } + self.step_idx += 1; + } + + /// Updates the witness corresponding to the round value in [0..23] + fn set_flag_round(&mut self, round: u64) { + assert!(round < ROUNDS as u64); + self.write_column(KeccakColumn::RoundNumber, round); + } + + /// Updates and any other sponge flag depending on the kind of absorb step (root, padding, both). + fn set_flag_absorb(&mut self, absorb: Absorbs) { + match absorb { + Last | Only => { + // Step flag has been updated already + self.set_flags_pad(); + } + First | Middle => (), // Step flag has been updated already, + } + } + /// Sets the flag columns related to padding flags such as `PadLength`, `TwoToPad`, `PadBytesFlags`, and `PadSuffix`. + fn set_flags_pad(&mut self) { + // Initialize padding columns with precomputed values to speed up interpreter + self.write_column(KeccakColumn::PadLength, self.pad_len); + self.write_column_field( + KeccakColumn::TwoToPad, + self.two_to_pad[self.pad_len as usize - 1], + ); + let pad_range = RATE_IN_BYTES - self.pad_len as usize..RATE_IN_BYTES; + for i in pad_range { + self.write_column(KeccakColumn::PadBytesFlags(i), 1); + } + let pad_suffix = self.pad_suffixes[self.pad_len as usize - 1]; + for (idx, value) in pad_suffix.iter().enumerate() { + self.write_column_field(KeccakColumn::PadSuffix(idx), *value); + } + } + + /// Assigns the witness values needed in a sponge step (absorb or squeeze) + fn run_sponge(&mut self, sponge: Sponges) { + // Keep track of the round number for ease of debugging + match sponge { + Absorb(absorb) => self.run_absorb(absorb), + Squeeze => self.run_squeeze(), + } + } + /// Assigns the witness values needed in an absorb step (root, padding, or middle) + fn run_absorb(&mut self, absorb: Absorbs) { + self.set_flag_absorb(absorb); + + // Compute witness values + let ini_idx = RATE_IN_BYTES * self.block_idx as usize; + let mut block = self.padded[ini_idx..ini_idx + RATE_IN_BYTES].to_vec(); + self.write_column(KeccakColumn::BlockIndex, self.block_idx); + + // Pad with zeros + block.append(&mut vec![0; CAPACITY_IN_BYTES]); + + // Round + Mode of Operation (Sponge) + // state -> permutation(state) -> state' + // ----> either [0] or state' + // | new state = Exp(block) + // | ------------------------ + // Absorb: state + [ block + 0...0 ] + // 1088 bits 512 + // ---------------------------------- + // XOR STATE + let old_state = self.prev_block.clone(); + let new_state = Keccak::expand_state(&block); + let xor_state = old_state + .iter() + .zip(new_state.clone()) + .map(|(x, y)| x + y) + .collect::>(); + + let shifts = Keccak::shift(&new_state); + let bytes = block.iter().map(|b| *b as u64).collect::>(); + + // Write absorb-related columns + for idx in 0..STATE_LEN { + self.write_column(KeccakColumn::Input(idx), old_state[idx]); + self.write_column(KeccakColumn::SpongeNewState(idx), new_state[idx]); + self.write_column(KeccakColumn::Output(idx), xor_state[idx]); + } + for (idx, value) in bytes.iter().enumerate() { + self.write_column(KeccakColumn::SpongeBytes(idx), *value); + } + for (idx, value) in shifts.iter().enumerate() { + self.write_column(KeccakColumn::SpongeShifts(idx), *value); + } + // Rest is zero thanks to null_state + + // Update environment + self.prev_block = xor_state; + self.block_idx += 1; // To be used in next absorb (if any) + } + /// Assigns the witness values needed in a squeeze step + fn run_squeeze(&mut self) { + // Squeeze step is already updated + + // Compute witness values + let state = self.prev_block.clone(); + let shifts = Keccak::shift(&state); + let dense = Keccak::collapse(&Keccak::reset(&shifts)); + let bytes = Keccak::bytestring(&dense); + + // Write squeeze-related columns + for (idx, value) in state.iter().enumerate() { + self.write_column(KeccakColumn::Input(idx), *value); + } + for (idx, value) in bytes.iter().enumerate().take(HASH_BYTELENGTH) { + self.write_column(KeccakColumn::SpongeBytes(idx), *value); + } + for idx in 0..WORDS_IN_HASH * QUARTERS { + self.write_column(KeccakColumn::SpongeShifts(idx), shifts[idx]); + self.write_column(KeccakColumn::SpongeShifts(100 + idx), shifts[100 + idx]); + self.write_column(KeccakColumn::SpongeShifts(200 + idx), shifts[200 + idx]); + self.write_column(KeccakColumn::SpongeShifts(300 + idx), shifts[300 + idx]); + } + + // Rest is zero thanks to null_state + } + /// Assigns the witness values needed in the round step for the given round index + fn run_round(&mut self, round: u64) { + self.set_flag_round(round); + + let state_a = self.prev_block.clone(); + let state_e = self.run_theta(&state_a); + let state_b = self.run_pirho(&state_e); + let state_f = self.run_chi(&state_b); + let state_g = self.run_iota(&state_f, round as usize); + + // Update block for next step with the output of the round + self.prev_block = state_g; + } + /// Assigns the witness values needed in the theta algorithm + /// ```text + /// for x in 0…4 + /// C[x] = A[x,0] xor A[x,1] xor \ + /// A[x,2] xor A[x,3] xor \ + /// A[x,4] + /// for x in 0…4 + /// D[x] = C[x-1] xor rot(C[x+1],1) + /// for (x,y) in (0…4,0…4) + /// A[x,y] = A[x,y] xor D[x] + /// ``` + fn run_theta(&mut self, state_a: &[u64]) -> Vec { + let theta = Theta::create(state_a); + + // Write Theta-related columns + for x in 0..DIM { + self.write_column(KeccakColumn::ThetaQuotientC(x), theta.quotient_c(x)); + for q in 0..QUARTERS { + let idx = grid_index(QUARTERS * DIM, 0, 0, x, q); + self.write_column(KeccakColumn::ThetaDenseC(idx), theta.dense_c(x, q)); + self.write_column(KeccakColumn::ThetaRemainderC(idx), theta.remainder_c(x, q)); + self.write_column(KeccakColumn::ThetaDenseRotC(idx), theta.dense_rot_c(x, q)); + self.write_column(KeccakColumn::ThetaExpandRotC(idx), theta.expand_rot_c(x, q)); + for y in 0..DIM { + let idx = grid_index(THETA_STATE_A_LEN, 0, y, x, q); + self.write_column(KeccakColumn::Input(idx), state_a[idx]); + } + for i in 0..QUARTERS { + let idx = grid_index(THETA_SHIFTS_C_LEN, i, 0, x, q); + self.write_column(KeccakColumn::ThetaShiftsC(idx), theta.shifts_c(i, x, q)); + } + } + } + theta.state_e() + } + /// Assigns the witness values needed in the pirho algorithm + /// ```text + /// for (x,y) in (0…4,0…4) + /// B[y,2*x+3*y] = rot(A[x,y], r[x,y]) + /// ``` + fn run_pirho(&mut self, state_e: &[u64]) -> Vec { + let pirho = PiRho::create(state_e); + + // Write PiRho-related columns + for y in 0..DIM { + for x in 0..DIM { + for q in 0..QUARTERS { + let idx = grid_index(STATE_LEN, 0, y, x, q); + self.write_column(KeccakColumn::PiRhoDenseE(idx), pirho.dense_e(y, x, q)); + self.write_column(KeccakColumn::PiRhoQuotientE(idx), pirho.quotient_e(y, x, q)); + self.write_column( + KeccakColumn::PiRhoRemainderE(idx), + pirho.remainder_e(y, x, q), + ); + self.write_column( + KeccakColumn::PiRhoDenseRotE(idx), + pirho.dense_rot_e(y, x, q), + ); + self.write_column( + KeccakColumn::PiRhoExpandRotE(idx), + pirho.expand_rot_e(y, x, q), + ); + for i in 0..QUARTERS { + self.write_column( + KeccakColumn::PiRhoShiftsE(grid_index(PIRHO_SHIFTS_E_LEN, i, y, x, q)), + pirho.shifts_e(i, y, x, q), + ); + } + } + } + } + pirho.state_b() + } + /// Assigns the witness values needed in the chi algorithm + /// ```text + /// for (x,y) in (0…4,0…4) + /// A[x, y] = B[x,y] xor ((not B[x+1,y]) and B[x+2,y]) + /// ``` + fn run_chi(&mut self, state_b: &[u64]) -> Vec { + let chi = Chi::create(state_b); + + // Write Chi-related columns + for i in 0..SHIFTS { + for y in 0..DIM { + for x in 0..DIM { + for q in 0..QUARTERS { + let idx = grid_index(SHIFTS_LEN, i, y, x, q); + self.write_column(KeccakColumn::ChiShiftsB(idx), chi.shifts_b(i, y, x, q)); + self.write_column( + KeccakColumn::ChiShiftsSum(idx), + chi.shifts_sum(i, y, x, q), + ); + } + } + } + } + chi.state_f() + } + /// Assigns the witness values needed in the iota algorithm + /// ```text + /// A[0,0] = A[0,0] xor RC + /// ``` + fn run_iota(&mut self, state_f: &[u64], round: usize) -> Vec { + let iota = Iota::create(state_f, round); + let state_g = iota.state_g(); + + // Update columns + for (idx, g) in state_g.iter().enumerate() { + self.write_column(KeccakColumn::Output(idx), *g); + } + for idx in 0..QUARTERS { + self.write_column(KeccakColumn::RoundConstants(idx), iota.round_constants(idx)); + } + + state_g + } + + /// Returns the list of constraints used in a specific Keccak step + pub(crate) fn constraints_of(step: Steps) -> Vec> { + let mut env = ConstraintsEnv { + constraints: vec![], + lookups: vec![], + }; + env.constraints(step); + env.constraints + } + + /// Returns the list of lookups used in a specific Keccak step + pub(crate) fn lookups_of(step: Steps) -> Vec>> { + let mut env = ConstraintsEnv { + constraints: vec![], + lookups: vec![], + }; + env.lookups(step); + env.lookups + } +} diff --git a/o1vm/src/interpreters/keccak/helpers.rs b/o1vm/src/interpreters/keccak/helpers.rs new file mode 100644 index 0000000000..0babb2927f --- /dev/null +++ b/o1vm/src/interpreters/keccak/helpers.rs @@ -0,0 +1,140 @@ +use crate::{ + interpreters::keccak::interpreter::Interpreter, + lookups::{Lookup, LookupTableIDs::*}, +}; +use ark_ff::{One, Zero}; +use std::fmt::Debug; + +/// This trait contains helper functions for the lookups used in the Keccak circuit +/// using the zkVM lookup tables +pub trait LogupHelpers +where + Self: Interpreter, +{ + /// Adds a lookup to the RangeCheck16 table + fn lookup_rc16(&mut self, flag: Self::Variable, value: Self::Variable) { + self.add_lookup(flag, Lookup::read_one(RangeCheck16Lookup, vec![value])); + } + + /// Adds a lookup to the Reset table + fn lookup_reset( + &mut self, + flag: Self::Variable, + dense: Self::Variable, + sparse: Self::Variable, + ) { + self.add_lookup(flag, Lookup::read_one(ResetLookup, vec![dense, sparse])); + } + + /// Adds a lookup to the Shift table + fn lookup_sparse(&mut self, flag: Self::Variable, value: Self::Variable) { + self.add_lookup(flag, Lookup::read_one(SparseLookup, vec![value])); + } + + /// Adds a lookup to the Byte table + fn lookup_byte(&mut self, flag: Self::Variable, value: Self::Variable) { + self.add_lookup(flag, Lookup::read_one(ByteLookup, vec![value])); + } + + /// Adds a lookup to the Pad table + fn lookup_pad(&mut self, flag: Self::Variable, value: Vec) { + self.add_lookup(flag, Lookup::read_one(PadLookup, value)); + } + + /// Adds a lookup to the RoundConstants table + fn lookup_round_constants(&mut self, flag: Self::Variable, value: Vec) { + self.add_lookup(flag, Lookup::read_one(RoundConstantsLookup, value)); + } + + fn read_syscall(&mut self, flag: Self::Variable, value: Vec) { + self.add_lookup(flag, Lookup::read_one(SyscallLookup, value)); + } + + fn write_syscall(&mut self, flag: Self::Variable, value: Vec) { + self.add_lookup(flag, Lookup::write_one(SyscallLookup, value)); + } +} + +/// This trait contains helper functions for boolean operations used in the Keccak circuit +pub trait BoolHelpers +where + Self: Interpreter, +{ + /// Degree-2 variable encoding whether the input is a boolean value (0 = yes) + fn is_boolean(x: Self::Variable) -> Self::Variable { + x.clone() * (x - Self::Variable::one()) + } + + /// Degree-1 variable encoding the negation of the input + /// Note: it only works as expected if the input is a boolean value + fn not(x: Self::Variable) -> Self::Variable { + Self::Variable::one() - x + } + + /// Degree-1 variable encoding whether the input is the value one (0 = yes) + fn is_one(x: Self::Variable) -> Self::Variable { + Self::not(x) + } + + /// Degree-2 variable encoding whether the first input is nonzero (0 = yes). + /// It requires the second input to be the multiplicative inverse of the first. + /// Note: if the first input is zero, there is no multiplicative inverse. + fn is_nonzero(x: Self::Variable, x_inv: Self::Variable) -> Self::Variable { + Self::is_one(x * x_inv) + } + + // The following two degree-2 constraints define the is_zero function meaning (1 = yes): + // - if x = 0 then is_zero(x) = 1 + // - if x ≠ 0 then is_zero(x) = 0 + // It is obtained by combining these two equations: + // x * xinv = 1 - z + // x * z = 0 + // When x = 0, xinv could be anything, like 1. + fn is_zero( + x: Self::Variable, + x_inv: Self::Variable, + z: Self::Variable, + ) -> (Self::Variable, Self::Variable) { + ( + x.clone() * x_inv.clone() - z.clone() + Self::Variable::one(), + x * z, + ) + } + + /// Degree-2 variable encoding the XOR of two variables which should be boolean (1 = true) + fn xor(x: Self::Variable, y: Self::Variable) -> Self::Variable { + x.clone() + y.clone() - Self::constant(2) * x * y + } + + /// Degree-2 variable encoding the OR of two variables, which should be boolean (1 = true) + fn or(x: Self::Variable, y: Self::Variable) -> Self::Variable { + x.clone() + y.clone() - x * y + } + + /// Degree-2 variable encoding whether at least one of the two inputs is zero (0 = yes) + fn either_zero(x: Self::Variable, y: Self::Variable) -> Self::Variable { + x * y + } +} + +/// This trait contains helper functions for arithmetic operations used in the Keccak circuit +pub trait ArithHelpers +where + Self: Interpreter, +{ + /// Returns a variable representing the value zero + fn zero() -> Self::Variable { + Self::constant(0) + } + /// Returns a variable representing the value one + fn one() -> Self::Variable { + Self::constant(1) + } + /// Returns a variable representing the value two + fn two() -> Self::Variable { + Self::constant(2) + } + + /// Returns a variable representing the value 2^x + fn two_pow(x: u64) -> Self::Variable; +} diff --git a/o1vm/src/interpreters/keccak/interpreter.rs b/o1vm/src/interpreters/keccak/interpreter.rs new file mode 100644 index 0000000000..1fb7eaaf21 --- /dev/null +++ b/o1vm/src/interpreters/keccak/interpreter.rs @@ -0,0 +1,1103 @@ +//! This module defines the Keccak interpreter in charge of triggering the Keccak workflow + +use crate::{ + interpreters::keccak::{ + column::{PAD_BYTES_LEN, PAD_SUFFIX_LEN, ROUND_CONST_LEN}, + grid_index, + helpers::{ArithHelpers, BoolHelpers, LogupHelpers}, + Absorbs::*, + Constraint::{self, *}, + KeccakColumn, + Sponges::*, + Steps::{self, *}, + WORDS_IN_HASH, + }, + lookups::{Lookup, LookupTableIDs::*}, +}; +use ark_ff::{One, Zero}; +use kimchi::{ + auto_clone_array, + circuits::polynomials::keccak::{ + constants::{ + CHI_SHIFTS_B_LEN, CHI_SHIFTS_SUM_LEN, DIM, PIRHO_DENSE_E_LEN, PIRHO_DENSE_ROT_E_LEN, + PIRHO_EXPAND_ROT_E_LEN, PIRHO_QUOTIENT_E_LEN, PIRHO_REMAINDER_E_LEN, + PIRHO_SHIFTS_E_LEN, QUARTERS, RATE_IN_BYTES, SHIFTS, SHIFTS_LEN, SPONGE_BYTES_LEN, + SPONGE_SHIFTS_LEN, SPONGE_ZEROS_LEN, STATE_LEN, THETA_DENSE_C_LEN, + THETA_DENSE_ROT_C_LEN, THETA_EXPAND_ROT_C_LEN, THETA_QUOTIENT_C_LEN, + THETA_REMAINDER_C_LEN, THETA_SHIFTS_C_LEN, THETA_STATE_A_LEN, + }, + OFF, + }, + grid, +}; +use std::{array, fmt::Debug}; + +/// This trait includes functionalities needed to obtain the variables of the Keccak circuit needed for constraints and witness +pub trait Interpreter { + type Variable: std::ops::Mul + + std::ops::Add + + std::ops::Sub + + Clone + + Debug + + One + + Zero; + + /// Creates a variable from a constant integer + fn constant(x: u64) -> Self::Variable; + + /// Creates a variable from a constant field element + fn constant_field(x: F) -> Self::Variable; + + /// Returns the variable corresponding to a given column alias. + fn variable(&self, column: KeccakColumn) -> Self::Variable; + + /// Adds one KeccakConstraint to the environment if the selector holds + fn constrain(&mut self, tag: Constraint, if_true: Self::Variable, x: Self::Variable); + + /// Adds a given Lookup to the environment if the condition holds + fn add_lookup(&mut self, if_true: Self::Variable, lookup: Lookup); +} + +pub trait KeccakInterpreter +where + Self: Interpreter + LogupHelpers + BoolHelpers + ArithHelpers, +{ + /// Creates all 879 constraints/checks to the environment: + /// - 733 constraints of degree 1 + /// - 146 constraints of degree 2 + /// Where: + /// - if Steps::Round(_) -> only 389 constraints added + /// - if Steps::Sponge::Absorb::First -> only 332 constraints added (232 + 100) + /// - if Steps::Sponge::Absorb::Middle -> only 232 constraints added + /// - if Steps::Sponge::Absorb::Last -> only 374 constraints added (232 + 136 + 6) + /// - if Steps::Sponge::Absorb::Only -> only 474 constraints added (232 + 136 + 100 + 6) + /// - if Steps::Sponge::Squeeze -> only 16 constraints added + /// So: + /// - At most, 474 constraints are added per row + /// In particular, after folding: + /// - 136 columns should be added for the degree-2 constraints of the flags + /// - 5 columns should be added for the degree-2 constraints of the round + /// - 10 columns should be added for the degree-2 constraints of the sponge + /// - for each of the 5 constraints, 2 columns are added for block_in_padding + fn constraints(&mut self, step: Steps) { + // CORRECTNESS OF FLAGS: 136 CONSTRAINTS + // - 136 constraints of degree 2 + // Of which: + // - 136 constraints are added only if is_pad() holds + self.constrain_flags(step); + + // SPONGE CONSTRAINTS: 32 + 3*100 + 16 + 6 = 354 CONSTRAINTS + // - 349 of degree 1 + // - 5 of degree 2 + // Of which: + // - 232 constraints are added only if is_absorb() holds + // - 100 constraints are added only if is_root() holds + // - 6 constraints are added only if is_pad() holds + // - 16 constraints are added only if is_squeeze() holds + self.constrain_sponge(step); + + // ROUND CONSTRAINTS: 35 + 150 + 200 + 4 = 389 CONSTRAINTS + // - 384 constraints of degree 1 + // - 5 constraints of degree 2 + // Of which: + // - 389 constraints are added only if is_round() holds + self.constrain_round(step); + } + + /// Constrains 136 checks of correctness of mode flags + /// - 136 constraints of degree 2 + /// Of which: + /// - 136 constraints are added only if is_pad() holds + fn constrain_flags(&mut self, step: Steps) + where + Self: Interpreter, + { + // Booleanity of sponge flags: + // - 136 constraints of degree 2 + self.constrain_booleanity(step); + } + + /// Constrains 136 checks of booleanity for some mode flags. + /// - 136 constraints of degree 2 + /// Of which, + /// - 136 constraints are added only if is_pad() holds + fn constrain_booleanity(&mut self, step: Steps) + where + Self: Interpreter, + { + for i in 0..RATE_IN_BYTES { + // Bytes are either involved on padding or not + self.constrain( + BooleanityPadding(i), + self.is_pad(step), + Self::is_boolean(self.in_padding(i)), + ); + } + } + + /// Constrains 354 checks of sponge steps + /// - 349 of degree 1 + /// - 5 of degree 2 + /// Of which: + /// - 232 constraints are added only if is_absorb() holds + /// - 100 constraints are added only if is_root() holds + /// - 6 constraints are added only if is_pad() holds + /// - 16 constraints are added only if is_squeeze() holds + fn constrain_sponge(&mut self, step: Steps) { + self.constrain_absorb(step); + self.constrain_padding(step); + self.constrain_squeeze(step); + } + + /// Constrains 332 checks of absorb sponges + /// - 332 of degree 1 + /// Of which: + /// - 232 constraints are added only if is_absorb() holds + /// - 100 constraints are added only if is_root() holds + fn constrain_absorb(&mut self, step: Steps) { + for (i, zero) in self.sponge_zeros().iter().enumerate() { + // Absorb phase pads with zeros the new state + self.constrain(AbsorbZeroPad(i), self.is_absorb(step), zero.clone()); + } + for i in 0..QUARTERS * DIM * DIM { + // In first absorb, root state is all zeros + self.constrain( + AbsorbRootZero(i), + self.is_root(step), + self.old_state(i).clone(), + ); + // Absorbs the new block by performing XOR with the old state + self.constrain( + AbsorbXor(i), + self.is_absorb(step), + self.xor_state(i).clone() - (self.old_state(i).clone() + self.new_state(i).clone()), + ); + // In absorb, Check shifts correspond to the decomposition of the new state + self.constrain( + AbsorbShifts(i), + self.is_absorb(step), + self.new_state(i).clone() + - Self::from_shifts(&self.vec_sponge_shifts(), Some(i), None, None, None), + ); + } + } + + /// Constrains 6 checks of padding absorb sponges + /// - 1 of degree 1 + /// - 5 of degree 2 + /// Of which: + /// - 6 constraints are added only if is_pad() holds + fn constrain_padding(&mut self, step: Steps) { + // Check that the padding is located at the end of the message + let pad_at_end = (0..RATE_IN_BYTES).fold(Self::zero(), |acc, i| { + acc * Self::two() + self.in_padding(i) + }); + self.constrain( + PadAtEnd, + self.is_pad(step), + self.two_to_pad() - Self::one() - pad_at_end, + ); + // Check that the padding value is correct + for i in 0..PAD_SUFFIX_LEN { + self.constrain( + PaddingSuffix(i), + self.is_pad(step), + self.block_in_padding(i) - self.pad_suffix(i), + ); + } + } + + /// Constrains 16 checks of squeeze sponges + /// - 16 of degree 1 + /// Of which: + /// - 16 constraints are added only if is_squeeze() holds + fn constrain_squeeze(&mut self, step: Steps) { + let sponge_shifts = self.vec_sponge_shifts(); + for i in 0..QUARTERS * WORDS_IN_HASH { + // In squeeze, check shifts correspond to the 256-bit prefix digest of the old state (current) + self.constrain( + SqueezeShifts(i), + self.is_squeeze(step), + self.old_state(i).clone() + - Self::from_shifts(&sponge_shifts, Some(i), None, None, None), + ); + } + } + + /// Constrains 389 checks of round steps + /// - 384 constraints of degree 1 + /// - 5 constraints of degree 2 + /// Of which: + /// - 389 constraints are added only if is_round() holds + fn constrain_round(&mut self, step: Steps) { + // STEP theta: 5 * ( 3 + 4 * 1 ) = 35 constraints + // - 30 constraints of degree 1 + // - 5 constraints of degree 2 + let state_e = self.constrain_theta(step); + + // STEP pirho: 5 * 5 * (2 + 4 * 1) = 150 constraints + // - 150 of degree 1 + let state_b = self.constrain_pirho(step, state_e); + + // STEP chi: 4 * 5 * 5 * 2 = 200 constraints + // - 200 of degree 1 + let state_f = self.constrain_chi(step, state_b); + + // STEP iota: 4 constraints + // - 4 of degree 1 + self.constrain_iota(step, state_f); + } + + /// Constrains 35 checks of the theta algorithm in round steps + /// - 30 constraints of degree 1 + /// - 5 constraints of degree 2 + fn constrain_theta(&mut self, step: Steps) -> Vec>> { + // Define vectors storing expressions which are not in the witness layout for efficiency + let mut state_c = vec![vec![Self::zero(); QUARTERS]; DIM]; + let mut state_d = vec![vec![Self::zero(); QUARTERS]; DIM]; + let mut state_e = vec![vec![vec![Self::zero(); QUARTERS]; DIM]; DIM]; + + for x in 0..DIM { + let word_c = Self::from_quarters(&self.vec_dense_c(), None, x); + let rem_c = Self::from_quarters(&self.vec_remainder_c(), None, x); + let rot_c = Self::from_quarters(&self.vec_dense_rot_c(), None, x); + + self.constrain( + ThetaWordC(x), + self.is_round(step), + word_c * Self::two_pow(1) + - (self.quotient_c(x) * Self::two_pow(64) + rem_c.clone()), + ); + self.constrain( + ThetaRotatedC(x), + self.is_round(step), + rot_c - (self.quotient_c(x) + rem_c), + ); + self.constrain( + ThetaQuotientC(x), + self.is_round(step), + Self::is_boolean(self.quotient_c(x)), + ); + + for q in 0..QUARTERS { + state_c[x][q] = self.state_a(0, x, q) + + self.state_a(1, x, q) + + self.state_a(2, x, q) + + self.state_a(3, x, q) + + self.state_a(4, x, q); + self.constrain( + ThetaShiftsC(x, q), + self.is_round(step), + state_c[x][q].clone() + - Self::from_shifts(&self.vec_shifts_c(), None, None, Some(x), Some(q)), + ); + + state_d[x][q] = + self.shifts_c(0, (x + DIM - 1) % DIM, q) + self.expand_rot_c((x + 1) % DIM, q); + + for (y, column_e) in state_e.iter_mut().enumerate() { + column_e[x][q] = self.state_a(y, x, q) + state_d[x][q].clone(); + } + } + } + state_e + } + + /// Constrains 150 checks of the pirho algorithm in round steps + /// - 150 of degree 1 + fn constrain_pirho( + &mut self, + step: Steps, + state_e: Vec>>, + ) -> Vec>> { + // Define vectors storing expressions which are not in the witness layout for efficiency + let mut state_b = vec![vec![vec![Self::zero(); QUARTERS]; DIM]; DIM]; + + for (y, col) in OFF.iter().enumerate() { + for (x, off) in col.iter().enumerate() { + let word_e = Self::from_quarters(&self.vec_dense_e(), Some(y), x); + let quo_e = Self::from_quarters(&self.vec_quotient_e(), Some(y), x); + let rem_e = Self::from_quarters(&self.vec_remainder_e(), Some(y), x); + let rot_e = Self::from_quarters(&self.vec_dense_rot_e(), Some(y), x); + + self.constrain( + PiRhoWordE(y, x), + self.is_round(step), + word_e * Self::two_pow(*off) + - (quo_e.clone() * Self::two_pow(64) + rem_e.clone()), + ); + self.constrain( + PiRhoRotatedE(y, x), + self.is_round(step), + rot_e - (quo_e.clone() + rem_e), + ); + + for q in 0..QUARTERS { + self.constrain( + PiRhoShiftsE(y, x, q), + self.is_round(step), + state_e[y][x][q].clone() + - Self::from_shifts( + &self.vec_shifts_e(), + None, + Some(y), + Some(x), + Some(q), + ), + ); + state_b[(2 * x + 3 * y) % DIM][y][q] = self.expand_rot_e(y, x, q); + } + } + } + state_b + } + + /// Constrains 200 checks of the chi algorithm in round steps + /// - 200 of degree 1 + fn constrain_chi( + &mut self, + step: Steps, + state_b: Vec>>, + ) -> Vec>> { + // Define vectors storing expressions which are not in the witness layout for efficiency + let mut state_f = vec![vec![vec![Self::zero(); QUARTERS]; DIM]; DIM]; + + for q in 0..QUARTERS { + for x in 0..DIM { + for y in 0..DIM { + let not = Self::constant(0x1111111111111111u64) + - self.shifts_b(0, y, (x + 1) % DIM, q); + let sum = not + self.shifts_b(0, y, (x + 2) % DIM, q); + let and = self.shifts_sum(1, y, x, q); + + self.constrain( + ChiShiftsB(y, x, q), + self.is_round(step), + state_b[y][x][q].clone() + - Self::from_shifts( + &self.vec_shifts_b(), + None, + Some(y), + Some(x), + Some(q), + ), + ); + self.constrain( + ChiShiftsSum(y, x, q), + self.is_round(step), + sum - Self::from_shifts( + &self.vec_shifts_sum(), + None, + Some(y), + Some(x), + Some(q), + ), + ); + state_f[y][x][q] = self.shifts_b(0, y, x, q) + and; + } + } + } + state_f + } + + /// Constrains 4 checks of the iota algorithm in round steps + /// - 4 of degree 1 + fn constrain_iota(&mut self, step: Steps, state_f: Vec>>) { + for (q, c) in self.round_constants().to_vec().iter().enumerate() { + self.constrain( + IotaStateG(q), + self.is_round(step), + self.state_g(q).clone() - (state_f[0][0][q].clone() + c.clone()), + ); + } + } + + //////////////////////// + // LOOKUPS OPERATIONS // + //////////////////////// + + /// Creates all possible lookups to the Keccak constraints environment: + /// - 2225 lookups for the step row + /// - 2 lookups for the inter-step channel + /// - 136 lookups for the syscall channel (preimage bytes) + /// - 1 lookups for the syscall channel (hash) + /// Of which: + /// - 1623 lookups if Step::Round (1621 + 2) + /// - 537 lookups if Step::Absorb::First (400 + 1 + 136) + /// - 538 lookups if Step::Absorb::Middle (400 + 2 + 136) + /// - 539 lookups if Step::Absorb::Last (401 + 2 + 136) + /// - 538 lookups if Step::Absorb::Only (401 + 1 + 136) + /// - 602 lookups if Step::Squeeze (600 + 1 + 1) + fn lookups(&mut self, step: Steps) { + // SPONGE LOOKUPS + // -> adds 400 lookups if is_sponge + // -> adds +200 lookups if is_squeeze + // -> adds +1 lookups if is_pad + self.lookups_sponge(step); + + // ROUND LOOKUPS + // -> adds 1621 lookups if is_round + { + // THETA LOOKUPS + self.lookups_round_theta(step); + // PIRHO LOOKUPS + self.lookups_round_pirho(step); + // CHI LOOKUPS + self.lookups_round_chi(step); + // IOTA LOOKUPS + self.lookups_round_iota(step); + } + + // INTER-STEP CHANNEL + // Write outputs for next step if not a squeeze and read inputs of curr step if not a root + // -> adds 1 lookup if is_root + // -> adds 1 lookup if is_squeeze + // -> adds 2 lookups otherwise + self.lookup_steps(step); + + // COMMUNICATION CHANNEL: read bytes of current block + // -> adds 136 lookups if is_absorb + self.lookup_syscall_preimage(step); + + // COMMUNICATION CHANNEL: Write hash output + // -> adds 1 lookup if is_squeeze + self.lookup_syscall_hash(step); + } + + /// When in Absorb mode, reads Lookups containing the 136 bytes of the block of the preimage + /// - if is_absorb, adds 136 lookups + /// - otherwise, adds 0 lookups + // TODO: optimize this by using a single lookup reusing PadSuffix + fn lookup_syscall_preimage(&mut self, step: Steps) { + for i in 0..RATE_IN_BYTES { + self.read_syscall( + self.is_absorb(step), + vec![ + self.hash_index(), + self.block_index() * Self::constant(RATE_IN_BYTES as u64) + + Self::constant(i as u64), + self.sponge_byte(i), + ], + ); + } + } + + /// When in Squeeze mode, writes a Lookup containing the 31byte output of + /// the hash (excludes the MSB) + /// - if is_squeeze, adds 1 lookup + /// - otherwise, adds 0 lookups + /// NOTE: this is excluding the MSB (which is then substituted with the + /// file descriptor). + fn lookup_syscall_hash(&mut self, step: Steps) { + let bytes31 = (1..32).fold(Self::zero(), |acc, i| { + acc * Self::two_pow(8) + self.sponge_byte(i) + }); + self.write_syscall(self.is_squeeze(step), vec![self.hash_index(), bytes31]); + } + + /// Reads a Lookup containing the input of a step + /// and writes a Lookup containing the output of the next step + /// - if is_root, only adds 1 lookup + /// - if is_squeeze, only adds 1 lookup + /// - otherwise, adds 2 lookups + fn lookup_steps(&mut self, step: Steps) { + // (if not a root) Output of previous step is input of current step + self.add_lookup( + Self::not(self.is_root(step)), + Lookup::read_one(KeccakStepLookup, self.input_of_step()), + ); + // (if not a squeeze) Input for next step is output of current step + self.add_lookup( + Self::not(self.is_squeeze(step)), + Lookup::write_one(KeccakStepLookup, self.output_of_step()), + ); + } + + /// Adds the 601 lookups required for the sponge + /// - 400 lookups if is_sponge() + /// - 200 extra lookups if is_squeeze() + /// - 1 extra lookup if is_pad() + fn lookups_sponge(&mut self, step: Steps) { + // PADDING LOOKUPS + // Power of two corresponds to 2^pad_length + // Pad suffixes correspond to 10*1 rule + self.lookup_pad( + self.is_pad(step), + vec![ + self.pad_length(), + self.two_to_pad(), + self.pad_suffix(0), + self.pad_suffix(1), + self.pad_suffix(2), + self.pad_suffix(3), + self.pad_suffix(4), + ], + ); + // BYTES LOOKUPS + // Checking the 200 bytes of the absorb phase together with the length + // bytes is performed in the SyscallReadPreimage rows (4 byte lookups + // per row). + // Here, this only checks the 200 bytes of the squeeze phase (digest) + // TODO: could this be just 32 and we check the shifts only for those? + for i in 0..200 { + // Bytes are <2^8 + self.lookup_byte(self.is_squeeze(step), self.sponge_byte(i)); + } + // SHIFTS LOOKUPS + for i in 100..SHIFTS_LEN { + // Shifts1, Shifts2, Shifts3 are in the Sparse table + self.lookup_sparse(self.is_sponge(step), self.sponge_shifts(i)); + } + for i in 0..STATE_LEN { + // Shifts0 together with Bytes composition by pairs are in the Reset table + let dense = self.sponge_byte(2 * i) + self.sponge_byte(2 * i + 1) * Self::two_pow(8); + self.lookup_reset(self.is_sponge(step), dense, self.sponge_shifts(i)); + } + } + + /// Adds the 120 lookups required for Theta in the round + fn lookups_round_theta(&mut self, step: Steps) { + for q in 0..QUARTERS { + for x in 0..DIM { + // Check that ThetaRemainderC < 2^64 + self.lookup_rc16(self.is_round(step), self.remainder_c(x, q)); + // Check ThetaExpandRotC is the expansion of ThetaDenseRotC + self.lookup_reset( + self.is_round(step), + self.dense_rot_c(x, q), + self.expand_rot_c(x, q), + ); + // Check ThetaShiftC0 is the expansion of ThetaDenseC + self.lookup_reset( + self.is_round(step), + self.dense_c(x, q), + self.shifts_c(0, x, q), + ); + // Check that the rest of ThetaShiftsC are in the Sparse table + for i in 1..SHIFTS { + self.lookup_sparse(self.is_round(step), self.shifts_c(i, x, q)); + } + } + } + } + + /// Adds the 700 lookups required for PiRho in the round + fn lookups_round_pirho(&mut self, step: Steps) { + for q in 0..QUARTERS { + for x in 0..DIM { + for y in 0..DIM { + // Check that PiRhoRemainderE < 2^64 and PiRhoQuotientE < 2^64 + self.lookup_rc16(self.is_round(step), self.remainder_e(y, x, q)); + self.lookup_rc16(self.is_round(step), self.quotient_e(y, x, q)); + // Check PiRhoExpandRotE is the expansion of PiRhoDenseRotE + self.lookup_reset( + self.is_round(step), + self.dense_rot_e(y, x, q), + self.expand_rot_e(y, x, q), + ); + // Check PiRhoShift0E is the expansion of PiRhoDenseE + self.lookup_reset( + self.is_round(step), + self.dense_e(y, x, q), + self.shifts_e(0, y, x, q), + ); + // Check that the rest of PiRhoShiftsE are in the Sparse table + for i in 1..SHIFTS { + self.lookup_sparse(self.is_round(step), self.shifts_e(i, y, x, q)); + } + } + } + } + } + + /// Adds the 800 lookups required for Chi in the round + fn lookups_round_chi(&mut self, step: Steps) { + let shifts_b = self.vec_shifts_b(); + let shifts_sum = self.vec_shifts_sum(); + for i in 0..SHIFTS_LEN { + // Check ChiShiftsB and ChiShiftsSum are in the Sparse table + self.lookup_sparse(self.is_round(step), shifts_b[i].clone()); + self.lookup_sparse(self.is_round(step), shifts_sum[i].clone()); + } + } + + /// Adds the 1 lookup required for Iota in the round + fn lookups_round_iota(&mut self, step: Steps) { + // Check round constants correspond with the current round + let round_constants = self.round_constants(); + self.lookup_round_constants( + self.is_round(step), + vec![ + self.round(), + round_constants[3].clone(), + round_constants[2].clone(), + round_constants[1].clone(), + round_constants[0].clone(), + ], + ); + } + + /////////////////////////// + /// SELECTOR OPERATIONS /// + /////////////////////////// + + /// Returns a degree-2 variable that encodes whether the current step is a sponge (1 = yes) + fn is_sponge(&self, step: Steps) -> Self::Variable { + Self::xor(self.is_absorb(step), self.is_squeeze(step)) + } + /// Returns a variable that encodes whether the current step is an absorb sponge (1 = yes) + fn is_absorb(&self, step: Steps) -> Self::Variable { + Self::or( + Self::or(self.mode_root(step), self.mode_rootpad(step)), + Self::or(self.mode_pad(step), self.mode_absorb(step)), + ) + } + /// Returns a variable that encodes whether the current step is a squeeze sponge (1 = yes) + fn is_squeeze(&self, step: Steps) -> Self::Variable { + self.mode_squeeze(step) + } + /// Returns a variable that encodes whether the current step is the first absorb sponge (1 = yes) + fn is_root(&self, step: Steps) -> Self::Variable { + Self::or(self.mode_root(step), self.mode_rootpad(step)) + } + /// Returns a degree-1 variable that encodes whether the current step is the last absorb sponge (1 = yes) + fn is_pad(&self, step: Steps) -> Self::Variable { + Self::or(self.mode_pad(step), self.mode_rootpad(step)) + } + /// Returns a variable that encodes whether the current step is a permutation round (1 = yes) + fn is_round(&self, step: Steps) -> Self::Variable { + self.mode_round(step) + } + + /// Returns a variable that encodes whether the current step is an absorb sponge (1 = yes) + fn mode_absorb(&self, step: Steps) -> Self::Variable { + match step { + Sponge(Absorb(Middle)) => Self::one(), + _ => Self::zero(), + } + } + + /// Returns a variable that encodes whether the current step is a squeeze sponge (1 = yes) + fn mode_squeeze(&self, step: Steps) -> Self::Variable { + match step { + Sponge(Squeeze) => Self::one(), + _ => Self::zero(), + } + } + + /// Returns a variable that encodes whether the current step is the first absorb sponge (1 = yes) + fn mode_root(&self, step: Steps) -> Self::Variable { + match step { + Sponge(Absorb(First)) => Self::one(), + _ => Self::zero(), + } + } + + /// Returns a degree-1 variable that encodes whether the current step is the last absorb sponge (1 = yes) + fn mode_pad(&self, step: Steps) -> Self::Variable { + match step { + Sponge(Absorb(Last)) => Self::one(), + _ => Self::zero(), + } + } + + /// Returns a degree-1 variable that encodes whether the current step is the first and last absorb sponge (1 = yes) + fn mode_rootpad(&self, step: Steps) -> Self::Variable { + match step { + Sponge(Absorb(Only)) => Self::one(), + _ => Self::zero(), + } + } + + /// Returns a variable that encodes whether the current step is a permutation round (1 = yes) + fn mode_round(&self, step: Steps) -> Self::Variable { + // The actual round number in the selector carries no information for witness nor constraints + // because in the witness, any usize is mapped to the same index inside the mode flags + match step { + Round(_) => Self::one(), + _ => Self::zero(), + } + } + + ///////////////////////// + /// COLUMN OPERATIONS /// + ///////////////////////// + + /// This function returns the composed sparse variable from shifts of any correct length: + /// - When the length is 400, two index configurations are possible: + /// - If `i` is `Some`, then this sole index could range between [0..400) + /// - If `i` is `None`, then `y`, `x` and `q` must be `Some` and + /// - `y` must range between [0..5) + /// - `x` must range between [0..5) + /// - `q` must range between [0..4) + /// - When the length is 80, both `i` and `y` should be `None`, and `x` and `q` must be `Some` with: + /// - `x` must range between [0..5) + /// - `q` must range between [0..4) + fn from_shifts( + shifts: &[Self::Variable], + i: Option, + y: Option, + x: Option, + q: Option, + ) -> Self::Variable { + match shifts.len() { + 400 => { + if let Some(i) = i { + auto_clone_array!(shifts); + shifts(i) + + Self::two_pow(1) * shifts(100 + i) + + Self::two_pow(2) * shifts(200 + i) + + Self::two_pow(3) * shifts(300 + i) + } else { + let shifts = grid!(400, shifts); + shifts(0, y.unwrap(), x.unwrap(), q.unwrap()) + + Self::two_pow(1) * shifts(1, y.unwrap(), x.unwrap(), q.unwrap()) + + Self::two_pow(2) * shifts(2, y.unwrap(), x.unwrap(), q.unwrap()) + + Self::two_pow(3) * shifts(3, y.unwrap(), x.unwrap(), q.unwrap()) + } + } + 80 => { + let shifts = grid!(80, shifts); + shifts(0, x.unwrap(), q.unwrap()) + + Self::two_pow(1) * shifts(1, x.unwrap(), q.unwrap()) + + Self::two_pow(2) * shifts(2, x.unwrap(), q.unwrap()) + + Self::two_pow(3) * shifts(3, x.unwrap(), q.unwrap()) + } + _ => panic!("Invalid length of shifts"), + } + } + + /// This function returns the composed variable from dense quarters of any correct length: + /// - When `y` is `Some`, then the length must be 100 and: + /// - `y` must range between [0..5) + /// - `x` must range between [0..5) + /// - When `y` is `None`, then the length must be 20 and: + /// - `x` must range between [0..5) + fn from_quarters(quarters: &[Self::Variable], y: Option, x: usize) -> Self::Variable { + if let Some(y) = y { + assert!(quarters.len() == 100, "Invalid length of quarters"); + let quarters = grid!(100, quarters); + quarters(y, x, 0) + + Self::two_pow(16) * quarters(y, x, 1) + + Self::two_pow(32) * quarters(y, x, 2) + + Self::two_pow(48) * quarters(y, x, 3) + } else { + assert!(quarters.len() == 20, "Invalid length of quarters"); + let quarters = grid!(20, quarters); + quarters(x, 0) + + Self::two_pow(16) * quarters(x, 1) + + Self::two_pow(32) * quarters(x, 2) + + Self::two_pow(48) * quarters(x, 3) + } + } + + /// Returns a variable that encodes the current round number [0..24) + fn round(&self) -> Self::Variable { + self.variable(KeccakColumn::RoundNumber) + } + + /// Returns a variable that encodes the bytelength of the padding if any [0..136) + fn pad_length(&self) -> Self::Variable { + self.variable(KeccakColumn::PadLength) + } + /// Returns a variable that encodes the value 2^pad_length + fn two_to_pad(&self) -> Self::Variable { + self.variable(KeccakColumn::TwoToPad) + } + + /// Returns a variable that encodes whether the `idx`-th byte of the new block is involved in the padding (1 = yes) + fn in_padding(&self, idx: usize) -> Self::Variable { + self.variable(KeccakColumn::PadBytesFlags(idx)) + } + + /// Returns a variable that encodes the `idx`-th chunk of the padding suffix + /// - if `idx` = 0, then the length is 12 bytes at most + /// - if `idx` = [1..5), then the length is 31 bytes at most + fn pad_suffix(&self, idx: usize) -> Self::Variable { + self.variable(KeccakColumn::PadSuffix(idx)) + } + + /// Returns a variable that encodes the `idx`-th block of bytes of the new block + /// by composing the bytes variables, with `idx` in [0..5) + fn bytes_block(&self, idx: usize) -> Vec { + let sponge_bytes = self.sponge_bytes(); + match idx { + 0 => sponge_bytes[0..12].to_vec(), + 1..=4 => sponge_bytes[12 + (idx - 1) * 31..12 + idx * 31].to_vec(), + _ => panic!("No more blocks of bytes can be part of padding"), + } + } + + /// Returns the 136 flags indicating which bytes of the new block are involved in the padding, as variables + fn pad_bytes_flags(&self) -> [Self::Variable; PAD_BYTES_LEN] { + array::from_fn(|idx| self.variable(KeccakColumn::PadBytesFlags(idx))) + } + + /// Returns a vector of pad bytes flags as variables, with `idx` in [0..5) + /// - if `idx` = 0, then the length of the block is at most 12 + /// - if `idx` = [1..5), then the length of the block is at most 31 + fn flags_block(&self, idx: usize) -> Vec { + let pad_bytes_flags = self.pad_bytes_flags(); + match idx { + 0 => pad_bytes_flags[0..12].to_vec(), + 1..=4 => pad_bytes_flags[12 + (idx - 1) * 31..12 + idx * 31].to_vec(), + _ => panic!("No more blocks of flags can be part of padding"), + } + } + + /// This function returns a degree-2 variable that is computed as the accumulated value of the + /// operation `byte * flag * 2^8` for each byte block and flag block of the new block. + /// This function will be used in constraints to determine whether the padding is located + /// at the end of the preimage data, as consecutive bits that are involved in the padding. + fn block_in_padding(&self, idx: usize) -> Self::Variable { + let bytes = self.bytes_block(idx); + let flags = self.flags_block(idx); + assert_eq!(bytes.len(), flags.len()); + let pad = bytes + .iter() + .zip(flags) + .fold(Self::zero(), |acc, (byte, flag)| { + acc * Self::two_pow(8) + byte.clone() * flag.clone() + }); + + pad + } + + /// Returns the 4 expanded quarters that encode the round constant, as variables + fn round_constants(&self) -> [Self::Variable; ROUND_CONST_LEN] { + array::from_fn(|idx| self.variable(KeccakColumn::RoundConstants(idx))) + } + + /// Returns the `idx`-th old state expanded quarter, as a variable + fn old_state(&self, idx: usize) -> Self::Variable { + self.variable(KeccakColumn::Input(idx)) + } + + /// Returns the `idx`-th new state expanded quarter, as a variable + fn new_state(&self, idx: usize) -> Self::Variable { + self.variable(KeccakColumn::SpongeNewState(idx)) + } + + /// Returns the output of an absorb sponge, which is the XOR of the old state and the new state + fn xor_state(&self, idx: usize) -> Self::Variable { + self.variable(KeccakColumn::Output(idx)) + } + + /// Returns the last 32 terms that are added to the new block in an absorb sponge, as variables which should be zeros + fn sponge_zeros(&self) -> [Self::Variable; SPONGE_ZEROS_LEN] { + array::from_fn(|idx| self.variable(KeccakColumn::SpongeZeros(idx))) + } + + /// Returns the 400 terms that compose the shifts of the sponge, as variables + fn vec_sponge_shifts(&self) -> [Self::Variable; SPONGE_SHIFTS_LEN] { + array::from_fn(|idx| self.variable(KeccakColumn::SpongeShifts(idx))) + } + /// Returns the `idx`-th term of the shifts of the sponge, as a variable + fn sponge_shifts(&self, idx: usize) -> Self::Variable { + self.variable(KeccakColumn::SpongeShifts(idx)) + } + + /// Returns the 200 bytes of the sponge, as variables + fn sponge_bytes(&self) -> [Self::Variable; SPONGE_BYTES_LEN] { + array::from_fn(|idx| self.variable(KeccakColumn::SpongeBytes(idx))) + } + /// Returns the `idx`-th byte of the sponge, as a variable + fn sponge_byte(&self, idx: usize) -> Self::Variable { + self.variable(KeccakColumn::SpongeBytes(idx)) + } + + /// Returns the (y,x,q)-th input of the theta algorithm, as a variable + fn state_a(&self, y: usize, x: usize, q: usize) -> Self::Variable { + let idx = grid_index(THETA_STATE_A_LEN, 0, y, x, q); + self.variable(KeccakColumn::Input(idx)) + } + + /// Returns the 80 variables corresponding to ThetaShiftsC + fn vec_shifts_c(&self) -> [Self::Variable; THETA_SHIFTS_C_LEN] { + array::from_fn(|idx| self.variable(KeccakColumn::ThetaShiftsC(idx))) + } + /// Returns the (i,x,q)-th variable of ThetaShiftsC + fn shifts_c(&self, i: usize, x: usize, q: usize) -> Self::Variable { + let idx = grid_index(THETA_SHIFTS_C_LEN, i, 0, x, q); + self.variable(KeccakColumn::ThetaShiftsC(idx)) + } + + /// Returns the 20 variables corresponding to ThetaDenseC + fn vec_dense_c(&self) -> [Self::Variable; THETA_DENSE_C_LEN] { + array::from_fn(|idx| self.variable(KeccakColumn::ThetaDenseC(idx))) + } + /// Returns the (x,q)-th term of ThetaDenseC, as a variable + fn dense_c(&self, x: usize, q: usize) -> Self::Variable { + let idx = grid_index(THETA_DENSE_C_LEN, 0, 0, x, q); + self.variable(KeccakColumn::ThetaDenseC(idx)) + } + + /// Returns the 5 variables corresponding to ThetaQuotientC + fn vec_quotient_c(&self) -> [Self::Variable; THETA_QUOTIENT_C_LEN] { + array::from_fn(|idx| self.variable(KeccakColumn::ThetaQuotientC(idx))) + } + /// Returns the (x)-th term of ThetaQuotientC, as a variable + fn quotient_c(&self, x: usize) -> Self::Variable { + self.variable(KeccakColumn::ThetaQuotientC(x)) + } + + /// Returns the 20 variables corresponding to ThetaRemainderC + fn vec_remainder_c(&self) -> [Self::Variable; THETA_REMAINDER_C_LEN] { + array::from_fn(|idx| self.variable(KeccakColumn::ThetaRemainderC(idx))) + } + /// Returns the (x,q)-th variable of ThetaRemainderC + fn remainder_c(&self, x: usize, q: usize) -> Self::Variable { + let idx = grid_index(THETA_REMAINDER_C_LEN, 0, 0, x, q); + self.variable(KeccakColumn::ThetaRemainderC(idx)) + } + + /// Returns the 20 variables corresponding to ThetaDenseRotC + fn vec_dense_rot_c(&self) -> [Self::Variable; THETA_DENSE_ROT_C_LEN] { + array::from_fn(|idx| self.variable(KeccakColumn::ThetaDenseRotC(idx))) + } + /// Returns the (x,q)-th variable of ThetaDenseRotC + fn dense_rot_c(&self, x: usize, q: usize) -> Self::Variable { + let idx = grid_index(THETA_DENSE_ROT_C_LEN, 0, 0, x, q); + self.variable(KeccakColumn::ThetaDenseRotC(idx)) + } + + /// Returns the 20 variables corresponding to ThetaExpandRotC + fn vec_expand_rot_c(&self) -> [Self::Variable; THETA_EXPAND_ROT_C_LEN] { + array::from_fn(|idx| self.variable(KeccakColumn::ThetaExpandRotC(idx))) + } + /// Returns the (x,q)-th variable of ThetaExpandRotC + fn expand_rot_c(&self, x: usize, q: usize) -> Self::Variable { + let idx = grid_index(THETA_EXPAND_ROT_C_LEN, 0, 0, x, q); + self.variable(KeccakColumn::ThetaExpandRotC(idx)) + } + + /// Returns the 400 variables corresponding to PiRhoShiftsE + fn vec_shifts_e(&self) -> [Self::Variable; PIRHO_SHIFTS_E_LEN] { + array::from_fn(|idx| self.variable(KeccakColumn::PiRhoShiftsE(idx))) + } + /// Returns the (i,y,x,q)-th variable of PiRhoShiftsE + fn shifts_e(&self, i: usize, y: usize, x: usize, q: usize) -> Self::Variable { + let idx = grid_index(PIRHO_SHIFTS_E_LEN, i, y, x, q); + self.variable(KeccakColumn::PiRhoShiftsE(idx)) + } + + /// Returns the 100 variables corresponding to PiRhoDenseE + fn vec_dense_e(&self) -> [Self::Variable; PIRHO_DENSE_E_LEN] { + array::from_fn(|idx: usize| self.variable(KeccakColumn::PiRhoDenseE(idx))) + } + /// Returns the (y,x,q)-th variable of PiRhoDenseE + fn dense_e(&self, y: usize, x: usize, q: usize) -> Self::Variable { + let idx = grid_index(PIRHO_DENSE_E_LEN, 0, y, x, q); + self.variable(KeccakColumn::PiRhoDenseE(idx)) + } + + /// Returns the 100 variables corresponding to PiRhoQuotientE + fn vec_quotient_e(&self) -> [Self::Variable; PIRHO_QUOTIENT_E_LEN] { + array::from_fn(|idx| self.variable(KeccakColumn::PiRhoQuotientE(idx))) + } + /// Returns the (y,x,q)-th variable of PiRhoQuotientE + fn quotient_e(&self, y: usize, x: usize, q: usize) -> Self::Variable { + let idx = grid_index(PIRHO_QUOTIENT_E_LEN, 0, y, x, q); + self.variable(KeccakColumn::PiRhoQuotientE(idx)) + } + + /// Returns the 100 variables corresponding to PiRhoRemainderE + fn vec_remainder_e(&self) -> [Self::Variable; PIRHO_REMAINDER_E_LEN] { + array::from_fn(|idx| self.variable(KeccakColumn::PiRhoRemainderE(idx))) + } + /// Returns the (y,x,q)-th variable of PiRhoRemainderE + fn remainder_e(&self, y: usize, x: usize, q: usize) -> Self::Variable { + let idx = grid_index(PIRHO_REMAINDER_E_LEN, 0, y, x, q); + self.variable(KeccakColumn::PiRhoRemainderE(idx)) + } + + /// Returns the 100 variables corresponding to PiRhoDenseRotE + fn vec_dense_rot_e(&self) -> [Self::Variable; PIRHO_DENSE_ROT_E_LEN] { + array::from_fn(|idx| self.variable(KeccakColumn::PiRhoDenseRotE(idx))) + } + /// Returns the (y,x,q)-th variable of PiRhoDenseRotE + fn dense_rot_e(&self, y: usize, x: usize, q: usize) -> Self::Variable { + let idx = grid_index(PIRHO_DENSE_ROT_E_LEN, 0, y, x, q); + self.variable(KeccakColumn::PiRhoDenseRotE(idx)) + } + + /// Returns the 100 variables corresponding to PiRhoExpandRotE + fn vec_expand_rot_e(&self) -> [Self::Variable; PIRHO_EXPAND_ROT_E_LEN] { + array::from_fn(|idx| self.variable(KeccakColumn::PiRhoExpandRotE(idx))) + } + /// Returns the (y,x,q)-th variable of PiRhoExpandRotE + fn expand_rot_e(&self, y: usize, x: usize, q: usize) -> Self::Variable { + let idx = grid_index(PIRHO_EXPAND_ROT_E_LEN, 0, y, x, q); + self.variable(KeccakColumn::PiRhoExpandRotE(idx)) + } + + /// Returns the 400 variables corresponding to ChiShiftsB + fn vec_shifts_b(&self) -> [Self::Variable; CHI_SHIFTS_B_LEN] { + array::from_fn(|idx| self.variable(KeccakColumn::ChiShiftsB(idx))) + } + /// Returns the (i,y,x,q)-th variable of ChiShiftsB + fn shifts_b(&self, i: usize, y: usize, x: usize, q: usize) -> Self::Variable { + let idx = grid_index(CHI_SHIFTS_B_LEN, i, y, x, q); + self.variable(KeccakColumn::ChiShiftsB(idx)) + } + + /// Returns the 400 variables corresponding to ChiShiftsSum + fn vec_shifts_sum(&self) -> [Self::Variable; CHI_SHIFTS_SUM_LEN] { + array::from_fn(|idx| self.variable(KeccakColumn::ChiShiftsSum(idx))) + } + /// Returns the (i,y,x,q)-th variable of ChiShiftsSum + fn shifts_sum(&self, i: usize, y: usize, x: usize, q: usize) -> Self::Variable { + let idx = grid_index(CHI_SHIFTS_SUM_LEN, i, y, x, q); + self.variable(KeccakColumn::ChiShiftsSum(idx)) + } + + /// Returns the `idx`-th output of a round step as a variable + fn state_g(&self, idx: usize) -> Self::Variable { + self.variable(KeccakColumn::Output(idx)) + } + + /// Returns the hash index as a variable + fn hash_index(&self) -> Self::Variable { + self.variable(KeccakColumn::HashIndex) + } + /// Returns the block index as a variable + fn block_index(&self) -> Self::Variable { + self.variable(KeccakColumn::BlockIndex) + } + /// Returns the step index as a variable + fn step_index(&self) -> Self::Variable { + self.variable(KeccakColumn::StepIndex) + } + + /// Returns the 100 step input variables, which correspond to the: + /// - State A when the current step is a permutation round + /// - Old state when the current step is a non-root sponge + fn input(&self) -> [Self::Variable; STATE_LEN] { + array::from_fn::<_, STATE_LEN, _>(|idx| self.variable(KeccakColumn::Input(idx))) + } + /// Returns a slice of the input variables of the current step + /// including the current hash index and step index + fn input_of_step(&self) -> Vec { + let mut input_of_step = Vec::with_capacity(STATE_LEN + 2); + input_of_step.push(self.hash_index()); + input_of_step.push(self.step_index()); + input_of_step.extend_from_slice(&self.input()); + input_of_step + } + + /// Returns the 100 step output variables, which correspond to the: + /// - State G when the current step is a permutation round + /// - Xor state when the current step is an absorb sponge + fn output(&self) -> [Self::Variable; STATE_LEN] { + array::from_fn::<_, STATE_LEN, _>(|idx| self.variable(KeccakColumn::Output(idx))) + } + /// Returns a slice of the output variables of the current step (= input of next step) + /// including the current hash index and step index + fn output_of_step(&self) -> Vec { + let mut output_of_step = Vec::with_capacity(STATE_LEN + 2); + output_of_step.push(self.hash_index()); + output_of_step.push(self.step_index() + Self::one()); + output_of_step.extend_from_slice(&self.output()); + output_of_step + } +} diff --git a/o1vm/src/interpreters/keccak/mod.rs b/o1vm/src/interpreters/keccak/mod.rs new file mode 100644 index 0000000000..a05ab81d62 --- /dev/null +++ b/o1vm/src/interpreters/keccak/mod.rs @@ -0,0 +1,116 @@ +use crate::{ + interpreters::keccak::column::{ColumnAlias as KeccakColumn, Steps::*, PAD_SUFFIX_LEN}, + lookups::LookupTableIDs, +}; +use ark_ff::Field; +use kimchi::circuits::polynomials::keccak::constants::{ + DIM, KECCAK_COLS, QUARTERS, RATE_IN_BYTES, STATE_LEN, +}; + +pub mod column; +pub mod constraints; +pub mod environment; +pub mod helpers; +pub mod interpreter; +#[cfg(test)] +pub mod tests; +pub mod witness; + +pub use column::{Absorbs, Sponges, Steps}; + +/// Desired output length of the hash in bits +pub(crate) const HASH_BITLENGTH: usize = 256; +/// Desired output length of the hash in bytes +pub(crate) const HASH_BYTELENGTH: usize = HASH_BITLENGTH / 8; +/// Length of each word in the Keccak state, in bits +pub(crate) const WORD_LENGTH_IN_BITS: usize = 64; +/// Number of columns required in the `curr` part of the witness +pub(crate) const ZKVM_KECCAK_COLS_CURR: usize = KECCAK_COLS; +/// Number of columns required in the `next` part of the witness, corresponding to the output length +pub(crate) const ZKVM_KECCAK_COLS_NEXT: usize = STATE_LEN; +/// Number of words that fit in the hash digest +pub(crate) const WORDS_IN_HASH: usize = HASH_BITLENGTH / WORD_LENGTH_IN_BITS; + +/// Errors that can occur during the check of the witness +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Error { + Constraint(Constraint), + Lookup(LookupTableIDs), +} + +/// All the names for constraints involved in the Keccak circuit +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum Constraint { + BooleanityPadding(usize), + AbsorbZeroPad(usize), + AbsorbRootZero(usize), + AbsorbXor(usize), + AbsorbShifts(usize), + PadAtEnd, + PaddingSuffix(usize), + SqueezeShifts(usize), + ThetaWordC(usize), + ThetaRotatedC(usize), + ThetaQuotientC(usize), + ThetaShiftsC(usize, usize), + PiRhoWordE(usize, usize), + PiRhoRotatedE(usize, usize), + PiRhoShiftsE(usize, usize, usize), + ChiShiftsB(usize, usize, usize), + ChiShiftsSum(usize, usize, usize), + IotaStateG(usize), +} + +/// Standardizes a Keccak step to a common opcode +pub fn standardize(opcode: Steps) -> Steps { + // Note that steps of execution are obtained from the constraints environment. + // There, the round steps can be anything between 0 and 23 (for the 24 permutations). + // Nonetheless, all of them contain the same set of constraints and lookups. + // Therefore, we want to treat them as the same step when it comes to splitting the + // circuit into multiple instances with shared behaviour. By default, we use `Round(0)`. + if let Round(_) = opcode { + Round(0) + } else { + opcode + } +} + +// This function maps a 4D index into a 1D index depending on the length of the grid +fn grid_index(length: usize, i: usize, y: usize, x: usize, q: usize) -> usize { + match length { + 5 => x, + 20 => q + QUARTERS * x, + 80 => q + QUARTERS * (x + DIM * i), + 100 => q + QUARTERS * (x + DIM * y), + 400 => q + QUARTERS * (x + DIM * (y + DIM * i)), + _ => panic!("Invalid grid size"), + } +} + +/// This function returns a vector of field elements that represent the 5 padding suffixes. +/// The first one uses at most 12 bytes, and the rest use at most 31 bytes. +pub fn pad_blocks(pad_bytelength: usize) -> [F; PAD_SUFFIX_LEN] { + assert!(pad_bytelength > 0, "Padding length must be at least 1 byte"); + assert!( + pad_bytelength <= 136, + "Padding length must be at most 136 bytes", + ); + // Blocks to store padding. The first one uses at most 12 bytes, and the rest use at most 31 bytes. + let mut blocks = [F::zero(); PAD_SUFFIX_LEN]; + let mut pad = [F::zero(); RATE_IN_BYTES]; + pad[RATE_IN_BYTES - pad_bytelength] = F::one(); + pad[RATE_IN_BYTES - 1] += F::from(0x80u8); + blocks[0] = pad + .iter() + .take(12) + .fold(F::zero(), |acc, x| acc * F::from(256u32) + *x); + for (i, block) in blocks.iter_mut().enumerate().take(5).skip(1) { + // take 31 elements from pad, starting at 12 + (i - 1) * 31 and fold them into a single Fp + *block = pad + .iter() + .skip(12 + (i - 1) * 31) + .take(31) + .fold(F::zero(), |acc, x| acc * F::from(256u32) + *x); + } + blocks +} diff --git a/o1vm/src/interpreters/keccak/tests.rs b/o1vm/src/interpreters/keccak/tests.rs new file mode 100644 index 0000000000..6ce716a038 --- /dev/null +++ b/o1vm/src/interpreters/keccak/tests.rs @@ -0,0 +1,489 @@ +use crate::{ + interpreters::keccak::{ + column::{Absorbs::*, Sponges::*, Steps::*}, + environment::KeccakEnv, + interpreter::KeccakInterpreter, + Constraint::*, + Error, KeccakColumn, + }, + lookups::{FixedLookupTables, LookupTable, LookupTableIDs::*}, +}; + +use ark_ff::{One, Zero}; +use kimchi::{ + circuits::polynomials::keccak::Keccak, + o1_utils::{self, FieldHelpers, Two}, +}; +use rand::Rng; +use sha3::{Digest, Keccak256}; +use std::collections::HashMap; + +// FIXME: we should check with other fields too +use ark_bn254::Fr as Fp; + +#[test] +fn test_pad_blocks() { + let blocks_1 = crate::interpreters::keccak::pad_blocks::(1); + assert_eq!(blocks_1[0], Fp::from(0x00)); + assert_eq!(blocks_1[1], Fp::from(0x00)); + assert_eq!(blocks_1[2], Fp::from(0x00)); + assert_eq!(blocks_1[3], Fp::from(0x00)); + assert_eq!(blocks_1[4], Fp::from(0x81)); + + let blocks_136 = crate::interpreters::keccak::pad_blocks::(136); + assert_eq!(blocks_136[0], Fp::from(0x010000000000000000000000u128)); + assert_eq!(blocks_136[1], Fp::from(0x00)); + assert_eq!(blocks_136[2], Fp::from(0x00)); + assert_eq!(blocks_136[3], Fp::from(0x00)); + assert_eq!(blocks_136[4], Fp::from(0x80)); +} + +#[test] +fn test_is_in_table() { + let table_pad = LookupTable::table_pad(); + let table_round_constants = LookupTable::table_round_constants(); + let table_byte = LookupTable::table_byte(); + let table_range_check_16 = LookupTable::table_range_check_16(); + let table_sparse = LookupTable::table_sparse(); + let table_reset = LookupTable::table_reset(); + // PadLookup + assert!(LookupTable::is_in_table( + &table_pad, + vec![ + Fp::one(), // Length of padding + Fp::two_pow(1), // 2^length of padding + Fp::zero(), // Most significant chunk of padding suffix + Fp::zero(), + Fp::zero(), + Fp::zero(), + Fp::from(0x81) // Least significant chunk of padding suffix + ] + ) + .is_some()); + assert!(LookupTable::is_in_table( + &table_pad, + vec![ + Fp::from(136), // Length of padding + Fp::two_pow(136), // 2^length of padding + Fp::from(0x010000000000000000000000u128), // Most significant chunk of padding suffix + Fp::zero(), + Fp::zero(), + Fp::zero(), + Fp::from(0x80) // Least significant chunk of padding suffix + ] + ) + .is_some()); + assert!(LookupTable::is_in_table(&table_pad, vec![Fp::from(137u32)]).is_none()); + // RoundConstantsLookup + assert!(LookupTable::is_in_table( + &table_round_constants, + vec![ + Fp::zero(), // Round index + Fp::zero(), // Most significant quarter of round constant + Fp::zero(), + Fp::zero(), + Fp::one() // Least significant quarter of round constant + ] + ) + .is_some()); + assert!(LookupTable::is_in_table( + &table_round_constants, + vec![ + Fp::from(23), // Round index + Fp::from(Keccak::sparse(0x8000)[0]), // Most significant quarter of round constant + Fp::from(Keccak::sparse(0x0000)[0]), + Fp::from(Keccak::sparse(0x8000)[0]), + Fp::from(Keccak::sparse(0x8008)[0]), // Least significant quarter of round constant + ] + ) + .is_some()); + assert!(LookupTable::is_in_table(&table_round_constants, vec![Fp::from(24u32)]).is_none()); + // ByteLookup + assert!(LookupTable::is_in_table(&table_byte, vec![Fp::zero()]).is_some()); + assert!(LookupTable::is_in_table(&table_byte, vec![Fp::from(255u32)]).is_some()); + assert!(LookupTable::is_in_table(&table_byte, vec![Fp::from(256u32)]).is_none()); + // RangeCheck16Lookup + assert!(LookupTable::is_in_table(&table_range_check_16, vec![Fp::zero()]).is_some()); + assert!( + LookupTable::is_in_table(&table_range_check_16, vec![Fp::from((1 << 16) - 1)]).is_some() + ); + assert!(LookupTable::is_in_table(&table_range_check_16, vec![Fp::from(1 << 16)]).is_none()); + // SparseLookup + assert!(LookupTable::is_in_table(&table_sparse, vec![Fp::zero()]).is_some()); + assert!(LookupTable::is_in_table( + &table_sparse, + vec![Fp::from(Keccak::sparse((1 << 16) - 1)[3])] + ) + .is_some()); + assert!(LookupTable::is_in_table(&table_sparse, vec![Fp::two()]).is_none()); + // ResetLookup + assert!(LookupTable::is_in_table(&table_reset, vec![Fp::zero(), Fp::zero()]).is_some()); + assert!(LookupTable::is_in_table( + &table_reset, + vec![ + Fp::from((1 << 16) - 1), + Fp::from(Keccak::sparse(((1u128 << 64) - 1) as u64)[3]) + ] + ) + .is_some()); + assert!(LookupTable::is_in_table(&table_reset, vec![Fp::from(1 << 16)]).is_none()); +} + +#[test] +fn test_keccak_witness_satisfies_constraints() { + let mut rng = o1_utils::tests::make_test_rng(None); + + // Generate random bytelength and preimage for Keccak + let bytelength = rng.gen_range(1..1000); + let preimage: Vec = (0..bytelength).map(|_| rng.gen()).collect(); + // Use an external library to compute the hash + let mut hasher = Keccak256::new(); + hasher.update(&preimage); + let hash = hasher.finalize(); + + // Initialize the environment and run the interpreter + let mut keccak_env = KeccakEnv::::new(0, &preimage); + while keccak_env.step.is_some() { + let step = keccak_env.step.unwrap(); + keccak_env.step(); + // Simulate the constraints for each row + keccak_env.witness_env.constraints(step); + assert!(keccak_env.witness_env.errors.is_empty()); + } + // Extract the hash from the witness + let output = keccak_env.witness_env.sponge_bytes()[0..32] + .iter() + .map(|byte| byte.to_bytes()[0]) + .collect::>(); + + // Check that the hash matches + for (i, byte) in output.iter().enumerate() { + assert_eq!(*byte, hash[i]); + } +} + +#[test] +fn test_regression_number_of_lookups_and_constraints_and_degree() { + let mut rng = o1_utils::tests::make_test_rng(None); + + // Generate random bytelength and preimage for Keccak of 1, 2 or 3 blocks + // so that there can be both First, Middle, Last and Only absorbs + let bytelength = rng.gen_range(1..400); + let preimage: Vec = (0..bytelength).map(|_| rng.gen()).collect(); + + let mut keccak_env = KeccakEnv::::new(0, &preimage); + + // Execute the interpreter to obtain constraints for each step + while keccak_env.step.is_some() { + // Current step to be executed + let step = keccak_env.step.unwrap(); + + // Push constraints for the current step + keccak_env.constraints_env.constraints(step); + // Push lookups for the current step + keccak_env.constraints_env.lookups(step); + + // Checking relation constraints for each step selector + let mut constraint_degrees: HashMap = HashMap::new(); + keccak_env + .constraints_env + .constraints + .iter() + .for_each(|constraint| { + let degree = constraint.degree(1, 0); + let entry = constraint_degrees.entry(degree).or_insert(0); + *entry += 1; + }); + + // Check that the number of constraints is correct for that step type + // Check that the degrees of the constraints are correct + // Checking lookup constraints + + match step { + Sponge(Absorb(First)) => { + assert_eq!(keccak_env.constraints_env.lookups.len(), 537); + assert_eq!(keccak_env.constraints_env.constraints.len(), 332); + // We have 1 different degrees of constraints in Absorbs::First + assert_eq!(constraint_degrees.len(), 1); + // 332 degree-1 constraints + assert_eq!(constraint_degrees[&1], 332); + } + Sponge(Absorb(Middle)) => { + assert_eq!(keccak_env.constraints_env.lookups.len(), 538); + assert_eq!(keccak_env.constraints_env.constraints.len(), 232); + // We have 1 different degrees of constraints in Absorbs::Middle + assert_eq!(constraint_degrees.len(), 1); + // 232 degree-1 constraints + assert_eq!(constraint_degrees[&1], 232); + } + Sponge(Absorb(Last)) => { + assert_eq!(keccak_env.constraints_env.lookups.len(), 539); + assert_eq!(keccak_env.constraints_env.constraints.len(), 374); + // We have 2 different degrees of constraints in Squeeze + assert_eq!(constraint_degrees.len(), 2); + // 233 degree-1 constraints + assert_eq!(constraint_degrees[&1], 233); + // 136 degree-2 constraints + assert_eq!(constraint_degrees[&2], 141); + } + Sponge(Absorb(Only)) => { + assert_eq!(keccak_env.constraints_env.lookups.len(), 538); + assert_eq!(keccak_env.constraints_env.constraints.len(), 474); + // We have 2 different degrees of constraints in Squeeze + assert_eq!(constraint_degrees.len(), 2); + // 333 degree-1 constraints + assert_eq!(constraint_degrees[&1], 333); + // 136 degree-2 constraints + assert_eq!(constraint_degrees[&2], 141); + } + Sponge(Squeeze) => { + assert_eq!(keccak_env.constraints_env.lookups.len(), 602); + assert_eq!(keccak_env.constraints_env.constraints.len(), 16); + // We have 1 different degrees of constraints in Squeeze + assert_eq!(constraint_degrees.len(), 1); + // 16 degree-1 constraints + assert_eq!(constraint_degrees[&1], 16); + } + Round(_) => { + assert_eq!(keccak_env.constraints_env.lookups.len(), 1623); + assert_eq!(keccak_env.constraints_env.constraints.len(), 389); + // We have 2 different degrees of constraints in Round + assert_eq!(constraint_degrees.len(), 2); + // 384 degree-1 constraints + assert_eq!(constraint_degrees[&1], 384); + // 5 degree-2 constraints + assert_eq!(constraint_degrees[&2], 5); + } + } + // Execute the step updating the witness + // (no need to happen before constraints if we are not checking the witness) + // This updates the step for the next + keccak_env.step(); + } +} + +#[test] +fn test_keccak_witness_satisfies_lookups() { + let mut rng = o1_utils::tests::make_test_rng(None); + + // Generate random preimage of 1 block for Keccak + let preimage: Vec = (0..100).map(|_| rng.gen()).collect(); + + // Initialize the environment and run the interpreter + let mut keccak_env = KeccakEnv::::new(0, &preimage); + while keccak_env.step.is_some() { + let step = keccak_env.step.unwrap(); + keccak_env.step(); + keccak_env.witness_env.lookups(step); + assert!(keccak_env.witness_env.errors.is_empty()); + } +} + +#[test] +fn test_keccak_fake_witness_wont_satisfy_constraints() { + let mut rng = o1_utils::tests::make_test_rng(None); + + // Generate random preimage of 1 block for Keccak + let preimage: Vec = (0..100).map(|_| rng.gen()).collect(); + + // Initialize witness for + // - 1 absorb + // - 24 rounds + // - 1 squeeze + let n_steps = 26; + let mut witness_env = Vec::with_capacity(n_steps); + + // Initialize the environment + let mut keccak_env = KeccakEnv::::new(0, &preimage); + + // Run the interpreter and keep track of the witness + while keccak_env.step.is_some() { + let step = keccak_env.step.unwrap(); + keccak_env.step(); + // Store a copy of the witness to be altered later + witness_env.push(keccak_env.witness_env.clone()); + // Make sure that the constraints of that row hold + keccak_env.witness_env.constraints(step); + assert!(keccak_env.witness_env.errors.is_empty()); + } + assert_eq!(witness_env.len(), n_steps); + + // NEGATIVIZE THE WITNESS + + // Break padding constraints + let step = Sponge(Absorb(Only)); + assert_eq!(witness_env[0].is_pad(step), Fp::one()); + // Padding can only occur in suffix[3] and suffix[4] because length is 100 bytes + assert_eq!(witness_env[0].pad_suffix(0), Fp::zero()); + assert_eq!(witness_env[0].pad_suffix(1), Fp::zero()); + assert_eq!(witness_env[0].pad_suffix(2), Fp::zero()); + // Check that the padding blocks are corrrect + assert_eq!(witness_env[0].block_in_padding(0), Fp::zero()); + assert_eq!(witness_env[0].block_in_padding(1), Fp::zero()); + assert_eq!(witness_env[0].block_in_padding(2), Fp::zero()); + // Force claim pad in PadBytesFlags(0), involved in suffix(0) + assert_eq!( + witness_env[0].witness[KeccakColumn::PadBytesFlags(0)], + Fp::zero() + ); + witness_env[0].witness[KeccakColumn::PadBytesFlags(0)] = Fp::from(1u32); + // Now that PadBytesFlags(0) is 1, then block_in_padding(0) should be 0b10* + witness_env[0].constrain_padding(step); + // When the byte(0) is different than 0 then the padding suffix constraint also fails + if witness_env[0].sponge_bytes()[0] != Fp::zero() { + assert_eq!( + witness_env[0].errors, + vec![ + Error::Constraint(PadAtEnd), + Error::Constraint(PaddingSuffix(0)) + ] + ); + } else { + assert_eq!(witness_env[0].errors, vec![Error::Constraint(PadAtEnd)]); + } + + witness_env[0].errors.clear(); + + // Break booleanity constraints + witness_env[0].witness[KeccakColumn::PadBytesFlags(0)] = Fp::from(2u32); + witness_env[0].constrain_booleanity(step); + assert_eq!( + witness_env[0].errors, + vec![Error::Constraint(BooleanityPadding(0))] + ); + witness_env[0].errors.clear(); + + // Break absorb constraints + witness_env[0].witness[KeccakColumn::Input(68)] += Fp::from(1u32); + witness_env[0].witness[KeccakColumn::SpongeNewState(68)] += Fp::from(1u32); + witness_env[0].witness[KeccakColumn::Output(68)] += Fp::from(1u32); + witness_env[0].constrain_absorb(step); + assert_eq!( + witness_env[0].errors, + vec![ + Error::Constraint(AbsorbZeroPad(0)), // 68th SpongeNewState is the 0th SpongeZeros + Error::Constraint(AbsorbRootZero(68)), + Error::Constraint(AbsorbXor(68)), + Error::Constraint(AbsorbShifts(68)), + ] + ); + witness_env[0].errors.clear(); + + // Break squeeze constraints + let step = Sponge(Squeeze); + witness_env[25].witness[KeccakColumn::Input(0)] += Fp::from(1u32); + witness_env[25].constrain_squeeze(step); + assert_eq!( + witness_env[25].errors, + vec![Error::Constraint(SqueezeShifts(0))] + ); + witness_env[25].errors.clear(); + + // Break theta constraints + let step = Round(0); + witness_env[1].witness[KeccakColumn::ThetaQuotientC(0)] += Fp::from(2u32); + witness_env[1].witness[KeccakColumn::ThetaShiftsC(0)] += Fp::from(1u32); + witness_env[1].constrain_theta(step); + assert_eq!( + witness_env[1].errors, + vec![ + Error::Constraint(ThetaWordC(0)), + Error::Constraint(ThetaRotatedC(0)), + Error::Constraint(ThetaQuotientC(0)), + Error::Constraint(ThetaShiftsC(0, 0)) + ] + ); + witness_env[1].errors.clear(); + witness_env[1].witness[KeccakColumn::ThetaQuotientC(0)] -= Fp::from(2u32); + witness_env[1].witness[KeccakColumn::ThetaShiftsC(0)] -= Fp::from(1u32); + let state_e = witness_env[1].constrain_theta(step); + assert!(witness_env[1].errors.is_empty()); + + // Break pi-rho constraints + witness_env[1].witness[KeccakColumn::PiRhoRemainderE(0)] += Fp::from(1u32); + witness_env[1].witness[KeccakColumn::PiRhoShiftsE(0)] += Fp::from(1u32); + witness_env[1].constrain_pirho(step, state_e.clone()); + assert_eq!( + witness_env[1].errors, + vec![ + Error::Constraint(PiRhoWordE(0, 0)), + Error::Constraint(PiRhoRotatedE(0, 0)), + Error::Constraint(PiRhoShiftsE(0, 0, 0)), + ] + ); + witness_env[1].errors.clear(); + witness_env[1].witness[KeccakColumn::PiRhoRemainderE(0)] -= Fp::from(1u32); + witness_env[1].witness[KeccakColumn::PiRhoShiftsE(0)] -= Fp::from(1u32); + let state_b = witness_env[1].constrain_pirho(step, state_e); + assert!(witness_env[1].errors.is_empty()); + + // Break chi constraints + witness_env[1].witness[KeccakColumn::ChiShiftsB(0)] += Fp::from(1u32); + witness_env[1].witness[KeccakColumn::ChiShiftsSum(0)] += Fp::from(1u32); + witness_env[1].constrain_chi(step, state_b.clone()); + assert_eq!( + witness_env[1].errors, + vec![ + Error::Constraint(ChiShiftsB(0, 0, 0)), + Error::Constraint(ChiShiftsSum(0, 0, 0)), + Error::Constraint(ChiShiftsSum(0, 3, 0)), + Error::Constraint(ChiShiftsSum(0, 4, 0)), + ] + ); + witness_env[1].errors.clear(); + witness_env[1].witness[KeccakColumn::ChiShiftsB(0)] -= Fp::from(1u32); + witness_env[1].witness[KeccakColumn::ChiShiftsSum(0)] -= Fp::from(1u32); + let state_f = witness_env[1].constrain_chi(step, state_b); + assert!(witness_env[1].errors.is_empty()); + + // Break iota constraints + witness_env[1].witness[KeccakColumn::Output(0)] += Fp::from(1u32); + witness_env[1].constrain_iota(step, state_f); + assert_eq!( + witness_env[1].errors, + vec![Error::Constraint(IotaStateG(0))] + ); + witness_env[1].errors.clear(); +} + +#[test] +fn test_keccak_multiplicities() { + let mut rng = o1_utils::tests::make_test_rng(None); + + // Generate random preimage of 1 block for Keccak, which will need a second full block for padding + let preimage: Vec = (0..136).map(|_| rng.gen()).collect(); + + // Initialize witness for + // - 1 root absorb + // - 24 rounds + // - 1 pad absorb + // - 24 rounds + // - 1 squeeze + let n_steps = 51; + let mut witness_env = Vec::with_capacity(n_steps); + + // Run the interpreter and keep track of the witness + let mut keccak_env = KeccakEnv::::new(0, &preimage); + while keccak_env.step.is_some() { + let step = keccak_env.step.unwrap(); + keccak_env.step(); + keccak_env.witness_env.lookups(step); + // Store a copy of the witness + witness_env.push(keccak_env.witness_env.clone()); + } + assert_eq!(witness_env.len(), n_steps); + + // Check multiplicities of the padding suffixes + assert_eq!( + witness_env[25].multiplicities.get_mut(&PadLookup).unwrap()[135], + 1 + ); + // Check multiplicities of the round constants of Rounds 0 + assert_eq!( + witness_env[26] + .multiplicities + .get_mut(&RoundConstantsLookup) + .unwrap()[0], + 2 + ); +} diff --git a/o1vm/src/interpreters/keccak/witness.rs b/o1vm/src/interpreters/keccak/witness.rs new file mode 100644 index 0000000000..e8ba6ba0fb --- /dev/null +++ b/o1vm/src/interpreters/keccak/witness.rs @@ -0,0 +1,124 @@ +//! This file contains the witness for the Keccak hash function for the zkVM project. +//! It assigns the witness values to the corresponding columns of KeccakWitness in the environment. +//! +//! The actual witness generation code makes use of the code which is already present in Kimchi, +//! to avoid code duplication and reduce error-proneness. +//! +//! For a pseudo code implementation of Keccap-f, see +//! +use std::collections::HashMap; + +use crate::{ + interpreters::keccak::{ + column::KeccakWitness, + helpers::{ArithHelpers, BoolHelpers, LogupHelpers}, + interpreter::{Interpreter, KeccakInterpreter}, + Constraint, Error, KeccakColumn, + }, + lookups::{ + FixedLookupTables, Lookup, LookupTable, + LookupTableIDs::{self, *}, + }, +}; + +use ark_ff::Field; +use kimchi::o1_utils::Two; +use kimchi_msm::LookupTableID; + +/// This struct contains all that needs to be kept track of during the execution of the Keccak step interpreter +// TODO: the fixed tables information should be inferred from the general environment +#[derive(Clone, Debug)] +pub struct Env { + /// The full state of the Keccak gate (witness) + pub witness: KeccakWitness, + /// The fixed tables used in the Keccak gate + pub tables: HashMap>, + /// The multiplicities of each lookup entry. Should not be cleared between steps. + pub multiplicities: HashMap>, + /// If any, an error that occurred during the execution of the constraints, to help with debugging + pub(crate) errors: Vec, +} + +impl Default for Env { + fn default() -> Self { + Self { + witness: KeccakWitness::default(), + tables: { + let mut t = HashMap::new(); + t.insert(PadLookup, LookupTable::table_pad()); + t.insert(RoundConstantsLookup, LookupTable::table_round_constants()); + t.insert(AtMost4Lookup, LookupTable::table_at_most_4()); + t.insert(ByteLookup, LookupTable::table_byte()); + t.insert(RangeCheck16Lookup, LookupTable::table_range_check_16()); + t.insert(SparseLookup, LookupTable::table_sparse()); + t.insert(ResetLookup, LookupTable::table_reset()); + t + }, + multiplicities: { + let mut m = HashMap::new(); + m.insert(PadLookup, vec![0; PadLookup.length()]); + m.insert(RoundConstantsLookup, vec![0; RoundConstantsLookup.length()]); + m.insert(AtMost4Lookup, vec![0; AtMost4Lookup.length()]); + m.insert(ByteLookup, vec![0; ByteLookup.length()]); + m.insert(RangeCheck16Lookup, vec![0; RangeCheck16Lookup.length()]); + m.insert(SparseLookup, vec![0; SparseLookup.length()]); + m.insert(ResetLookup, vec![0; ResetLookup.length()]); + m + }, + errors: vec![], + } + } +} + +impl ArithHelpers for Env { + fn two_pow(x: u64) -> as Interpreter>::Variable { + Self::constant_field(F::two_pow(x)) + } +} + +impl BoolHelpers for Env {} + +impl LogupHelpers for Env {} + +impl Interpreter for Env { + type Variable = F; + + fn constant(x: u64) -> Self::Variable { + Self::constant_field(F::from(x)) + } + + fn constant_field(x: F) -> Self::Variable { + x + } + + fn variable(&self, column: KeccakColumn) -> Self::Variable { + self.witness[column] + } + + /// Checks the constraint `tag` by checking that the input `x` is zero + fn constrain(&mut self, tag: Constraint, if_true: Self::Variable, x: Self::Variable) { + if if_true == Self::Variable::one() && x != F::zero() { + self.errors.push(Error::Constraint(tag)); + } + } + + fn add_lookup(&mut self, if_true: Self::Variable, lookup: Lookup) { + // Keep track of multiplicities for fixed lookups + if if_true == Self::Variable::one() && lookup.table_id.is_fixed() { + // Only when reading. We ignore the other values. + if lookup.magnitude == Self::one() { + // Check that the lookup value is in the table + if let Some(idx) = LookupTable::is_in_table( + self.tables.get_mut(&lookup.table_id).unwrap(), + lookup.value, + ) { + self.multiplicities.get_mut(&lookup.table_id).unwrap()[idx] += 1; + } else { + self.errors.push(Error::Lookup(lookup.table_id)); + } + } + } + } +} + +impl KeccakInterpreter for Env {} diff --git a/o1vm/src/interpreters/mips/column.rs b/o1vm/src/interpreters/mips/column.rs new file mode 100644 index 0000000000..cd895198b7 --- /dev/null +++ b/o1vm/src/interpreters/mips/column.rs @@ -0,0 +1,206 @@ +use crate::interpreters::mips::Instruction::{self, IType, JType, RType}; +use kimchi_msm::{ + columns::{Column, ColumnIndexer}, + witness::Witness, +}; +use std::ops::{Index, IndexMut}; +use strum::EnumCount; + +use super::{ITypeInstruction, JTypeInstruction, RTypeInstruction}; + +pub(crate) const SCRATCH_SIZE_WITHOUT_KECCAK: usize = 45; +/// The number of hashes performed so far in the block +pub(crate) const MIPS_HASH_COUNTER_OFF: usize = SCRATCH_SIZE_WITHOUT_KECCAK; +/// The number of bytes of the preimage that have been read so far in this hash +pub(crate) const MIPS_BYTE_COUNTER_OFF: usize = SCRATCH_SIZE_WITHOUT_KECCAK + 1; +/// A flag indicating whether the preimage has been read fully or not +pub(crate) const MIPS_END_OF_PREIMAGE_OFF: usize = SCRATCH_SIZE_WITHOUT_KECCAK + 2; +/// The number of preimage bytes processed in this step +pub(crate) const MIPS_NUM_BYTES_READ_OFF: usize = SCRATCH_SIZE_WITHOUT_KECCAK + 3; +/// The at most 4-byte chunk of the preimage that has been read in this step. +/// Contains a field element of at most 4 bytes. +pub(crate) const MIPS_PREIMAGE_CHUNK_OFF: usize = SCRATCH_SIZE_WITHOUT_KECCAK + 4; +/// The at most 4-bytes of the preimage that are currently being processed +/// Consists of 4 field elements of at most 1 byte each. +pub(crate) const MIPS_PREIMAGE_BYTES_OFF: usize = SCRATCH_SIZE_WITHOUT_KECCAK + 5; +/// The at most 4-bytes of the length that are currently being processed +pub(crate) const MIPS_LENGTH_BYTES_OFF: usize = SCRATCH_SIZE_WITHOUT_KECCAK + 5 + 4; +/// Flags indicating whether at least N bytes have been processed in this step +pub(crate) const MIPS_HAS_N_BYTES_OFF: usize = SCRATCH_SIZE_WITHOUT_KECCAK + 5 + 4 + 4; +/// The maximum size of a chunk (4 bytes) +pub(crate) const MIPS_CHUNK_BYTES_LEN: usize = 4; +/// The location of the preimage key as a field element of 248bits +pub(crate) const MIPS_PREIMAGE_KEY: usize = SCRATCH_SIZE_WITHOUT_KECCAK + 5 + 4 + 4 + 4; + +// MIPS + hash_counter + byte_counter + eof + num_bytes_read + chunk + bytes +// + length + has_n_bytes + chunk_bytes + preimage +pub const SCRATCH_SIZE: usize = SCRATCH_SIZE_WITHOUT_KECCAK + 5 + 4 + 4 + 4 + 1; + +/// Number of columns used by the MIPS interpreter to keep values to be +/// inverted. +pub const SCRATCH_SIZE_INVERSE: usize = 12; + +/// The number of columns used for relation witness in the MIPS circuit +pub const N_MIPS_REL_COLS: usize = SCRATCH_SIZE + SCRATCH_SIZE_INVERSE + 2; + +/// The number of witness columns used to store the instruction selectors. +pub const N_MIPS_SEL_COLS: usize = + RTypeInstruction::COUNT + JTypeInstruction::COUNT + ITypeInstruction::COUNT; + +/// All the witness columns used in MIPS +pub const N_MIPS_COLS: usize = N_MIPS_REL_COLS + N_MIPS_SEL_COLS; + +/// Abstract columns (or variables of our multi-variate polynomials) that will +/// be used to describe our constraints. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum ColumnAlias { + // Can be seen as the abstract indexed variable X_{i} + ScratchState(usize), + // A column whose value needs to be inverted in the final witness. + // We're keeping a separate column to perform a batch inversion at the end. + ScratchStateInverse(usize), + InstructionCounter, + Selector(usize), +} + +/// The columns used by the MIPS circuit. The MIPS circuit is split into three +/// main opcodes: RType, JType, IType. The columns are shared between different +/// instruction types. (the total number of columns refers to the maximum of +/// columns used by each mode) +impl From for usize { + fn from(alias: ColumnAlias) -> usize { + // Note that SCRATCH_SIZE + 1 is for the error + match alias { + ColumnAlias::ScratchState(i) => { + assert!(i < SCRATCH_SIZE); + i + } + ColumnAlias::ScratchStateInverse(i) => { + assert!(i < SCRATCH_SIZE_INVERSE); + SCRATCH_SIZE + i + } + ColumnAlias::InstructionCounter => SCRATCH_SIZE + SCRATCH_SIZE_INVERSE, + ColumnAlias::Selector(s) => SCRATCH_SIZE + SCRATCH_SIZE_INVERSE + 1 + s, + } + } +} + +/// Returns the corresponding index of the corresponding DynamicSelector column. +impl From for usize { + fn from(instr: Instruction) -> usize { + match instr { + RType(rtype) => N_MIPS_REL_COLS + rtype as usize, + JType(jtype) => N_MIPS_REL_COLS + RTypeInstruction::COUNT + jtype as usize, + IType(itype) => { + N_MIPS_REL_COLS + RTypeInstruction::COUNT + JTypeInstruction::COUNT + itype as usize + } + } + } +} + +/// Represents one line of the execution trace of the virtual machine +/// It contain +/// + [SCRATCH_SIZE] columns +/// + 1 column to keep track of the instruction index +/// + 1 column for the system error code +/// + [N_MIPS_SEL_COLS] columns for the instruction selectors. +/// The columns are, in order, +/// - the 32 general purpose registers +/// - the low and hi registers used by some arithmetic instructions +/// - the current instruction pointer +/// - the next instruction pointer +/// - the heap pointer +/// - the preimage key, splitted in 8 consecutive columns representing 4 bytes +/// of the 32 bytes long preimage key +/// - the preimage offset, i.e. the number of bytes that have been read for the +/// currently processing preimage +/// - `[SCRATCH_SIZE] - 46` intermediate columns that can be used by the +/// instruction set +/// - the hash counter +/// - the flag to indicate if the current instruction is a preimage syscall +/// - the flag to indicate if the current instruction is reading a preimage +/// - the number of bytes read so far for the current preimage +/// - how many bytes are left to be read for the current preimage +/// - the (at most) 4 bytes of the preimage key that are currently being +/// processed +/// - 4 helpers to check if at least n bytes were read in the current row +pub type MIPSWitness = Witness; + +// IMPLEMENTATIONS FOR COLUMN ALIAS + +impl Index for MIPSWitness { + type Output = T; + + /// Map the column alias to the actual column index. + fn index(&self, index: ColumnAlias) -> &Self::Output { + &self.cols[usize::from(index)] + } +} + +impl IndexMut for MIPSWitness { + fn index_mut(&mut self, index: ColumnAlias) -> &mut Self::Output { + &mut self.cols[usize::from(index)] + } +} + +impl ColumnIndexer for ColumnAlias { + const N_COL: usize = N_MIPS_COLS; + + fn to_column(self) -> Column { + match self { + Self::ScratchState(ss) => { + assert!( + ss < SCRATCH_SIZE, + "The maximum index is {}, got {}", + SCRATCH_SIZE, + ss + ); + Column::Relation(ss) + } + Self::ScratchStateInverse(ss) => { + assert!( + ss < SCRATCH_SIZE_INVERSE, + "The maximum index is {}, got {}", + SCRATCH_SIZE_INVERSE, + ss + ); + Column::Relation(SCRATCH_SIZE + ss) + } + Self::InstructionCounter => Column::Relation(SCRATCH_SIZE + SCRATCH_SIZE_INVERSE), + // TODO: what happens with error? It does not have a corresponding alias + Self::Selector(s) => { + assert!( + s < N_MIPS_SEL_COLS, + "The maximum index is {}, got {}", + N_MIPS_SEL_COLS, + s + ); + Column::DynamicSelector(s) + } + } + } +} + +// IMPLEMENTATIONS FOR SELECTOR + +impl Index for MIPSWitness { + type Output = T; + + /// Map the column alias to the actual column index. + fn index(&self, index: Instruction) -> &Self::Output { + &self.cols[usize::from(index)] + } +} + +impl IndexMut for MIPSWitness { + fn index_mut(&mut self, index: Instruction) -> &mut Self::Output { + &mut self.cols[usize::from(index)] + } +} + +impl ColumnIndexer for Instruction { + const N_COL: usize = N_MIPS_REL_COLS + N_MIPS_SEL_COLS; + fn to_column(self) -> Column { + Column::DynamicSelector(usize::from(self) - N_MIPS_REL_COLS) + } +} diff --git a/o1vm/src/interpreters/mips/constraints.rs b/o1vm/src/interpreters/mips/constraints.rs new file mode 100644 index 0000000000..279cb0a6c1 --- /dev/null +++ b/o1vm/src/interpreters/mips/constraints.rs @@ -0,0 +1,675 @@ +use crate::{ + interpreters::mips::{ + column::{ + ColumnAlias as MIPSColumn, MIPS_BYTE_COUNTER_OFF, MIPS_CHUNK_BYTES_LEN, + MIPS_END_OF_PREIMAGE_OFF, MIPS_HASH_COUNTER_OFF, MIPS_HAS_N_BYTES_OFF, + MIPS_LENGTH_BYTES_OFF, MIPS_NUM_BYTES_READ_OFF, MIPS_PREIMAGE_BYTES_OFF, + MIPS_PREIMAGE_CHUNK_OFF, MIPS_PREIMAGE_KEY, N_MIPS_REL_COLS, + }, + interpreter::InterpreterEnv, + Instruction, + }, + lookups::{Lookup, LookupTableIDs}, + E, +}; +use ark_ff::{Field, One}; +use kimchi::circuits::{ + expr::{ConstantTerm::Literal, Expr, ExprInner, Operations, Variable}, + gate::CurrOrNext, +}; +use kimchi_msm::columns::ColumnIndexer as _; +use std::array; + +use super::column::N_MIPS_SEL_COLS; + +/// The environment keeping the constraints between the different polynomials +pub struct Env { + scratch_state_idx: usize, + scratch_state_idx_inverse: usize, + /// A list of constraints, which are multi-variate polynomials over a field, + /// represented using the expression framework of `kimchi`. + constraints: Vec>, + lookups: Vec>>, + /// Selector (as expression) for the constraints of the environment. + selector: Option>, +} + +impl Default for Env { + fn default() -> Self { + Self { + scratch_state_idx: 0, + scratch_state_idx_inverse: 0, + constraints: Vec::new(), + lookups: Vec::new(), + selector: None, + } + } +} + +impl InterpreterEnv for Env { + /// In the concrete implementation for the constraints, the interpreter will + /// work over columns. The position in this case can be seen as a new + /// variable/input of our circuit. + type Position = MIPSColumn; + + // Use one of the available columns. It won't create a new column every time + // this function is called. The number of columns is defined upfront by + // crate::interpreters::mips::column::SCRATCH_SIZE. + fn alloc_scratch(&mut self) -> Self::Position { + // All columns are implemented using a simple index, and a name is given + // to the index. See crate::interpreters::mips::column::SCRATCH_SIZE for the maximum number of + // columns the circuit can use. + let scratch_idx = self.scratch_state_idx; + self.scratch_state_idx += 1; + MIPSColumn::ScratchState(scratch_idx) + } + + fn alloc_scratch_inverse(&mut self) -> Self::Position { + let scratch_idx = self.scratch_state_idx_inverse; + self.scratch_state_idx_inverse += 1; + MIPSColumn::ScratchStateInverse(scratch_idx) + } + + type Variable = E; + + fn variable(&self, column: Self::Position) -> Self::Variable { + Expr::Atom(ExprInner::Cell(Variable { + col: column.to_column(), + row: CurrOrNext::Curr, + })) + } + + fn activate_selector(&mut self, selector: Instruction) { + // Sanity check: we only want to activate once per instruction + assert!(self.selector.is_none(), "A selector has been already activated. You might need to reset the environment if you want to start a new instruction."); + let n = usize::from(selector) - N_MIPS_REL_COLS; + self.selector = Some(self.variable(MIPSColumn::Selector(n))) + } + + fn add_constraint(&mut self, assert_equals_zero: Self::Variable) { + self.constraints.push(assert_equals_zero) + } + + fn check_is_zero(_assert_equals_zero: &Self::Variable) { + // No-op, witness only + } + + fn check_equal(_x: &Self::Variable, _y: &Self::Variable) { + // No-op, witness only + } + + fn check_boolean(_x: &Self::Variable) { + // No-op, witness only + } + + fn add_lookup(&mut self, lookup: Lookup) { + self.lookups.push(lookup); + } + + fn instruction_counter(&self) -> Self::Variable { + self.variable(MIPSColumn::InstructionCounter) + } + + fn increase_instruction_counter(&mut self) { + // No-op, witness only + } + + unsafe fn fetch_register( + &mut self, + _idx: &Self::Variable, + output: Self::Position, + ) -> Self::Variable { + self.variable(output) + } + + unsafe fn push_register_if( + &mut self, + _idx: &Self::Variable, + _value: Self::Variable, + _if_is_true: &Self::Variable, + ) { + // No-op, witness only + } + + unsafe fn fetch_register_access( + &mut self, + _idx: &Self::Variable, + output: Self::Position, + ) -> Self::Variable { + self.variable(output) + } + + unsafe fn push_register_access_if( + &mut self, + _idx: &Self::Variable, + _value: Self::Variable, + _if_is_true: &Self::Variable, + ) { + // No-op, witness only + } + + unsafe fn fetch_memory( + &mut self, + _addr: &Self::Variable, + output: Self::Position, + ) -> Self::Variable { + self.variable(output) + } + + unsafe fn push_memory(&mut self, _addr: &Self::Variable, _value: Self::Variable) { + // No-op, witness only + } + + unsafe fn fetch_memory_access( + &mut self, + _addr: &Self::Variable, + output: Self::Position, + ) -> Self::Variable { + self.variable(output) + } + + unsafe fn push_memory_access(&mut self, _addr: &Self::Variable, _value: Self::Variable) { + // No-op, witness only + } + + fn constant(x: u32) -> Self::Variable { + Self::Variable::constant(Operations::from(Literal(Fp::from(x)))) + } + + unsafe fn bitmask( + &mut self, + _x: &Self::Variable, + _highest_bit: u32, + _lowest_bit: u32, + position: Self::Position, + ) -> Self::Variable { + self.variable(position) + } + + unsafe fn shift_left( + &mut self, + _x: &Self::Variable, + _by: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + self.variable(position) + } + + unsafe fn shift_right( + &mut self, + _x: &Self::Variable, + _by: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + self.variable(position) + } + + unsafe fn shift_right_arithmetic( + &mut self, + _x: &Self::Variable, + _by: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + self.variable(position) + } + + unsafe fn test_zero( + &mut self, + _x: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + self.variable(position) + } + + fn is_zero(&mut self, x: &Self::Variable) -> Self::Variable { + let res = { + let pos = self.alloc_scratch(); + unsafe { self.test_zero(x, pos) } + }; + let x_inv_or_zero = { + let pos = self.alloc_scratch_inverse(); + self.variable(pos) + }; + // If x = 0, then res = 1 and x_inv_or_zero = 0 + // If x <> 0, then res = 0 and x_inv_or_zero = x^(-1) + self.add_constraint(x.clone() * x_inv_or_zero.clone() + res.clone() - Self::constant(1)); + self.add_constraint(x.clone() * res.clone()); + res + } + + fn equal(&mut self, x: &Self::Variable, y: &Self::Variable) -> Self::Variable { + self.is_zero(&(x.clone() - y.clone())) + } + + unsafe fn test_less_than( + &mut self, + _x: &Self::Variable, + _y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + self.variable(position) + } + + unsafe fn test_less_than_signed( + &mut self, + _x: &Self::Variable, + _y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + self.variable(position) + } + + unsafe fn and_witness( + &mut self, + _x: &Self::Variable, + _y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + self.variable(position) + } + + unsafe fn nor_witness( + &mut self, + _x: &Self::Variable, + _y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + self.variable(position) + } + + unsafe fn or_witness( + &mut self, + _x: &Self::Variable, + _y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + self.variable(position) + } + + unsafe fn xor_witness( + &mut self, + _x: &Self::Variable, + _y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + self.variable(position) + } + + unsafe fn add_witness( + &mut self, + _y: &Self::Variable, + _x: &Self::Variable, + out_position: Self::Position, + overflow_position: Self::Position, + ) -> (Self::Variable, Self::Variable) { + ( + self.variable(out_position), + self.variable(overflow_position), + ) + } + + unsafe fn sub_witness( + &mut self, + _y: &Self::Variable, + _x: &Self::Variable, + out_position: Self::Position, + underflow_position: Self::Position, + ) -> (Self::Variable, Self::Variable) { + ( + self.variable(out_position), + self.variable(underflow_position), + ) + } + + unsafe fn mul_signed_witness( + &mut self, + _x: &Self::Variable, + _y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + self.variable(position) + } + + unsafe fn mul_hi_lo_signed( + &mut self, + _x: &Self::Variable, + _y: &Self::Variable, + position_hi: Self::Position, + position_lo: Self::Position, + ) -> (Self::Variable, Self::Variable) { + (self.variable(position_hi), self.variable(position_lo)) + } + + unsafe fn mul_hi_lo( + &mut self, + _x: &Self::Variable, + _y: &Self::Variable, + position_hi: Self::Position, + position_lo: Self::Position, + ) -> (Self::Variable, Self::Variable) { + (self.variable(position_hi), self.variable(position_lo)) + } + + unsafe fn divmod_signed( + &mut self, + _x: &Self::Variable, + _y: &Self::Variable, + position_quotient: Self::Position, + position_remainder: Self::Position, + ) -> (Self::Variable, Self::Variable) { + ( + self.variable(position_quotient), + self.variable(position_remainder), + ) + } + + unsafe fn divmod( + &mut self, + _x: &Self::Variable, + _y: &Self::Variable, + position_quotient: Self::Position, + position_remainder: Self::Position, + ) -> (Self::Variable, Self::Variable) { + ( + self.variable(position_quotient), + self.variable(position_remainder), + ) + } + + unsafe fn count_leading_zeros( + &mut self, + _x: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + self.variable(position) + } + + unsafe fn count_leading_ones( + &mut self, + _x: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + self.variable(position) + } + + fn copy(&mut self, x: &Self::Variable, position: Self::Position) -> Self::Variable { + let res = self.variable(position); + self.constraints.push(x.clone() - res.clone()); + res + } + + fn set_halted(&mut self, _flag: Self::Variable) { + // TODO + } + + fn report_exit(&mut self, _exit_code: &Self::Variable) {} + + /// This function checks that the preimage is read correctly. + /// It adds 13 constraints, and 5 lookups for the communication channel. + /// In particular, at every step it writes the bytes of the preimage into + /// the channel (excluding the length bytes) and it reads the hash digest + /// from the channel when the preimage is fully read. + /// The output is the actual number of bytes that have been read. + fn request_preimage_write( + &mut self, + _addr: &Self::Variable, + len: &Self::Variable, + pos: Self::Position, + ) -> Self::Variable { + // How many hashes have been performed so far in the circuit + let hash_counter = self.variable(Self::Position::ScratchState(MIPS_HASH_COUNTER_OFF)); + + // How many bytes have been read from the preimage so far + let byte_counter = self.variable(Self::Position::ScratchState(MIPS_BYTE_COUNTER_OFF)); + + // Whether this is the last step of the preimage or not (boolean) + let end_of_preimage = self.variable(Self::Position::ScratchState(MIPS_END_OF_PREIMAGE_OFF)); + + // How many preimage bytes are being processed in this instruction + // FIXME: need to connect this to REGISTER_PREIMAGE_OFFSET or pos? + let num_preimage_bytes_read = + self.variable(Self::Position::ScratchState(MIPS_NUM_BYTES_READ_OFF)); + + // The chunk of at most 4 bytes that is being processed from the + // preimage in this instruction + let this_chunk = self.variable(Self::Position::ScratchState(MIPS_PREIMAGE_CHUNK_OFF)); + + // The preimage key composed of 248 bits + let preimage_key = self.variable(Self::Position::ScratchState(MIPS_PREIMAGE_KEY)); + + // The (at most) 4 bytes that are being processed from the preimage + let bytes: [_; MIPS_CHUNK_BYTES_LEN] = array::from_fn(|i| { + self.variable(Self::Position::ScratchState(MIPS_PREIMAGE_BYTES_OFF + i)) + }); + + // The (at most) 4 bytes that are being read from the bytelength + let length_bytes: [_; MIPS_CHUNK_BYTES_LEN] = array::from_fn(|i| { + self.variable(Self::Position::ScratchState(MIPS_LENGTH_BYTES_OFF + i)) + }); + + // Whether the preimage chunk read has at least n bytes (1, 2, 3, or 4). + // It will be all zero when the syscall reads the bytelength prefix. + let has_n_bytes: [_; MIPS_CHUNK_BYTES_LEN] = array::from_fn(|i| { + self.variable(Self::Position::ScratchState(MIPS_HAS_N_BYTES_OFF + i)) + }); + + // The actual number of bytes read in this instruction, will be 0 <= x <= len <= 4 + let actual_read_bytes = self.variable(pos); + + // EXTRA 13 CONSTRAINTS + + // 5 Booleanity constraints + { + for var in has_n_bytes.iter() { + self.assert_boolean(var.clone()); + } + self.assert_boolean(end_of_preimage.clone()); + } + + // + 4 constraints + { + // Expressions that are nonzero when the exact corresponding number + // of preimage bytes are read (case 0 bytes used when bytelength is read) + // TODO: embed any more complex logic to know how many bytes are read + // depending on the address and length as in the witness? + // FIXME: use the lines below when the issue with `equal` is solved + // that will bring the number of constraints from 23 to 31 + // (meaning the unit test needs to be manually adapted) + // let preimage_1 = self.equal(&num_preimage_bytes_read, &Expr::from(1)); + // let preimage_2 = self.equal(&num_preimage_bytes_read, &Expr::from(2)); + // let preimage_3 = self.equal(&num_preimage_bytes_read, &Expr::from(3)); + // let preimage_4 = self.equal(&num_preimage_bytes_read, &Expr::from(4)); + + let preimage_1 = (num_preimage_bytes_read.clone()) + * (num_preimage_bytes_read.clone() - Expr::from(2)) + * (num_preimage_bytes_read.clone() - Expr::from(3)) + * (num_preimage_bytes_read.clone() - Expr::from(4)); + let preimage_2 = (num_preimage_bytes_read.clone()) + * (num_preimage_bytes_read.clone() - Expr::from(1)) + * (num_preimage_bytes_read.clone() - Expr::from(3)) + * (num_preimage_bytes_read.clone() - Expr::from(4)); + let preimage_3 = (num_preimage_bytes_read.clone()) + * (num_preimage_bytes_read.clone() - Expr::from(1)) + * (num_preimage_bytes_read.clone() - Expr::from(2)) + * (num_preimage_bytes_read.clone() - Expr::from(4)); + let preimage_4 = (num_preimage_bytes_read.clone()) + * (num_preimage_bytes_read.clone() - Expr::from(1)) + * (num_preimage_bytes_read.clone() - Expr::from(2)) + * (num_preimage_bytes_read.clone() - Expr::from(3)); + + // Constrain the byte decomposition of the preimage chunk + // NOTE: these constraints also hold when 0 preimage bytes are read + { + // When only 1 preimage byte is read, the chunk equals byte[0] + self.add_constraint(preimage_1 * (this_chunk.clone() - bytes[0].clone())); + // When 2 bytes are read, the chunk is equal to the + // byte[0] * 2^8 + byte[1] + self.add_constraint( + preimage_2 + * (this_chunk.clone() + - (bytes[0].clone() * Expr::from(2u64.pow(8)) + bytes[1].clone())), + ); + // When 3 bytes are read, the chunk is equal to + // byte[0] * 2^16 + byte[1] * 2^8 + byte[2] + self.add_constraint( + preimage_3 + * (this_chunk.clone() + - (bytes[0].clone() * Expr::from(2u64.pow(16)) + + bytes[1].clone() * Expr::from(2u64.pow(8)) + + bytes[2].clone())), + ); + // When all 4 bytes are read, the chunk is equal to + // byte[0] * 2^24 + byte[1] * 2^16 + byte[2] * 2^8 + byte[3] + self.add_constraint( + preimage_4 + * (this_chunk.clone() + - (bytes[0].clone() * Expr::from(2u64.pow(24)) + + bytes[1].clone() * Expr::from(2u64.pow(16)) + + bytes[2].clone() * Expr::from(2u64.pow(8)) + + bytes[3].clone())), + ); + } + + // +4 constraints + // Constrain the bytes flags depending on the number of preimage + // bytes read in this row + { + // When at least has_1_byte, then any number of bytes can be + // read <=> Check that you can only read 1, 2, 3 or 4 bytes + self.add_constraint( + has_n_bytes[0].clone() + * (num_preimage_bytes_read.clone() - Expr::from(1)) + * (num_preimage_bytes_read.clone() - Expr::from(2)) + * (num_preimage_bytes_read.clone() - Expr::from(3)) + * (num_preimage_bytes_read.clone() - Expr::from(4)), + ); + + // When at least has_2_byte, then any number of bytes can be + // read from the preimage except 1 + self.add_constraint( + has_n_bytes[1].clone() + * (num_preimage_bytes_read.clone() - Expr::from(2)) + * (num_preimage_bytes_read.clone() - Expr::from(3)) + * (num_preimage_bytes_read.clone() - Expr::from(4)), + ); + // When at least has_3_byte, then any number of bytes can be + // read from the preimage except 1 nor 2 + self.add_constraint( + has_n_bytes[2].clone() + * (num_preimage_bytes_read.clone() - Expr::from(3)) + * (num_preimage_bytes_read.clone() - Expr::from(4)), + ); + + // When has_4_byte, then only can read 4 preimage bytes + self.add_constraint( + has_n_bytes[3].clone() * (num_preimage_bytes_read.clone() - Expr::from(4)), + ); + } + } + + // FIXED LOOKUPS + + // Byte checks with lookups: both preimage and length bytes are checked + // TODO: think of a way to merge these together to perform 4 lookups + // instead of 8 per row + // FIXME: understand if length bytes can ever be read together with + // preimage bytes. If not, then we can merge the lookups and just run + // 4 lookups per row for the byte checks. AKA: does the oracle always + // read the length bytes first and then the preimage bytes, with no + // overlapping? + for byte in bytes.iter() { + self.add_lookup(Lookup::read_one( + LookupTableIDs::ByteLookup, + vec![byte.clone()], + )); + } + for b in length_bytes.iter() { + self.add_lookup(Lookup::read_one( + LookupTableIDs::ByteLookup, + vec![b.clone()], + )); + } + + // Check that 0 <= preimage read <= actual read <= len <= 4 + self.lookup_2bits(len); + self.lookup_2bits(&actual_read_bytes); + self.lookup_2bits(&num_preimage_bytes_read); + self.lookup_2bits(&(len.clone() - actual_read_bytes.clone())); + self.lookup_2bits(&(actual_read_bytes.clone() - num_preimage_bytes_read.clone())); + + // COMMUNICATION CHANNEL: Write preimage chunk (1, 2, 3, or 4 bytes) + for i in 0..MIPS_CHUNK_BYTES_LEN { + self.add_lookup(Lookup::write_if( + has_n_bytes[i].clone(), + LookupTableIDs::SyscallLookup, + vec![ + hash_counter.clone(), + byte_counter.clone() + Expr::from(i as u64), + bytes[i].clone(), + ], + )); + } + + // COMMUNICATION CHANNEL: Read hash output + // If no more bytes left to be read, then the end of the preimage is + // true. + // TODO: keep track of counter to diminish the number of bytes at + // each step and check it is zero at the end? + self.add_lookup(Lookup::read_if( + end_of_preimage, + LookupTableIDs::SyscallLookup, + vec![hash_counter.clone(), preimage_key], + )); + + // Return actual length read as variable, stored in `pos` + actual_read_bytes + } + + fn request_hint_write(&mut self, _addr: &Self::Variable, _len: &Self::Variable) { + // No-op, witness only + } + + fn reset(&mut self) { + self.scratch_state_idx = 0; + self.scratch_state_idx_inverse = 0; + self.constraints.clear(); + self.lookups.clear(); + self.selector = None; + } +} + +impl Env { + /// Return the constraints for the selector. + /// Each selector must be a boolean. + pub fn get_selector_constraints(&self) -> Vec> { + let one = ::Variable::one(); + let mut enforce_bool: Vec> = (0..N_MIPS_SEL_COLS) + .map(|i| { + let var = self.variable(MIPSColumn::Selector(i)); + (var.clone() - one.clone()) * var.clone() + }) + .collect(); + let enforce_one_activation = (0..N_MIPS_SEL_COLS).fold(E::::one(), |res, i| { + let var = self.variable(MIPSColumn::Selector(i)); + res - var.clone() + }); + + enforce_bool.push(enforce_one_activation); + enforce_bool + } + + pub fn get_selector(&self) -> E { + self.selector + .clone() + .unwrap_or_else(|| panic!("Selector is not set")) + } + + /// Return the constraints for the current instruction, without the selector + pub fn get_constraints(&self) -> Vec> { + self.constraints.clone() + } + + pub fn get_lookups(&self) -> Vec>> { + self.lookups.clone() + } +} diff --git a/o1vm/src/interpreters/mips/interpreter.rs b/o1vm/src/interpreters/mips/interpreter.rs new file mode 100644 index 0000000000..0ea55c67ea --- /dev/null +++ b/o1vm/src/interpreters/mips/interpreter.rs @@ -0,0 +1,2662 @@ +use crate::{ + cannon::PAGE_ADDRESS_SIZE, + interpreters::mips::registers::{ + REGISTER_CURRENT_IP, REGISTER_HEAP_POINTER, REGISTER_HI, REGISTER_LO, REGISTER_NEXT_IP, + REGISTER_PREIMAGE_KEY_END, REGISTER_PREIMAGE_OFFSET, + }, + lookups::{Lookup, LookupTableIDs}, +}; +use ark_ff::{One, Zero}; +use strum::{EnumCount, IntoEnumIterator}; +use strum_macros::{EnumCount, EnumIter}; + +pub const FD_STDIN: u32 = 0; +pub const FD_STDOUT: u32 = 1; +pub const FD_STDERR: u32 = 2; +pub const FD_HINT_READ: u32 = 3; +pub const FD_HINT_WRITE: u32 = 4; +pub const FD_PREIMAGE_READ: u32 = 5; +pub const FD_PREIMAGE_WRITE: u32 = 6; + +pub const SYSCALL_MMAP: u32 = 4090; +pub const SYSCALL_BRK: u32 = 4045; +pub const SYSCALL_CLONE: u32 = 4120; +pub const SYSCALL_EXIT_GROUP: u32 = 4246; +pub const SYSCALL_READ: u32 = 4003; +pub const SYSCALL_WRITE: u32 = 4004; +pub const SYSCALL_FCNTL: u32 = 4055; + +#[derive(Debug, Clone, Copy, Eq, PartialEq, EnumCount, EnumIter, Hash, Ord, PartialOrd)] +pub enum Instruction { + RType(RTypeInstruction), + JType(JTypeInstruction), + IType(ITypeInstruction), +} + +#[derive( + Debug, Clone, Copy, Eq, PartialEq, EnumCount, EnumIter, Default, Hash, Ord, PartialOrd, +)] +pub enum RTypeInstruction { + #[default] + ShiftLeftLogical, // sll + ShiftRightLogical, // srl + ShiftRightArithmetic, // sra + ShiftLeftLogicalVariable, // sllv + ShiftRightLogicalVariable, // srlv + ShiftRightArithmeticVariable, // srav + JumpRegister, // jr + JumpAndLinkRegister, // jalr + SyscallMmap, // syscall (Mmap) + SyscallExitGroup, // syscall (ExitGroup) + SyscallReadHint, // syscall (Read 3) + SyscallReadPreimage, // syscall (Read 5) + SyscallReadOther, // syscall (Read ?) + SyscallWriteHint, // syscall (Write 4) + SyscallWritePreimage, // syscall (Write 6) + SyscallWriteOther, // syscall (Write ?) + SyscallFcntl, // syscall (Fcntl) + SyscallOther, // syscall (Brk, Clone, ?) + MoveZero, // movz + MoveNonZero, // movn + Sync, // sync + MoveFromHi, // mfhi + MoveToHi, // mthi + MoveFromLo, // mflo + MoveToLo, // mtlo + Multiply, // mult + MultiplyUnsigned, // multu + Div, // div + DivUnsigned, // divu + Add, // add + AddUnsigned, // addu + Sub, // sub + SubUnsigned, // subu + And, // and + Or, // or + Xor, // xor + Nor, // nor + SetLessThan, // slt + SetLessThanUnsigned, // sltu + MultiplyToRegister, // mul + CountLeadingOnes, // clo + CountLeadingZeros, // clz +} + +#[derive( + Debug, Clone, Copy, Eq, PartialEq, EnumCount, EnumIter, Default, Hash, Ord, PartialOrd, +)] +pub enum JTypeInstruction { + #[default] + Jump, // j + JumpAndLink, // jal +} + +#[derive( + Debug, Clone, Copy, Eq, PartialEq, EnumCount, EnumIter, Default, Hash, Ord, PartialOrd, +)] +pub enum ITypeInstruction { + #[default] + BranchEq, // beq + BranchNeq, // bne + BranchLeqZero, // blez + BranchGtZero, // bgtz + BranchLtZero, // bltz + BranchGeqZero, // bgez + AddImmediate, // addi + AddImmediateUnsigned, // addiu + SetLessThanImmediate, // slti + SetLessThanImmediateUnsigned, // sltiu + AndImmediate, // andi + OrImmediate, // ori + XorImmediate, // xori + LoadUpperImmediate, // lui + Load8, // lb + Load16, // lh + Load32, // lw + Load8Unsigned, // lbu + Load16Unsigned, // lhu + LoadWordLeft, // lwl + LoadWordRight, // lwr + Store8, // sb + Store16, // sh + Store32, // sw + Store32Conditional, // sc + StoreWordLeft, // swl + StoreWordRight, // swr +} + +impl IntoIterator for Instruction { + type Item = Instruction; + type IntoIter = std::vec::IntoIter; + + /// Iterate over the instruction variants + fn into_iter(self) -> Self::IntoIter { + match self { + Instruction::RType(_) => { + let mut iter_contents = Vec::with_capacity(RTypeInstruction::COUNT); + for rtype in RTypeInstruction::iter() { + iter_contents.push(Instruction::RType(rtype)); + } + iter_contents.into_iter() + } + Instruction::JType(_) => { + let mut iter_contents = Vec::with_capacity(JTypeInstruction::COUNT); + for jtype in JTypeInstruction::iter() { + iter_contents.push(Instruction::JType(jtype)); + } + iter_contents.into_iter() + } + Instruction::IType(_) => { + let mut iter_contents = Vec::with_capacity(ITypeInstruction::COUNT); + for itype in ITypeInstruction::iter() { + iter_contents.push(Instruction::IType(itype)); + } + iter_contents.into_iter() + } + } + } +} + +pub trait InterpreterEnv { + /// A position can be seen as an indexed variable + type Position; + + /// Allocate a new abstract variable for the current step. + /// The variable can be used to store temporary values. + /// The variables are "freed" after each step/instruction. + /// The variable allocation can be seen as an allocation on a stack that is + /// popped after each step execution. + /// At the moment, [crate::interpreters::mips::column::SCRATCH_SIZE - 46] + /// elements can be allocated. If more temporary variables are required for + /// an instruction, increase the value + /// [crate::interpreters::mips::column::SCRATCH_SIZE] + fn alloc_scratch(&mut self) -> Self::Position; + + fn alloc_scratch_inverse(&mut self) -> Self::Position; + + type Variable: Clone + + std::ops::Add + + std::ops::Sub + + std::ops::Mul + + std::fmt::Debug + + Zero + + One; + + // Returns the variable in the current row corresponding to a given column alias. + fn variable(&self, column: Self::Position) -> Self::Variable; + + /// Add a constraint to the proof system, asserting that + /// `assert_equals_zero` is 0. + fn add_constraint(&mut self, assert_equals_zero: Self::Variable); + + /// Activate the selector for the given instruction. + fn activate_selector(&mut self, selector: Instruction); + + /// Check that the witness value in `assert_equals_zero` is 0; otherwise abort. + fn check_is_zero(assert_equals_zero: &Self::Variable); + + /// Assert that the value `assert_equals_zero` is 0, and add a constraint in the proof system. + fn assert_is_zero(&mut self, assert_equals_zero: Self::Variable) { + Self::check_is_zero(&assert_equals_zero); + self.add_constraint(assert_equals_zero); + } + + /// Check that the witness values in `x` and `y` are equal; otherwise abort. + fn check_equal(x: &Self::Variable, y: &Self::Variable); + + /// Assert that the values `x` and `y` are equal, and add a constraint in the proof system. + fn assert_equal(&mut self, x: Self::Variable, y: Self::Variable) { + // NB: We use a different function to give a better error message for debugging. + Self::check_equal(&x, &y); + self.add_constraint(x - y); + } + + /// Check that the witness value `x` is a boolean (`0` or `1`); otherwise abort. + fn check_boolean(x: &Self::Variable); + + /// Assert that the value `x` is boolean, and add a constraint in the proof system. + fn assert_boolean(&mut self, x: Self::Variable) { + Self::check_boolean(&x); + self.add_constraint(x.clone() * x.clone() - x); + } + + fn add_lookup(&mut self, lookup: Lookup); + + fn instruction_counter(&self) -> Self::Variable; + + fn increase_instruction_counter(&mut self); + + /// Fetch the value of the general purpose register with index `idx` and store it in local + /// position `output`. + /// + /// # Safety + /// + /// No lookups or other constraints are added as part of this operation. The caller must + /// manually add the lookups for this operation. + unsafe fn fetch_register( + &mut self, + idx: &Self::Variable, + output: Self::Position, + ) -> Self::Variable; + + /// Set the general purpose register with index `idx` to `value` if `if_is_true` is true. + /// + /// # Safety + /// + /// No lookups or other constraints are added as part of this operation. The caller must + /// manually add the lookups for this operation. + unsafe fn push_register_if( + &mut self, + idx: &Self::Variable, + value: Self::Variable, + if_is_true: &Self::Variable, + ); + + /// Set the general purpose register with index `idx` to `value`. + /// + /// # Safety + /// + /// No lookups or other constraints are added as part of this operation. The caller must + /// manually add the lookups for this operation. + unsafe fn push_register(&mut self, idx: &Self::Variable, value: Self::Variable) { + self.push_register_if(idx, value, &Self::constant(1)) + } + + /// Fetch the last 'access index' for the general purpose register with index `idx`, and store + /// it in local position `output`. + /// + /// # Safety + /// + /// No lookups or other constraints are added as part of this operation. The caller must + /// manually add the lookups for this operation. + unsafe fn fetch_register_access( + &mut self, + idx: &Self::Variable, + output: Self::Position, + ) -> Self::Variable; + + /// Set the last 'access index' for the general purpose register with index `idx` to `value` if + /// `if_is_true` is true. + /// + /// # Safety + /// + /// No lookups or other constraints are added as part of this operation. The caller must + /// manually add the lookups for this operation. + unsafe fn push_register_access_if( + &mut self, + idx: &Self::Variable, + value: Self::Variable, + if_is_true: &Self::Variable, + ); + + /// Set the last 'access index' for the general purpose register with index `idx` to `value`. + /// + /// # Safety + /// + /// No lookups or other constraints are added as part of this operation. The caller must + /// manually add the lookups for this operation. + unsafe fn push_register_access(&mut self, idx: &Self::Variable, value: Self::Variable) { + self.push_register_access_if(idx, value, &Self::constant(1)) + } + + /// Access the general purpose register with index `idx`, adding constraints asserting that the + /// old value was `old_value` and that the new value will be `new_value`, if `if_is_true` is + /// true. + /// + /// # Safety + /// + /// Callers of this function must manually update the registers if required, this function will + /// only update the access counter. + unsafe fn access_register_if( + &mut self, + idx: &Self::Variable, + old_value: &Self::Variable, + new_value: &Self::Variable, + if_is_true: &Self::Variable, + ) { + let last_accessed = { + let last_accessed_location = self.alloc_scratch(); + unsafe { self.fetch_register_access(idx, last_accessed_location) } + }; + let instruction_counter = self.instruction_counter(); + let elapsed_time = instruction_counter.clone() - last_accessed.clone(); + let new_accessed = { + // Here, we write as if the register had been written *at the start of the next + // instruction*. This ensures that we can't 'time travel' within this + // instruction, and claim to read the value that we're about to write! + instruction_counter + Self::constant(1) + // A register should allow multiple accesses to the same register within the same instruction. + // In order to allow this, we always increase the instruction counter by 1. + }; + unsafe { self.push_register_access_if(idx, new_accessed.clone(), if_is_true) }; + self.add_lookup(Lookup::write_if( + if_is_true.clone(), + LookupTableIDs::RegisterLookup, + vec![idx.clone(), last_accessed, old_value.clone()], + )); + self.add_lookup(Lookup::read_if( + if_is_true.clone(), + LookupTableIDs::RegisterLookup, + vec![idx.clone(), new_accessed, new_value.clone()], + )); + self.range_check64(&elapsed_time); + + // Update instruction counter after accessing a register. + self.increase_instruction_counter(); + } + + fn read_register(&mut self, idx: &Self::Variable) -> Self::Variable { + let value = { + let value_location = self.alloc_scratch(); + unsafe { self.fetch_register(idx, value_location) } + }; + unsafe { + self.access_register(idx, &value, &value); + }; + value + } + + /// Access the general purpose register with index `idx`, adding constraints asserting that the + /// old value was `old_value` and that the new value will be `new_value`. + /// + /// # Safety + /// + /// Callers of this function must manually update the registers if required, this function will + /// only update the access counter. + unsafe fn access_register( + &mut self, + idx: &Self::Variable, + old_value: &Self::Variable, + new_value: &Self::Variable, + ) { + self.access_register_if(idx, old_value, new_value, &Self::constant(1)) + } + + fn write_register_if( + &mut self, + idx: &Self::Variable, + new_value: Self::Variable, + if_is_true: &Self::Variable, + ) { + let old_value = { + let value_location = self.alloc_scratch(); + unsafe { self.fetch_register(idx, value_location) } + }; + // Ensure that we only write 0 to the 0 register. + let actual_new_value = { + let idx_is_zero = self.is_zero(idx); + let pos = self.alloc_scratch(); + self.copy(&((Self::constant(1) - idx_is_zero) * new_value), pos) + }; + unsafe { + self.access_register_if(idx, &old_value, &actual_new_value, if_is_true); + }; + unsafe { + self.push_register_if(idx, actual_new_value, if_is_true); + }; + } + + fn write_register(&mut self, idx: &Self::Variable, new_value: Self::Variable) { + self.write_register_if(idx, new_value, &Self::constant(1)) + } + + /// Fetch the memory value at address `addr` and store it in local position `output`. + /// + /// # Safety + /// + /// No lookups or other constraints are added as part of this operation. The caller must + /// manually add the lookups for this memory operation. + unsafe fn fetch_memory( + &mut self, + addr: &Self::Variable, + output: Self::Position, + ) -> Self::Variable; + + /// Set the memory value at address `addr` to `value`. + /// + /// # Safety + /// + /// No lookups or other constraints are added as part of this operation. The caller must + /// manually add the lookups for this memory operation. + unsafe fn push_memory(&mut self, addr: &Self::Variable, value: Self::Variable); + + /// Fetch the last 'access index' that the memory at address `addr` was written at, and store + /// it in local position `output`. + /// + /// # Safety + /// + /// No lookups or other constraints are added as part of this operation. The caller must + /// manually add the lookups for this memory operation. + unsafe fn fetch_memory_access( + &mut self, + addr: &Self::Variable, + output: Self::Position, + ) -> Self::Variable; + + /// Set the last 'access index' for the memory at address `addr` to `value`. + /// + /// # Safety + /// + /// No lookups or other constraints are added as part of this operation. The caller must + /// manually add the lookups for this memory operation. + unsafe fn push_memory_access(&mut self, addr: &Self::Variable, value: Self::Variable); + + /// Access the memory address `addr`, adding constraints asserting that the old value was + /// `old_value` and that the new value will be `new_value`. + /// + /// # Safety + /// + /// Callers of this function must manually update the memory if required, this function will + /// only update the access counter. + unsafe fn access_memory( + &mut self, + addr: &Self::Variable, + old_value: &Self::Variable, + new_value: &Self::Variable, + ) { + let last_accessed = { + let last_accessed_location = self.alloc_scratch(); + unsafe { self.fetch_memory_access(addr, last_accessed_location) } + }; + let instruction_counter = self.instruction_counter(); + let elapsed_time = instruction_counter.clone() - last_accessed.clone(); + let new_accessed = { + // Here, we write as if the memory had been written *at the start of the next + // instruction*. This ensures that we can't 'time travel' within this + // instruction, and claim to read the value that we're about to write! + instruction_counter + Self::constant(1) + }; + unsafe { self.push_memory_access(addr, new_accessed.clone()) }; + self.add_lookup(Lookup::write_one( + LookupTableIDs::MemoryLookup, + vec![addr.clone(), last_accessed, old_value.clone()], + )); + self.add_lookup(Lookup::read_one( + LookupTableIDs::MemoryLookup, + vec![addr.clone(), new_accessed, new_value.clone()], + )); + self.range_check64(&elapsed_time); + + // Update instruction counter after accessing a memory address. + self.increase_instruction_counter(); + } + + fn read_memory(&mut self, addr: &Self::Variable) -> Self::Variable { + let value = { + let value_location = self.alloc_scratch(); + unsafe { self.fetch_memory(addr, value_location) } + }; + unsafe { + self.access_memory(addr, &value, &value); + }; + value + } + + fn write_memory(&mut self, addr: &Self::Variable, new_value: Self::Variable) { + let old_value = { + let value_location = self.alloc_scratch(); + unsafe { self.fetch_memory(addr, value_location) } + }; + unsafe { + self.access_memory(addr, &old_value, &new_value); + }; + unsafe { + self.push_memory(addr, new_value); + }; + } + + /// Adds a lookup to the RangeCheck16Lookup table + fn lookup_16bits(&mut self, value: &Self::Variable) { + self.add_lookup(Lookup::read_one( + LookupTableIDs::RangeCheck16Lookup, + vec![value.clone()], + )); + } + + /// Range checks with 2 lookups to the RangeCheck16Lookup table that a value + /// is at most 2^`bits`-1 (bits <= 16). + fn range_check16(&mut self, value: &Self::Variable, bits: u32) { + assert!(bits <= 16); + // 0 <= value < 2^bits + // First, check lowerbound: 0 <= value < 2^16 + self.lookup_16bits(value); + // Second, check upperbound: value + 2^16 - 2^bits < 2^16 + self.lookup_16bits(&(value.clone() + Self::constant(1 << 16) - Self::constant(1 << bits))); + } + + /// Adds a lookup to the ByteLookup table + fn lookup_8bits(&mut self, value: &Self::Variable) { + self.add_lookup(Lookup::read_one( + LookupTableIDs::ByteLookup, + vec![value.clone()], + )); + } + + /// Range checks with 2 lookups to the ByteLookup table that a value + /// is at most 2^`bits`-1 (bits <= 8). + fn range_check8(&mut self, value: &Self::Variable, bits: u32) { + assert!(bits <= 8); + // 0 <= value < 2^bits + // First, check lowerbound: 0 <= value < 2^8 + self.lookup_8bits(value); + // Second, check upperbound: value + 2^8 - 2^bits < 2^8 + self.lookup_8bits(&(value.clone() + Self::constant(1 << 8) - Self::constant(1 << bits))); + } + + /// Adds a lookup to the AtMost4Lookup table + fn lookup_2bits(&mut self, value: &Self::Variable) { + self.add_lookup(Lookup::read_one( + LookupTableIDs::AtMost4Lookup, + vec![value.clone()], + )); + } + + /// Range checks with 1 lookup to the AtMost4Lookup table 0 <= value < 4 + fn range_check2(&mut self, value: &Self::Variable) { + self.lookup_2bits(value); + } + + fn range_check64(&mut self, _value: &Self::Variable) { + // TODO + } + + fn set_instruction_pointer(&mut self, ip: Self::Variable) { + let idx = Self::constant(REGISTER_CURRENT_IP as u32); + let new_accessed = self.instruction_counter() + Self::constant(1); + unsafe { + self.push_register_access(&idx, new_accessed.clone()); + } + unsafe { + self.push_register(&idx, ip.clone()); + } + self.add_lookup(Lookup::read_one( + LookupTableIDs::RegisterLookup, + vec![idx, new_accessed, ip], + )); + } + + fn get_instruction_pointer(&mut self) -> Self::Variable { + let idx = Self::constant(REGISTER_CURRENT_IP as u32); + let ip = { + let value_location = self.alloc_scratch(); + unsafe { self.fetch_register(&idx, value_location) } + }; + self.add_lookup(Lookup::write_one( + LookupTableIDs::RegisterLookup, + vec![idx, self.instruction_counter(), ip.clone()], + )); + ip + } + + fn set_next_instruction_pointer(&mut self, ip: Self::Variable) { + let idx = Self::constant(REGISTER_NEXT_IP as u32); + let new_accessed = self.instruction_counter() + Self::constant(1); + unsafe { + self.push_register_access(&idx, new_accessed.clone()); + } + unsafe { + self.push_register(&idx, ip.clone()); + } + self.add_lookup(Lookup::read_one( + LookupTableIDs::RegisterLookup, + vec![idx, new_accessed, ip], + )); + } + + fn get_next_instruction_pointer(&mut self) -> Self::Variable { + let idx = Self::constant(REGISTER_NEXT_IP as u32); + let ip = { + let value_location = self.alloc_scratch(); + unsafe { self.fetch_register(&idx, value_location) } + }; + self.add_lookup(Lookup::write_one( + LookupTableIDs::RegisterLookup, + vec![idx, self.instruction_counter(), ip.clone()], + )); + ip + } + + fn constant(x: u32) -> Self::Variable; + + /// Extract the bits from the variable `x` between `highest_bit` and `lowest_bit`, and store + /// the result in `position`. + /// `lowest_bit` becomes the least-significant bit of the resulting value. + /// + /// # Safety + /// + /// There are no constraints on the returned value; callers must assert the relationship with + /// the source variable `x` and that the returned value fits in `highest_bit - lowest_bit` + /// bits. + /// + /// Do not call this function with highest_bit - lowest_bit >= 32. + // TODO: embed the range check in the function when highest_bit - lowest_bit <= 16? + unsafe fn bitmask( + &mut self, + x: &Self::Variable, + highest_bit: u32, + lowest_bit: u32, + position: Self::Position, + ) -> Self::Variable; + + /// Return the result of shifting `x` by `by`, storing the result in `position`. + /// + /// # Safety + /// + /// There are no constraints on the returned value; callers must assert the relationship with + /// the source variable `x` and the shift amount `by`. + unsafe fn shift_left( + &mut self, + x: &Self::Variable, + by: &Self::Variable, + position: Self::Position, + ) -> Self::Variable; + + /// Return the result of shifting `x` by `by`, storing the result in `position`. + /// + /// # Safety + /// + /// There are no constraints on the returned value; callers must assert the relationship with + /// the source variable `x` and the shift amount `by`. + unsafe fn shift_right( + &mut self, + x: &Self::Variable, + by: &Self::Variable, + position: Self::Position, + ) -> Self::Variable; + + /// Return the result of shifting `x` by `by`, storing the result in `position`. + /// + /// # Safety + /// + /// There are no constraints on the returned value; callers must assert the relationship with + /// the source variable `x` and the shift amount `by`. + unsafe fn shift_right_arithmetic( + &mut self, + x: &Self::Variable, + by: &Self::Variable, + position: Self::Position, + ) -> Self::Variable; + + /// Returns 1 if `x` is 0, or 0 otherwise, storing the result in `position`. + /// + /// # Safety + /// + /// There are no constraints on the returned value; callers must assert the relationship with + /// `x`. + unsafe fn test_zero(&mut self, x: &Self::Variable, position: Self::Position) -> Self::Variable; + + fn is_zero(&mut self, x: &Self::Variable) -> Self::Variable; + + /// Returns 1 if `x` is equal to `y`, or 0 otherwise, storing the result in `position`. + fn equal(&mut self, x: &Self::Variable, y: &Self::Variable) -> Self::Variable; + + /// Returns 1 if `x < y` as unsigned integers, or 0 otherwise, storing the result in + /// `position`. + /// + /// # Safety + /// + /// There are no constraints on the returned value; callers must assert that the value + /// correctly represents the relationship between `x` and `y` + unsafe fn test_less_than( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable; + + /// Returns 1 if `x < y` as signed integers, or 0 otherwise, storing the result in `position`. + /// + /// # Safety + /// + /// There are no constraints on the returned value; callers must assert that the value + /// correctly represents the relationship between `x` and `y` + unsafe fn test_less_than_signed( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable; + + /// Returns `x or y`, storing the result in `position`. + /// + /// # Safety + /// + /// There are no constraints on the returned value; callers must manually add constraints to + /// ensure that it is correctly constructed. + unsafe fn and_witness( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable; + + /// Returns `x or y`, storing the result in `position`. + /// + /// # Safety + /// + /// There are no constraints on the returned value; callers must manually add constraints to + /// ensure that it is correctly constructed. + unsafe fn or_witness( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable; + + /// Returns `x nor y`, storing the result in `position`. + /// + /// # Safety + /// + /// There are no constraints on the returned value; callers must manually add constraints to + /// ensure that it is correctly constructed. + unsafe fn nor_witness( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable; + + /// Returns `x xor y`, storing the result in `position`. + /// + /// # Safety + /// + /// There are no constraints on the returned value; callers must manually add constraints to + /// ensure that it is correctly constructed. + unsafe fn xor_witness( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable; + + /// Returns `x + y` and the overflow bit, storing the results in `position_out` and + /// `position_overflow` respectively. + /// + /// # Safety + /// + /// There are no constraints on the returned values; callers must manually add constraints to + /// ensure that they are correctly constructed. + unsafe fn add_witness( + &mut self, + y: &Self::Variable, + x: &Self::Variable, + out_position: Self::Position, + overflow_position: Self::Position, + ) -> (Self::Variable, Self::Variable); + + /// Returns `x + y` and the underflow bit, storing the results in `position_out` and + /// `position_underflow` respectively. + /// + /// # Safety + /// + /// There are no constraints on the returned values; callers must manually add constraints to + /// ensure that they are correctly constructed. + unsafe fn sub_witness( + &mut self, + y: &Self::Variable, + x: &Self::Variable, + out_position: Self::Position, + underflow_position: Self::Position, + ) -> (Self::Variable, Self::Variable); + + /// Returns `x * y`, where `x` and `y` are treated as integers, storing the result in `position`. + /// + /// # Safety + /// + /// There are no constraints on the returned value; callers must manually add constraints to + /// ensure that it is correctly constructed. + unsafe fn mul_signed_witness( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable; + + /// Returns `((x * y) >> 32, (x * y) & ((1 << 32) - 1))`, storing the results in `position_hi` + /// and `position_lo` respectively. + /// + /// # Safety + /// + /// There are no constraints on the returned values; callers must manually add constraints to + /// ensure that the pair of returned values correspond to the given values `x` and `y`, and + /// that they fall within the desired range. + unsafe fn mul_hi_lo_signed( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position_hi: Self::Position, + position_lo: Self::Position, + ) -> (Self::Variable, Self::Variable); + + /// Returns `((x * y) >> 32, (x * y) & ((1 << 32) - 1))`, storing the results in `position_hi` + /// and `position_lo` respectively. + /// + /// # Safety + /// + /// There are no constraints on the returned values; callers must manually add constraints to + /// ensure that the pair of returned values correspond to the given values `x` and `y`, and + /// that they fall within the desired range. + unsafe fn mul_hi_lo( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position_hi: Self::Position, + position_lo: Self::Position, + ) -> (Self::Variable, Self::Variable); + + /// Returns `(x / y, x % y)`, storing the results in `position_quotient` and + /// `position_remainder` respectively. + /// + /// # Safety + /// + /// There are no constraints on the returned values; callers must manually add constraints to + /// ensure that the pair of returned values correspond to the given values `x` and `y`, and + /// that they fall within the desired range. + unsafe fn divmod_signed( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position_quotient: Self::Position, + position_remainder: Self::Position, + ) -> (Self::Variable, Self::Variable); + + /// Returns `(x / y, x % y)`, storing the results in `position_quotient` and + /// `position_remainder` respectively. + /// + /// # Safety + /// + /// There are no constraints on the returned values; callers must manually add constraints to + /// ensure that the pair of returned values correspond to the given values `x` and `y`, and + /// that they fall within the desired range. + unsafe fn divmod( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position_quotient: Self::Position, + position_remainder: Self::Position, + ) -> (Self::Variable, Self::Variable); + + /// Returns the number of leading 0s in `x`, storing the result in `position`. + /// + /// # Safety + /// + /// There are no constraints on the returned value; callers must manually add constraints to + /// ensure that it is correctly constructed. + unsafe fn count_leading_zeros( + &mut self, + x: &Self::Variable, + position: Self::Position, + ) -> Self::Variable; + + /// Returns the number of leading 1s in `x`, storing the result in `position`. + /// + /// # Safety + /// + /// There are no constraints on the returned value; callers must manually add constraints to + /// ensure that it is correctly constructed. + unsafe fn count_leading_ones( + &mut self, + x: &Self::Variable, + position: Self::Position, + ) -> Self::Variable; + + fn copy(&mut self, x: &Self::Variable, position: Self::Position) -> Self::Variable; + + /// Increases the heap pointer by `by_amount` if `if_is_true` is `1`, and returns the previous + /// value of the heap pointer. + fn increase_heap_pointer( + &mut self, + by_amount: &Self::Variable, + if_is_true: &Self::Variable, + ) -> Self::Variable { + let idx = Self::constant(REGISTER_HEAP_POINTER as u32); + let old_ptr = { + let value_location = self.alloc_scratch(); + unsafe { self.fetch_register(&idx, value_location) } + }; + let new_ptr = old_ptr.clone() + by_amount.clone(); + unsafe { + self.access_register_if(&idx, &old_ptr, &new_ptr, if_is_true); + }; + unsafe { + self.push_register_if(&idx, new_ptr, if_is_true); + }; + old_ptr + } + + fn set_halted(&mut self, flag: Self::Variable); + + /// Given a variable `x`, this function extends it to a signed integer of + /// `bitlength` bits. + fn sign_extend(&mut self, x: &Self::Variable, bitlength: u32) -> Self::Variable { + // FIXME: Constrain `high_bit` + let high_bit = { + let pos = self.alloc_scratch(); + unsafe { self.bitmask(x, bitlength, bitlength - 1, pos) } + }; + high_bit * Self::constant(((1 << (32 - bitlength)) - 1) << bitlength) + x.clone() + } + + fn report_exit(&mut self, exit_code: &Self::Variable); + + /// Request the preimage oracle for `len` bytes and store the bytes starting + /// from `addr`, and it returns the number of bytes actually read. + /// The number of bytes actually read will be set into `pos`. + /// The first 8 bytes will be the length of the preimage, encoded as an + /// unsigned 64bits, and the rest will be the preimage. + fn request_preimage_write( + &mut self, + addr: &Self::Variable, + len: &Self::Variable, + pos: Self::Position, + ) -> Self::Variable; + + fn request_hint_write(&mut self, addr: &Self::Variable, len: &Self::Variable); + + /// Reset the environment to handle the next instruction + fn reset(&mut self); +} + +pub fn interpret_instruction(env: &mut Env, instr: Instruction) { + env.activate_selector(instr); + + match instr { + Instruction::RType(instr) => interpret_rtype(env, instr), + Instruction::JType(instr) => interpret_jtype(env, instr), + Instruction::IType(instr) => interpret_itype(env, instr), + } +} + +pub fn interpret_rtype(env: &mut Env, instr: RTypeInstruction) { + let instruction_pointer = env.get_instruction_pointer(); + let next_instruction_pointer = env.get_next_instruction_pointer(); + let instruction = { + let v0 = env.read_memory(&instruction_pointer); + let v1 = env.read_memory(&(instruction_pointer.clone() + Env::constant(1))); + let v2 = env.read_memory(&(instruction_pointer.clone() + Env::constant(2))); + let v3 = env.read_memory(&(instruction_pointer.clone() + Env::constant(3))); + (v0 * Env::constant(1 << 24)) + + (v1 * Env::constant(1 << 16)) + + (v2 * Env::constant(1 << 8)) + + v3 + }; + let opcode = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 32, 26, pos) } + }; + env.range_check8(&opcode, 6); + + let rs = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 26, 21, pos) } + }; + env.range_check8(&rs, 5); + + let rt = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 21, 16, pos) } + }; + env.range_check8(&rt, 5); + + let rd = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 16, 11, pos) } + }; + env.range_check8(&rd, 5); + + let shamt = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 11, 6, pos) } + }; + env.range_check8(&shamt, 5); + + let funct = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 6, 0, pos) } + }; + env.range_check8(&funct, 6); + + // Check correctness of decomposition of instruction into parts + env.add_constraint( + instruction + - (opcode * Env::constant(1 << 26) + + rs.clone() * Env::constant(1 << 21) + + rt.clone() * Env::constant(1 << 16) + + rd.clone() * Env::constant(1 << 11) + + shamt.clone() * Env::constant(1 << 6) + + funct), + ); + + match instr { + RTypeInstruction::ShiftLeftLogical => { + let rt = env.read_register(&rt); + // FIXME: Constrain this value + let shifted = unsafe { + let pos = env.alloc_scratch(); + env.shift_left(&rt, &shamt, pos) + }; + env.write_register(&rd, shifted); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RTypeInstruction::ShiftRightLogical => { + let rt = env.read_register(&rt); + // FIXME: Constrain this value + let shifted = unsafe { + let pos = env.alloc_scratch(); + env.shift_right(&rt, &shamt, pos) + }; + env.write_register(&rd, shifted); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RTypeInstruction::ShiftRightArithmetic => { + let rt = env.read_register(&rt); + // FIXME: Constrain this value + let shifted = unsafe { + let pos = env.alloc_scratch(); + env.shift_right_arithmetic(&rt, &shamt, pos) + }; + env.write_register(&rd, shifted); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RTypeInstruction::ShiftLeftLogicalVariable => { + let rs = env.read_register(&rs); + let rt = env.read_register(&rt); + // FIXME: Constrain this value + let shifted = unsafe { + let pos = env.alloc_scratch(); + env.shift_left(&rt, &rs, pos) + }; + env.write_register(&rd, shifted); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RTypeInstruction::ShiftRightLogicalVariable => { + let rs = env.read_register(&rs); + let rt = env.read_register(&rt); + // FIXME: Constrain this value + let shifted = unsafe { + let pos = env.alloc_scratch(); + env.shift_right(&rt, &rs, pos) + }; + env.write_register(&rd, shifted); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RTypeInstruction::ShiftRightArithmeticVariable => { + let rs = env.read_register(&rs); + let rt = env.read_register(&rt); + // FIXME: Constrain this value + let shifted = unsafe { + let pos = env.alloc_scratch(); + env.shift_right_arithmetic(&rt, &rs, pos) + }; + env.write_register(&rd, shifted); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RTypeInstruction::JumpRegister => { + let addr = env.read_register(&rs); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(addr); + } + RTypeInstruction::JumpAndLinkRegister => { + let addr = env.read_register(&rs); + env.write_register(&rd, instruction_pointer + Env::constant(8u32)); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(addr); + } + RTypeInstruction::SyscallMmap => { + let requested_alloc_size = env.read_register(&Env::constant(5)); + let size_in_pages = { + // FIXME: Requires a range check + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&requested_alloc_size, 32, PAGE_ADDRESS_SIZE, pos) } + }; + let requires_extra_page = { + let remainder = requested_alloc_size + - (size_in_pages.clone() * Env::constant(1 << PAGE_ADDRESS_SIZE)); + Env::constant(1) - env.is_zero(&remainder) + }; + let actual_alloc_size = + (size_in_pages + requires_extra_page) * Env::constant(1 << PAGE_ADDRESS_SIZE); + let address = env.read_register(&Env::constant(4)); + let address_is_zero = env.is_zero(&address); + let old_heap_ptr = env.increase_heap_pointer(&actual_alloc_size, &address_is_zero); + let return_position = { + let pos = env.alloc_scratch(); + env.copy( + &(address_is_zero.clone() * old_heap_ptr + + (Env::constant(1) - address_is_zero) * address), + pos, + ) + }; + env.write_register(&Env::constant(2), return_position); + env.write_register(&Env::constant(7), Env::constant(0)); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RTypeInstruction::SyscallExitGroup => { + let exit_code = env.read_register(&Env::constant(4)); + env.report_exit(&exit_code); + env.set_halted(Env::constant(1)); + } + RTypeInstruction::SyscallReadHint => { + // We don't really write here, since the value is unused, per the cannon + // implementation. Just claim that we wrote the correct length. + let length = env.read_register(&Env::constant(6)); + env.write_register(&Env::constant(2), length); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RTypeInstruction::SyscallReadPreimage => { + let addr = env.read_register(&Env::constant(5)); + let length = env.read_register(&Env::constant(6)); + let preimage_offset = + env.read_register(&Env::constant(REGISTER_PREIMAGE_OFFSET as u32)); + + let read_length = { + let pos = env.alloc_scratch(); + env.request_preimage_write(&addr, &length, pos) + }; + env.write_register( + &Env::constant(REGISTER_PREIMAGE_OFFSET as u32), + preimage_offset + read_length.clone(), + ); + env.write_register(&Env::constant(2), read_length); + env.write_register(&Env::constant(7), Env::constant(0)); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RTypeInstruction::SyscallReadOther => { + let fd_id = env.read_register(&Env::constant(4)); + let mut check_equal = |expected_fd_id: u32| { + // FIXME: Requires constraints + let pos = env.alloc_scratch(); + unsafe { env.test_zero(&(fd_id.clone() - Env::constant(expected_fd_id)), pos) } + }; + let is_stdin = check_equal(FD_STDIN); + let is_preimage_read = check_equal(FD_PREIMAGE_READ); + let is_hint_read = check_equal(FD_HINT_READ); + + // FIXME: Should assert that `is_preimage_read` and `is_hint_read` cannot be true here. + let other_fd = Env::constant(1) - is_stdin - is_preimage_read - is_hint_read; + + // We're either reading stdin, in which case we get `(0, 0)` as desired, or we've hit a + // bad FD that we reject with EBADF. + let v0 = other_fd.clone() * Env::constant(0xFFFFFFFF); + let v1 = other_fd * Env::constant(0x9); // EBADF + + env.write_register(&Env::constant(2), v0); + env.write_register(&Env::constant(7), v1); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RTypeInstruction::SyscallWriteHint => { + let addr = env.read_register(&Env::constant(5)); + let length = env.read_register(&Env::constant(6)); + // TODO: Message preimage oracle + env.request_hint_write(&addr, &length); + env.write_register(&Env::constant(2), length); + env.write_register(&Env::constant(7), Env::constant(0)); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RTypeInstruction::SyscallWritePreimage => { + let addr = env.read_register(&Env::constant(5)); + let write_length = env.read_register(&Env::constant(6)); + + // Cannon assumes that the remaining `byte_length` represents how much remains to be + // read (i.e. all write calls send the full data in one syscall, and attempt to retry + // with the rest until there is a success). This also simplifies the implementation + // here, so we will follow suit. + let bytes_to_preserve_in_register = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&write_length, 2, 0, pos) } + }; + env.range_check2(&bytes_to_preserve_in_register); + let register_idx = { + let registers_left_to_write_after_this = { + let pos = env.alloc_scratch(); + // The virtual register is 32 bits wide, so we can just read 6 bytes. If the + // register has an incorrect value, it will be unprovable and we'll fault. + unsafe { env.bitmask(&write_length, 6, 2, pos) } + }; + env.range_check8(®isters_left_to_write_after_this, 4); + Env::constant(REGISTER_PREIMAGE_KEY_END as u32) - registers_left_to_write_after_this + }; + + let [r0, r1, r2, r3] = { + let register_value = { + let initial_register_value = env.read_register(®ister_idx); + + // We should clear the register if our offset into the read will replace all of its + // bytes. + let should_clear_register = env.is_zero(&bytes_to_preserve_in_register); + + let pos = env.alloc_scratch(); + env.copy( + &((Env::constant(1) - should_clear_register) * initial_register_value), + pos, + ) + }; + [ + { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(®ister_value, 32, 24, pos) } + }, + { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(®ister_value, 24, 16, pos) } + }, + { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(®ister_value, 16, 8, pos) } + }, + { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(®ister_value, 8, 0, pos) } + }, + ] + }; + env.lookup_8bits(&r0); + env.lookup_8bits(&r1); + env.lookup_8bits(&r2); + env.lookup_8bits(&r3); + + // We choose our read address so that the bytes we read come aligned with the target + // bytes in the register, to avoid an expensive bitshift. + let read_address = addr.clone() - bytes_to_preserve_in_register.clone(); + + let m0 = env.read_memory(&read_address); + let m1 = env.read_memory(&(read_address.clone() + Env::constant(1))); + let m2 = env.read_memory(&(read_address.clone() + Env::constant(2))); + let m3 = env.read_memory(&(read_address.clone() + Env::constant(3))); + + // Now, for some complexity. From the perspective of the write operation, we should be + // reading the `4 - bytes_to_preserve_in_register`. However, to match cannon 1:1, we + // only want to read the bytes up to the end of the current word. + let [overwrite_0, overwrite_1, overwrite_2, overwrite_3] = { + let next_word_addr = { + let byte_subaddr = { + // FIXME: Requires a range check + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&addr, 2, 0, pos) } + }; + env.range_check2(&byte_subaddr); + addr.clone() + Env::constant(4) - byte_subaddr + }; + let overwrite_0 = { + // We always write the first byte if we're not preserving it, since it will + // have been read from `addr`. + env.equal(&bytes_to_preserve_in_register, &Env::constant(0)) + }; + let overwrite_1 = { + // We write the second byte if: + // we wrote the first byte + overwrite_0.clone() + // and this isn't the start of the next word (which implies `overwrite_0`), + - env.equal(&(read_address.clone() + Env::constant(1)), &next_word_addr) + // or this byte was read from `addr` + + env.equal(&bytes_to_preserve_in_register, &Env::constant(1)) + }; + let overwrite_2 = { + // We write the third byte if: + // we wrote the second byte + overwrite_1.clone() + // and this isn't the start of the next word (which implies `overwrite_1`), + - env.equal(&(read_address.clone() + Env::constant(2)), &next_word_addr) + // or this byte was read from `addr` + + env.equal(&bytes_to_preserve_in_register, &Env::constant(2)) + }; + let overwrite_3 = { + // We write the fourth byte if: + // we wrote the third byte + overwrite_2.clone() + // and this isn't the start of the next word (which implies `overwrite_2`), + - env.equal(&(read_address.clone() + Env::constant(3)), &next_word_addr) + // or this byte was read from `addr` + + env.equal(&bytes_to_preserve_in_register, &Env::constant(3)) + }; + [overwrite_0, overwrite_1, overwrite_2, overwrite_3] + }; + + let value = { + let value = ((overwrite_0.clone() * m0 + + (Env::constant(1) - overwrite_0.clone()) * r0) + * Env::constant(1 << 24)) + + ((overwrite_1.clone() * m1 + (Env::constant(1) - overwrite_1.clone()) * r1) + * Env::constant(1 << 16)) + + ((overwrite_2.clone() * m2 + (Env::constant(1) - overwrite_2.clone()) * r2) + * Env::constant(1 << 8)) + + (overwrite_3.clone() * m3 + (Env::constant(1) - overwrite_3.clone()) * r3); + let pos = env.alloc_scratch(); + env.copy(&value, pos) + }; + + // Update the preimage key. + env.write_register(®ister_idx, value); + // Reset the preimage offset. + env.write_register( + &Env::constant(REGISTER_PREIMAGE_OFFSET as u32), + Env::constant(0u32), + ); + // Return the number of bytes read. + env.write_register( + &Env::constant(2), + overwrite_0 + overwrite_1 + overwrite_2 + overwrite_3, + ); + // Set the error register to 0. + env.write_register(&Env::constant(7), Env::constant(0u32)); + + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + // REMOVEME: when all itype instructions are implemented. + } + RTypeInstruction::SyscallWriteOther => { + let fd_id = env.read_register(&Env::constant(4)); + let write_length = env.read_register(&Env::constant(6)); + let mut check_equal = |expected_fd_id: u32| { + // FIXME: Requires constraints + let pos = env.alloc_scratch(); + unsafe { env.test_zero(&(fd_id.clone() - Env::constant(expected_fd_id)), pos) } + }; + let is_stdout = check_equal(FD_STDOUT); + let is_stderr = check_equal(FD_STDERR); + let is_preimage_write = check_equal(FD_PREIMAGE_WRITE); + let is_hint_write = check_equal(FD_HINT_WRITE); + + // FIXME: Should assert that `is_preimage_write` and `is_hint_write` cannot be true + // here. + let known_fd = is_stdout + is_stderr + is_preimage_write + is_hint_write; + let other_fd = Env::constant(1) - known_fd.clone(); + + // We're either reading stdin, in which case we get `(0, 0)` as desired, or we've hit a + // bad FD that we reject with EBADF. + let v0 = known_fd * write_length + other_fd.clone() * Env::constant(0xFFFFFFFF); + let v1 = other_fd * Env::constant(0x9); // EBADF + + env.write_register(&Env::constant(2), v0); + env.write_register(&Env::constant(7), v1); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RTypeInstruction::SyscallFcntl => { + let fd_id = env.read_register(&Env::constant(4)); + let fd_cmd = env.read_register(&Env::constant(5)); + let is_getfl = env.equal(&fd_cmd, &Env::constant(3)); + let is_stdin = env.equal(&fd_id, &Env::constant(FD_STDIN)); + let is_stdout = env.equal(&fd_id, &Env::constant(FD_STDOUT)); + let is_stderr = env.equal(&fd_id, &Env::constant(FD_STDERR)); + let is_hint_read = env.equal(&fd_id, &Env::constant(FD_HINT_READ)); + let is_hint_write = env.equal(&fd_id, &Env::constant(FD_HINT_WRITE)); + let is_preimage_read = env.equal(&fd_id, &Env::constant(FD_PREIMAGE_READ)); + let is_preimage_write = env.equal(&fd_id, &Env::constant(FD_PREIMAGE_WRITE)); + + // These variables are 1 if the condition is true, and 0 otherwise. + let is_read = is_stdin + is_preimage_read + is_hint_read; + let is_write = is_stdout + is_stderr + is_preimage_write + is_hint_write; + + let v0 = is_getfl.clone() + * (is_write.clone() + + (Env::constant(1) - is_read.clone() - is_write.clone()) + * Env::constant(0xFFFFFFFF)) + + (Env::constant(1) - is_getfl.clone()) * Env::constant(0xFFFFFFFF); + let v1 = + is_getfl.clone() * (Env::constant(1) - is_read - is_write.clone()) + * Env::constant(0x9) /* EBADF */ + + (Env::constant(1) - is_getfl.clone()) * Env::constant(0x16) /* EINVAL */; + + env.write_register(&Env::constant(2), v0); + env.write_register(&Env::constant(7), v1); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RTypeInstruction::SyscallOther => { + let syscall_num = env.read_register(&Env::constant(2)); + let is_sysbrk = env.equal(&syscall_num, &Env::constant(SYSCALL_BRK)); + let is_sysclone = env.equal(&syscall_num, &Env::constant(SYSCALL_CLONE)); + let v0 = { is_sysbrk * Env::constant(0x40000000) + is_sysclone }; + let v1 = Env::constant(0); + env.write_register(&Env::constant(2), v0); + env.write_register(&Env::constant(7), v1); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RTypeInstruction::MoveZero => { + let rt = env.read_register(&rt); + let is_zero = env.is_zero(&rt); + let rs = env.read_register(&rs); + env.write_register_if(&rd, rs, &is_zero); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RTypeInstruction::MoveNonZero => { + let rt = env.read_register(&rt); + let is_zero = Env::constant(1) - env.is_zero(&rt); + let rs = env.read_register(&rs); + env.write_register_if(&rd, rs, &is_zero); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RTypeInstruction::Sync => { + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RTypeInstruction::MoveFromHi => { + let hi = env.read_register(&Env::constant(REGISTER_HI as u32)); + env.write_register(&rd, hi); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RTypeInstruction::MoveToHi => { + let rs = env.read_register(&rs); + env.write_register(&Env::constant(REGISTER_HI as u32), rs); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RTypeInstruction::MoveFromLo => { + let lo = env.read_register(&Env::constant(REGISTER_LO as u32)); + env.write_register(&rd, lo); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RTypeInstruction::MoveToLo => { + let rs = env.read_register(&rs); + env.write_register(&Env::constant(REGISTER_LO as u32), rs); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RTypeInstruction::Multiply => { + let rs = env.read_register(&rs); + let rt = env.read_register(&rt); + let (hi, lo) = { + // Fixme: constrain + let hi_pos = env.alloc_scratch(); + let lo_pos = env.alloc_scratch(); + unsafe { env.mul_hi_lo_signed(&rs, &rt, hi_pos, lo_pos) } + }; + env.write_register(&Env::constant(REGISTER_HI as u32), hi); + env.write_register(&Env::constant(REGISTER_LO as u32), lo); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RTypeInstruction::MultiplyUnsigned => { + let rs = env.read_register(&rs); + let rt = env.read_register(&rt); + let (hi, lo) = { + // Fixme: constrain + let hi_pos = env.alloc_scratch(); + let lo_pos = env.alloc_scratch(); + unsafe { env.mul_hi_lo(&rs, &rt, hi_pos, lo_pos) } + }; + env.write_register(&Env::constant(REGISTER_HI as u32), hi); + env.write_register(&Env::constant(REGISTER_LO as u32), lo); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RTypeInstruction::Div => { + let rs = env.read_register(&rs); + let rt = env.read_register(&rt); + let (quotient, remainder) = { + // Fixme: constrain + let quotient_pos = env.alloc_scratch(); + let remainder_pos = env.alloc_scratch(); + unsafe { env.divmod_signed(&rs, &rt, quotient_pos, remainder_pos) } + }; + env.write_register(&Env::constant(REGISTER_LO as u32), quotient); + env.write_register(&Env::constant(REGISTER_HI as u32), remainder); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RTypeInstruction::DivUnsigned => { + let rs = env.read_register(&rs); + let rt = env.read_register(&rt); + let (quotient, remainder) = { + // Fixme: constrain + let quotient_pos = env.alloc_scratch(); + let remainder_pos = env.alloc_scratch(); + unsafe { env.divmod(&rs, &rt, quotient_pos, remainder_pos) } + }; + env.write_register(&Env::constant(REGISTER_LO as u32), quotient); + env.write_register(&Env::constant(REGISTER_HI as u32), remainder); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RTypeInstruction::Add => { + let rs = env.read_register(&rs); + let rt = env.read_register(&rt); + let res = { + let res_scratch = env.alloc_scratch(); + let overflow_scratch = env.alloc_scratch(); + let (res, _overflow) = + unsafe { env.add_witness(&rs, &rt, res_scratch, overflow_scratch) }; + // FIXME: Requires a range check + res + }; + env.write_register(&rd, res); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RTypeInstruction::AddUnsigned => { + let rs = env.read_register(&rs); + let rt = env.read_register(&rt); + let res = { + let res_scratch = env.alloc_scratch(); + let overflow_scratch = env.alloc_scratch(); + let (res, _overflow) = + unsafe { env.add_witness(&rs, &rt, res_scratch, overflow_scratch) }; + // FIXME: Requires a range check + res + }; + env.write_register(&rd, res); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RTypeInstruction::Sub => { + let rs = env.read_register(&rs); + let rt = env.read_register(&rt); + let res = { + let res_scratch = env.alloc_scratch(); + let overflow_scratch = env.alloc_scratch(); + let (res, _overflow) = + unsafe { env.sub_witness(&rs, &rt, res_scratch, overflow_scratch) }; + // FIXME: Requires a range check + res + }; + env.write_register(&rd, res); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RTypeInstruction::SubUnsigned => { + let rs = env.read_register(&rs); + let rt = env.read_register(&rt); + let res = { + let res_scratch = env.alloc_scratch(); + let overflow_scratch = env.alloc_scratch(); + let (res, _overflow) = + unsafe { env.sub_witness(&rs, &rt, res_scratch, overflow_scratch) }; + // FIXME: Requires a range check + res + }; + env.write_register(&rd, res); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RTypeInstruction::And => { + let rs = env.read_register(&rs); + let rt = env.read_register(&rt); + let res = { + // FIXME: Constrain + let pos = env.alloc_scratch(); + unsafe { env.and_witness(&rs, &rt, pos) } + }; + env.write_register(&rd, res); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RTypeInstruction::Or => { + let rs = env.read_register(&rs); + let rt = env.read_register(&rt); + let res = { + // FIXME: Constrain + let pos = env.alloc_scratch(); + unsafe { env.or_witness(&rs, &rt, pos) } + }; + env.write_register(&rd, res); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RTypeInstruction::Xor => { + let rs = env.read_register(&rs); + let rt = env.read_register(&rt); + let res = { + // FIXME: Constrain + let pos = env.alloc_scratch(); + unsafe { env.xor_witness(&rs, &rt, pos) } + }; + env.write_register(&rd, res); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RTypeInstruction::Nor => { + let rs = env.read_register(&rs); + let rt = env.read_register(&rt); + let res = { + // FIXME: Constrain + let pos = env.alloc_scratch(); + unsafe { env.nor_witness(&rs, &rt, pos) } + }; + env.write_register(&rd, res); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RTypeInstruction::SetLessThan => { + let rs = env.read_register(&rs); + let rt = env.read_register(&rt); + let res = { + // FIXME: Constrain + let pos = env.alloc_scratch(); + unsafe { env.test_less_than_signed(&rs, &rt, pos) } + }; + env.write_register(&rd, res); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RTypeInstruction::SetLessThanUnsigned => { + let rs = env.read_register(&rs); + let rt = env.read_register(&rt); + let res = { + // FIXME: Constrain + let pos = env.alloc_scratch(); + unsafe { env.test_less_than(&rs, &rt, pos) } + }; + env.write_register(&rd, res); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RTypeInstruction::MultiplyToRegister => { + let rs = env.read_register(&rs); + let rt = env.read_register(&rt); + let res = { + // FIXME: Constrain + let pos = env.alloc_scratch(); + unsafe { env.mul_signed_witness(&rs, &rt, pos) } + }; + env.write_register(&rd, res); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RTypeInstruction::CountLeadingOnes => { + let rs = env.read_register(&rs); + let leading_ones = { + // FIXME: Constrain + let pos = env.alloc_scratch(); + unsafe { env.count_leading_ones(&rs, pos) } + }; + env.write_register(&rd, leading_ones); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RTypeInstruction::CountLeadingZeros => { + let rs = env.read_register(&rs); + let leading_zeros = { + // FIXME: Constrain + let pos = env.alloc_scratch(); + unsafe { env.count_leading_zeros(&rs, pos) } + }; + env.write_register(&rd, leading_zeros); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + }; +} + +pub fn interpret_jtype(env: &mut Env, instr: JTypeInstruction) { + let instruction_pointer = env.get_instruction_pointer(); + let next_instruction_pointer = env.get_next_instruction_pointer(); + let instruction = { + let v0 = env.read_memory(&instruction_pointer); + let v1 = env.read_memory(&(instruction_pointer.clone() + Env::constant(1))); + let v2 = env.read_memory(&(instruction_pointer.clone() + Env::constant(2))); + let v3 = env.read_memory(&(instruction_pointer.clone() + Env::constant(3))); + (v0 * Env::constant(1 << 24)) + + (v1 * Env::constant(1 << 16)) + + (v2 * Env::constant(1 << 8)) + + v3 + }; + let opcode = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 32, 26, pos) } + }; + env.range_check8(&opcode, 6); + + let addr = { + // FIXME: Requires a range check (cannot use range_check_bits here because 26 > 16) + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 26, 0, pos) } + }; + let instruction_pointer_high_bits = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&next_instruction_pointer, 32, 28, pos) } + }; + env.range_check8(&instruction_pointer_high_bits, 4); + + // Check correctness of decomposition of instruction into parts + env.add_constraint(instruction - (opcode * Env::constant(1 << 26) + addr.clone())); + + let target_addr = + (instruction_pointer_high_bits * Env::constant(1 << 28)) + (addr * Env::constant(1 << 2)); + match instr { + JTypeInstruction::Jump => (), + JTypeInstruction::JumpAndLink => { + env.write_register(&Env::constant(31), instruction_pointer + Env::constant(8)); + } + }; + env.set_instruction_pointer(next_instruction_pointer); + env.set_next_instruction_pointer(target_addr); +} + +pub fn interpret_itype(env: &mut Env, instr: ITypeInstruction) { + let instruction_pointer = env.get_instruction_pointer(); + let next_instruction_pointer = env.get_next_instruction_pointer(); + let instruction = { + let v0 = env.read_memory(&instruction_pointer); + let v1 = env.read_memory(&(instruction_pointer.clone() + Env::constant(1))); + let v2 = env.read_memory(&(instruction_pointer.clone() + Env::constant(2))); + let v3 = env.read_memory(&(instruction_pointer.clone() + Env::constant(3))); + (v0 * Env::constant(1 << 24)) + + (v1 * Env::constant(1 << 16)) + + (v2 * Env::constant(1 << 8)) + + v3 + }; + let opcode = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 32, 26, pos) } + }; + env.range_check8(&opcode, 6); + + let rs = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 26, 21, pos) } + }; + env.range_check8(&rs, 5); + + let rt = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 21, 16, pos) } + }; + env.range_check8(&rt, 5); + + let immediate = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 16, 0, pos) } + }; + env.lookup_16bits(&immediate); + + // Check correctness of decomposition of instruction into parts + env.add_constraint( + instruction + - (opcode * Env::constant(1 << 26) + + rs.clone() * Env::constant(1 << 21) + + rt.clone() * Env::constant(1 << 16) + + immediate.clone()), + ); + + match instr { + ITypeInstruction::BranchEq => { + let offset = env.sign_extend(&(immediate * Env::constant(1 << 2)), 18); + let rs = env.read_register(&rs); + let rt = env.read_register(&rt); + let equals = env.equal(&rs, &rt); + let offset = (Env::constant(1) - equals.clone()) * Env::constant(4) + equals * offset; + let addr = { + let res_scratch = env.alloc_scratch(); + let overflow_scratch = env.alloc_scratch(); + let (res, _overflow) = unsafe { + env.add_witness( + &next_instruction_pointer, + &offset, + res_scratch, + overflow_scratch, + ) + }; + // FIXME: Requires a range check + res + }; + env.set_instruction_pointer(next_instruction_pointer); + env.set_next_instruction_pointer(addr); + } + ITypeInstruction::BranchNeq => { + let offset = env.sign_extend(&(immediate * Env::constant(1 << 2)), 18); + let rs = env.read_register(&rs); + let rt = env.read_register(&rt); + let equals = env.equal(&rs, &rt); + let offset = equals.clone() * Env::constant(4) + (Env::constant(1) - equals) * offset; + let addr = { + let res_scratch = env.alloc_scratch(); + let overflow_scratch = env.alloc_scratch(); + let (res, _overflow) = unsafe { + env.add_witness( + &next_instruction_pointer, + &offset, + res_scratch, + overflow_scratch, + ) + }; + // FIXME: Requires a range check + res + }; + env.set_instruction_pointer(next_instruction_pointer); + env.set_next_instruction_pointer(addr); + } + ITypeInstruction::BranchLeqZero => { + let offset = env.sign_extend(&(immediate * Env::constant(1 << 2)), 18); + let rs = env.read_register(&rs); + let less_than_or_equal_to = { + let greater_than_zero = { + // FIXME: Requires constraints + let pos = env.alloc_scratch(); + unsafe { env.test_less_than_signed(&Env::constant(0), &rs, pos) } + }; + Env::constant(1) - greater_than_zero + }; + let offset = (Env::constant(1) - less_than_or_equal_to.clone()) * Env::constant(4) + + less_than_or_equal_to * offset; + let addr = { + let res_scratch = env.alloc_scratch(); + let overflow_scratch = env.alloc_scratch(); + let (res, _overflow) = unsafe { + env.add_witness( + &next_instruction_pointer, + &offset, + res_scratch, + overflow_scratch, + ) + }; + // FIXME: Requires a range check + res + }; + env.set_instruction_pointer(next_instruction_pointer); + env.set_next_instruction_pointer(addr); + } + ITypeInstruction::BranchGtZero => { + let offset = env.sign_extend(&(immediate * Env::constant(1 << 2)), 18); + let rs = env.read_register(&rs); + let less_than = { + // FIXME: Requires constraints + let pos = env.alloc_scratch(); + unsafe { env.test_less_than_signed(&Env::constant(0), &rs, pos) } + }; + let offset = + (Env::constant(1) - less_than.clone()) * Env::constant(4) + less_than * offset; + let addr = { + let res_scratch = env.alloc_scratch(); + let overflow_scratch = env.alloc_scratch(); + let (res, _overflow) = unsafe { + env.add_witness( + &next_instruction_pointer, + &offset, + res_scratch, + overflow_scratch, + ) + }; + // FIXME: Requires a range check + res + }; + env.set_instruction_pointer(next_instruction_pointer); + env.set_next_instruction_pointer(addr); + } + ITypeInstruction::BranchLtZero => { + let offset = env.sign_extend(&(immediate * Env::constant(1 << 2)), 18); + let rs = env.read_register(&rs); + let less_than = { + // FIXME: Requires constraints + let pos = env.alloc_scratch(); + unsafe { env.test_less_than_signed(&rs, &Env::constant(0), pos) } + }; + let offset = + (Env::constant(1) - less_than.clone()) * Env::constant(4) + less_than * offset; + let addr = { + let res_scratch = env.alloc_scratch(); + let overflow_scratch = env.alloc_scratch(); + let (res, _overflow) = unsafe { + env.add_witness( + &next_instruction_pointer, + &offset, + res_scratch, + overflow_scratch, + ) + }; + // FIXME: Requires a range check + res + }; + env.set_instruction_pointer(next_instruction_pointer); + env.set_next_instruction_pointer(addr); + } + ITypeInstruction::BranchGeqZero => { + let offset = env.sign_extend(&(immediate * Env::constant(1 << 2)), 18); + let rs = env.read_register(&rs); + let less_than = { + // FIXME: Requires constraints + let pos = env.alloc_scratch(); + unsafe { env.test_less_than_signed(&rs, &Env::constant(0), pos) } + }; + let offset = + less_than.clone() * Env::constant(4) + (Env::constant(1) - less_than) * offset; + let addr = { + let res_scratch = env.alloc_scratch(); + let overflow_scratch = env.alloc_scratch(); + let (res, _overflow) = unsafe { + env.add_witness( + &next_instruction_pointer, + &offset, + res_scratch, + overflow_scratch, + ) + }; + // FIXME: Requires a range check + res + }; + env.set_instruction_pointer(next_instruction_pointer); + env.set_next_instruction_pointer(addr); + } + ITypeInstruction::AddImmediate => { + let register_rs = env.read_register(&rs); + let offset = env.sign_extend(&immediate, 16); + let res = { + let res_scratch = env.alloc_scratch(); + let overflow_scratch = env.alloc_scratch(); + let (res, _overflow) = unsafe { + env.add_witness(®ister_rs, &offset, res_scratch, overflow_scratch) + }; + // FIXME: Requires a range check + res + }; + env.write_register(&rt, res); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + ITypeInstruction::AddImmediateUnsigned => { + let register_rs = env.read_register(&rs); + let offset = env.sign_extend(&immediate, 16); + let res = { + let res_scratch = env.alloc_scratch(); + let overflow_scratch = env.alloc_scratch(); + let (res, _overflow) = unsafe { + env.add_witness(®ister_rs, &offset, res_scratch, overflow_scratch) + }; + // FIXME: Requires a range check + res + }; + env.write_register(&rt, res); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + ITypeInstruction::SetLessThanImmediate => { + let rs = env.read_register(&rs); + let immediate = env.sign_extend(&immediate, 16); + let res = { + // FIXME: Constrain + let pos = env.alloc_scratch(); + unsafe { env.test_less_than_signed(&rs, &immediate, pos) } + }; + env.write_register(&rt, res); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + ITypeInstruction::SetLessThanImmediateUnsigned => { + let rs = env.read_register(&rs); + let immediate = env.sign_extend(&immediate, 16); + let res = { + // FIXME: Constrain + let pos = env.alloc_scratch(); + unsafe { env.test_less_than(&rs, &immediate, pos) } + }; + env.write_register(&rt, res); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + ITypeInstruction::AndImmediate => { + let rs = env.read_register(&rs); + let res = { + // FIXME: Constraint + let pos = env.alloc_scratch(); + unsafe { env.and_witness(&rs, &immediate, pos) } + }; + env.write_register(&rt, res); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + ITypeInstruction::OrImmediate => { + let rs = env.read_register(&rs); + let res = { + // FIXME: Constraint + let pos = env.alloc_scratch(); + unsafe { env.or_witness(&rs, &immediate, pos) } + }; + env.write_register(&rt, res); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + ITypeInstruction::XorImmediate => { + let rs = env.read_register(&rs); + let res = { + // FIXME: Constraint + let pos = env.alloc_scratch(); + unsafe { env.xor_witness(&rs, &immediate, pos) } + }; + env.write_register(&rt, res); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + ITypeInstruction::LoadUpperImmediate => { + // lui $reg, [most significant 16 bits of immediate] + let immediate_value = immediate * Env::constant(1 << 16); + env.write_register(&rt, immediate_value); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + ITypeInstruction::Load8 => { + let base = env.read_register(&rs); + let dest = rt; + let offset = env.sign_extend(&immediate, 16); + let addr = { + let res_scratch = env.alloc_scratch(); + let overflow_scratch = env.alloc_scratch(); + let (res, _overflow) = + unsafe { env.add_witness(&base, &offset, res_scratch, overflow_scratch) }; + // FIXME: Requires a range check + res + }; + let v0 = env.read_memory(&addr); + let value = env.sign_extend(&v0, 8); + env.write_register(&dest, value); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + ITypeInstruction::Load16 => { + let base = env.read_register(&rs); + let dest = rt; + let offset = env.sign_extend(&immediate, 16); + let addr = { + let res_scratch = env.alloc_scratch(); + let overflow_scratch = env.alloc_scratch(); + let (res, _overflow) = + unsafe { env.add_witness(&base, &offset, res_scratch, overflow_scratch) }; + // FIXME: Requires a range check + res + }; + let v0 = env.read_memory(&addr); + let v1 = env.read_memory(&(addr.clone() + Env::constant(1))); + let value = (v0 * Env::constant(1 << 8)) + v1; + let value = env.sign_extend(&value, 16); + env.write_register(&dest, value); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + ITypeInstruction::Load32 => { + let base = env.read_register(&rs); + let dest = rt; + let offset = env.sign_extend(&immediate, 16); + let addr = { + let res_scratch = env.alloc_scratch(); + let overflow_scratch = env.alloc_scratch(); + let (res, _overflow) = + unsafe { env.add_witness(&base, &offset, res_scratch, overflow_scratch) }; + // FIXME: Requires a range check + res + }; + // We load 4 bytes, i.e. one word. + let v0 = env.read_memory(&addr); + let v1 = env.read_memory(&(addr.clone() + Env::constant(1))); + let v2 = env.read_memory(&(addr.clone() + Env::constant(2))); + let v3 = env.read_memory(&(addr.clone() + Env::constant(3))); + let value = (v0 * Env::constant(1 << 24)) + + (v1 * Env::constant(1 << 16)) + + (v2 * Env::constant(1 << 8)) + + v3; + env.write_register(&dest, value); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + ITypeInstruction::Load8Unsigned => { + let base = env.read_register(&rs); + let dest = rt; + let offset = env.sign_extend(&immediate, 16); + let addr = { + let res_scratch = env.alloc_scratch(); + let overflow_scratch = env.alloc_scratch(); + let (res, _overflow) = + unsafe { env.add_witness(&base, &offset, res_scratch, overflow_scratch) }; + // FIXME: Requires a range check + res + }; + let v0 = env.read_memory(&addr); + let value = v0; + env.write_register(&dest, value); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + ITypeInstruction::Load16Unsigned => { + let base = env.read_register(&rs); + let dest = rt; + let offset = env.sign_extend(&immediate, 16); + let addr = { + let res_scratch = env.alloc_scratch(); + let overflow_scratch = env.alloc_scratch(); + let (res, _overflow) = + unsafe { env.add_witness(&base, &offset, res_scratch, overflow_scratch) }; + // FIXME: Requires a range check + res + }; + let v0 = env.read_memory(&addr); + let v1 = env.read_memory(&(addr.clone() + Env::constant(1))); + let value = v0 * Env::constant(1 << 8) + v1; + env.write_register(&dest, value); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + ITypeInstruction::LoadWordLeft => { + let base = env.read_register(&rs); + let offset = env.sign_extend(&immediate, 16); + let addr = { + let res_scratch = env.alloc_scratch(); + let overflow_scratch = env.alloc_scratch(); + let (res, _overflow) = + unsafe { env.add_witness(&base, &offset, res_scratch, overflow_scratch) }; + // FIXME: Requires a range check + res + }; + + let byte_subaddr = { + // FIXME: Requires a range check + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&addr, 2, 0, pos) } + }; + + let overwrite_3 = env.equal(&byte_subaddr, &Env::constant(0)); + let overwrite_2 = env.equal(&byte_subaddr, &Env::constant(1)) + overwrite_3.clone(); + let overwrite_1 = env.equal(&byte_subaddr, &Env::constant(2)) + overwrite_2.clone(); + let overwrite_0 = env.equal(&byte_subaddr, &Env::constant(3)) + overwrite_1.clone(); + + let m0 = env.read_memory(&addr); + let m1 = env.read_memory(&(addr.clone() + Env::constant(1))); + let m2 = env.read_memory(&(addr.clone() + Env::constant(2))); + let m3 = env.read_memory(&(addr.clone() + Env::constant(3))); + + let [r0, r1, r2, r3] = { + let initial_register_value = env.read_register(&rt); + [ + { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&initial_register_value, 32, 24, pos) } + }, + { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&initial_register_value, 24, 16, pos) } + }, + { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&initial_register_value, 16, 8, pos) } + }, + { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&initial_register_value, 8, 0, pos) } + }, + ] + }; + env.lookup_8bits(&r0); + env.lookup_8bits(&r1); + env.lookup_8bits(&r2); + env.lookup_8bits(&r3); + + let value = { + let value = ((overwrite_0.clone() * m0 + (Env::constant(1) - overwrite_0) * r0) + * Env::constant(1 << 24)) + + ((overwrite_1.clone() * m1 + (Env::constant(1) - overwrite_1) * r1) + * Env::constant(1 << 16)) + + ((overwrite_2.clone() * m2 + (Env::constant(1) - overwrite_2) * r2) + * Env::constant(1 << 8)) + + (overwrite_3.clone() * m3 + (Env::constant(1) - overwrite_3) * r3); + let pos = env.alloc_scratch(); + env.copy(&value, pos) + }; + env.write_register(&rt, value); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + ITypeInstruction::LoadWordRight => { + let base = env.read_register(&rs); + let offset = env.sign_extend(&immediate, 16); + let addr = { + let res_scratch = env.alloc_scratch(); + let overflow_scratch = env.alloc_scratch(); + let (res, _overflow) = + unsafe { env.add_witness(&base, &offset, res_scratch, overflow_scratch) }; + // FIXME: Requires a range check + res + }; + + let byte_subaddr = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&addr, 2, 0, pos) } + }; + env.range_check2(&byte_subaddr); + + let overwrite_0 = env.equal(&byte_subaddr, &Env::constant(3)); + let overwrite_1 = env.equal(&byte_subaddr, &Env::constant(2)) + overwrite_0.clone(); + let overwrite_2 = env.equal(&byte_subaddr, &Env::constant(1)) + overwrite_1.clone(); + let overwrite_3 = env.equal(&byte_subaddr, &Env::constant(0)) + overwrite_2.clone(); + + // The `-3` here feels odd, but simulates the `<< 24` in cannon, and matches the + // behavior defined in the spec. + // See e.g. 'MIPS IV Instruction Set' Rev 3.2, Table A-31 for reference. + let m0 = env.read_memory(&(addr.clone() - Env::constant(3))); + let m1 = env.read_memory(&(addr.clone() - Env::constant(2))); + let m2 = env.read_memory(&(addr.clone() - Env::constant(1))); + let m3 = env.read_memory(&addr); + + let [r0, r1, r2, r3] = { + let initial_register_value = env.read_register(&rt); + [ + { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&initial_register_value, 32, 24, pos) } + }, + { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&initial_register_value, 24, 16, pos) } + }, + { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&initial_register_value, 16, 8, pos) } + }, + { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&initial_register_value, 8, 0, pos) } + }, + ] + }; + env.lookup_8bits(&r0); + env.lookup_8bits(&r1); + env.lookup_8bits(&r2); + env.lookup_8bits(&r3); + + let value = { + let value = ((overwrite_0.clone() * m0 + (Env::constant(1) - overwrite_0) * r0) + * Env::constant(1 << 24)) + + ((overwrite_1.clone() * m1 + (Env::constant(1) - overwrite_1) * r1) + * Env::constant(1 << 16)) + + ((overwrite_2.clone() * m2 + (Env::constant(1) - overwrite_2) * r2) + * Env::constant(1 << 8)) + + (overwrite_3.clone() * m3 + (Env::constant(1) - overwrite_3) * r3); + let pos = env.alloc_scratch(); + env.copy(&value, pos) + }; + env.write_register(&rt, value); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + ITypeInstruction::Store8 => { + let base = env.read_register(&rs); + let offset = env.sign_extend(&immediate, 16); + let addr = { + let res_scratch = env.alloc_scratch(); + let overflow_scratch = env.alloc_scratch(); + let (res, _overflow) = + unsafe { env.add_witness(&base, &offset, res_scratch, overflow_scratch) }; + // FIXME: Requires a range check + res + }; + let value = env.read_register(&rt); + let v0 = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&value, 8, 0, pos) } + }; + env.lookup_8bits(&v0); + + env.write_memory(&addr, v0); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + ITypeInstruction::Store16 => { + let base = env.read_register(&rs); + let offset = env.sign_extend(&immediate, 16); + let addr = { + let res_scratch = env.alloc_scratch(); + let overflow_scratch = env.alloc_scratch(); + let (res, _overflow) = + unsafe { env.add_witness(&base, &offset, res_scratch, overflow_scratch) }; + // FIXME: Requires a range check + res + }; + let value = env.read_register(&rt); + let [v0, v1] = { + [ + { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&value, 16, 8, pos) } + }, + { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&value, 8, 0, pos) } + }, + ] + }; + env.lookup_8bits(&v0); + env.lookup_8bits(&v1); + + env.write_memory(&addr, v0); + env.write_memory(&(addr.clone() + Env::constant(1)), v1); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + ITypeInstruction::Store32 => { + let base = env.read_register(&rs); + let offset = env.sign_extend(&immediate, 16); + let addr = { + let res_scratch = env.alloc_scratch(); + let overflow_scratch = env.alloc_scratch(); + let (res, _overflow) = + unsafe { env.add_witness(&base, &offset, res_scratch, overflow_scratch) }; + // FIXME: Requires a range check + res + }; + let value = env.read_register(&rt); + let [v0, v1, v2, v3] = { + [ + { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&value, 32, 24, pos) } + }, + { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&value, 24, 16, pos) } + }, + { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&value, 16, 8, pos) } + }, + { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&value, 8, 0, pos) } + }, + ] + }; + env.lookup_8bits(&v0); + env.lookup_8bits(&v1); + env.lookup_8bits(&v2); + env.lookup_8bits(&v3); + + // Checking that v is the correct decomposition. + { + let res = value + - v0.clone() * Env::constant(1 << 24) + - v1.clone() * Env::constant(1 << 16) + - v2.clone() * Env::constant(1 << 8) + - v3.clone(); + env.is_zero(&res) + }; + env.write_memory(&addr, v0); + env.write_memory(&(addr.clone() + Env::constant(1)), v1); + env.write_memory(&(addr.clone() + Env::constant(2)), v2); + env.write_memory(&(addr.clone() + Env::constant(3)), v3); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + ITypeInstruction::Store32Conditional => { + let base = env.read_register(&rs); + let offset = env.sign_extend(&immediate, 16); + let addr = { + let res_scratch = env.alloc_scratch(); + let overflow_scratch = env.alloc_scratch(); + let (res, _overflow) = + unsafe { env.add_witness(&base, &offset, res_scratch, overflow_scratch) }; + // FIXME: Requires a range check + res + }; + let value = env.read_register(&rt); + let [v0, v1, v2, v3] = { + [ + { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&value, 32, 24, pos) } + }, + { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&value, 24, 16, pos) } + }, + { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&value, 16, 8, pos) } + }, + { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&value, 8, 0, pos) } + }, + ] + }; + env.lookup_8bits(&v0); + env.lookup_8bits(&v1); + env.lookup_8bits(&v2); + env.lookup_8bits(&v3); + + env.write_memory(&addr, v0); + env.write_memory(&(addr.clone() + Env::constant(1)), v1); + env.write_memory(&(addr.clone() + Env::constant(2)), v2); + env.write_memory(&(addr.clone() + Env::constant(3)), v3); + // Write status flag. + env.write_register(&rt, Env::constant(1)); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + ITypeInstruction::StoreWordLeft => { + let base = env.read_register(&rs); + let offset = env.sign_extend(&immediate, 16); + let addr = { + let res_scratch = env.alloc_scratch(); + let overflow_scratch = env.alloc_scratch(); + let (res, _overflow) = + unsafe { env.add_witness(&base, &offset, res_scratch, overflow_scratch) }; + // FIXME: Requires a range check + res + }; + + let byte_subaddr = { + // FIXME: Requires a range check + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&addr, 2, 0, pos) } + }; + env.range_check2(&byte_subaddr); + + let overwrite_3 = env.equal(&byte_subaddr, &Env::constant(0)); + let overwrite_2 = env.equal(&byte_subaddr, &Env::constant(1)) + overwrite_3.clone(); + let overwrite_1 = env.equal(&byte_subaddr, &Env::constant(2)) + overwrite_2.clone(); + let overwrite_0 = env.equal(&byte_subaddr, &Env::constant(3)) + overwrite_1.clone(); + + let m0 = env.read_memory(&addr); + let m1 = env.read_memory(&(addr.clone() + Env::constant(1))); + let m2 = env.read_memory(&(addr.clone() + Env::constant(2))); + let m3 = env.read_memory(&(addr.clone() + Env::constant(3))); + + let [r0, r1, r2, r3] = { + let initial_register_value = env.read_register(&rt); + [ + { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&initial_register_value, 32, 24, pos) } + }, + { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&initial_register_value, 24, 16, pos) } + }, + { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&initial_register_value, 16, 8, pos) } + }, + { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&initial_register_value, 8, 0, pos) } + }, + ] + }; + env.lookup_8bits(&r0); + env.lookup_8bits(&r1); + env.lookup_8bits(&r2); + env.lookup_8bits(&r3); + + let v0 = { + let pos = env.alloc_scratch(); + env.copy( + &(overwrite_0.clone() * r0 + (Env::constant(1) - overwrite_0) * m0), + pos, + ) + }; + let v1 = { + let pos = env.alloc_scratch(); + env.copy( + &(overwrite_1.clone() * r1 + (Env::constant(1) - overwrite_1) * m1), + pos, + ) + }; + let v2 = { + let pos = env.alloc_scratch(); + env.copy( + &(overwrite_2.clone() * r2 + (Env::constant(1) - overwrite_2) * m2), + pos, + ) + }; + let v3 = { + let pos = env.alloc_scratch(); + env.copy( + &(overwrite_3.clone() * r3 + (Env::constant(1) - overwrite_3) * m3), + pos, + ) + }; + + env.write_memory(&addr, v0); + env.write_memory(&(addr.clone() + Env::constant(1)), v1); + env.write_memory(&(addr.clone() + Env::constant(2)), v2); + env.write_memory(&(addr.clone() + Env::constant(3)), v3); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + ITypeInstruction::StoreWordRight => { + let base = env.read_register(&rs); + let offset = env.sign_extend(&immediate, 16); + let addr = { + let res_scratch = env.alloc_scratch(); + let overflow_scratch = env.alloc_scratch(); + let (res, _overflow) = + unsafe { env.add_witness(&base, &offset, res_scratch, overflow_scratch) }; + // FIXME: Requires a range check + res + }; + + let byte_subaddr = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&addr, 2, 0, pos) } + }; + env.range_check2(&byte_subaddr); + + let overwrite_0 = env.equal(&byte_subaddr, &Env::constant(3)); + let overwrite_1 = env.equal(&byte_subaddr, &Env::constant(2)) + overwrite_0.clone(); + let overwrite_2 = env.equal(&byte_subaddr, &Env::constant(1)) + overwrite_1.clone(); + let overwrite_3 = env.equal(&byte_subaddr, &Env::constant(0)) + overwrite_2.clone(); + + // The `-3` here feels odd, but simulates the `<< 24` in cannon, and matches the + // behavior defined in the spec. + // See e.g. 'MIPS IV Instruction Set' Rev 3.2, Table A-31 for reference. + let m0 = env.read_memory(&(addr.clone() - Env::constant(3))); + let m1 = env.read_memory(&(addr.clone() - Env::constant(2))); + let m2 = env.read_memory(&(addr.clone() - Env::constant(1))); + let m3 = env.read_memory(&addr); + + let [r0, r1, r2, r3] = { + let initial_register_value = env.read_register(&rt); + [ + { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&initial_register_value, 32, 24, pos) } + }, + { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&initial_register_value, 24, 16, pos) } + }, + { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&initial_register_value, 16, 8, pos) } + }, + { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&initial_register_value, 8, 0, pos) } + }, + ] + }; + env.lookup_8bits(&r0); + env.lookup_8bits(&r1); + env.lookup_8bits(&r2); + env.lookup_8bits(&r3); + + let v0 = { + let pos = env.alloc_scratch(); + env.copy( + &(overwrite_0.clone() * r0 + (Env::constant(1) - overwrite_0) * m0), + pos, + ) + }; + let v1 = { + let pos = env.alloc_scratch(); + env.copy( + &(overwrite_1.clone() * r1 + (Env::constant(1) - overwrite_1) * m1), + pos, + ) + }; + let v2 = { + let pos = env.alloc_scratch(); + env.copy( + &(overwrite_2.clone() * r2 + (Env::constant(1) - overwrite_2) * m2), + pos, + ) + }; + let v3 = { + let pos = env.alloc_scratch(); + env.copy( + &(overwrite_3.clone() * r3 + (Env::constant(1) - overwrite_3) * m3), + pos, + ) + }; + + env.write_memory(&(addr.clone() - Env::constant(3)), v0); + env.write_memory(&(addr.clone() - Env::constant(2)), v1); + env.write_memory(&(addr.clone() - Env::constant(1)), v2); + env.write_memory(&addr.clone(), v3); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + } +} + +pub mod debugging { + use serde::{Deserialize, Serialize}; + #[derive(Debug, Clone, Copy, Eq, PartialEq, Default, Serialize, Deserialize)] + pub struct InstructionParts { + pub op_code: u32, + pub rs: u32, + pub rt: u32, + pub rd: u32, + pub shamt: u32, + pub funct: u32, + } + + impl InstructionParts { + pub fn decode(instruction: u32) -> Self { + let op_code = (instruction >> 26) & ((1 << (32 - 26)) - 1); + let rs = (instruction >> 21) & ((1 << (26 - 21)) - 1); + let rt = (instruction >> 16) & ((1 << (21 - 16)) - 1); + let rd = (instruction >> 11) & ((1 << (16 - 11)) - 1); + let shamt = (instruction >> 6) & ((1 << (11 - 6)) - 1); + let funct = instruction & ((1 << 6) - 1); + InstructionParts { + op_code, + rs, + rt, + rd, + shamt, + funct, + } + } + + pub fn encode(self) -> u32 { + (self.op_code << 26) + | (self.rs << 21) + | (self.rt << 16) + | (self.rd << 11) + | (self.shamt << 6) + | self.funct + } + } +} diff --git a/o1vm/src/interpreters/mips/mod.rs b/o1vm/src/interpreters/mips/mod.rs new file mode 100644 index 0000000000..fbfc6ecdfd --- /dev/null +++ b/o1vm/src/interpreters/mips/mod.rs @@ -0,0 +1,29 @@ +//! This module implements a zero-knowledge virtual machine (zkVM) for the MIPS +//! architecture. +//! A zkVM is used by a prover to convince a verifier that the execution trace +//! (also called the `witness`) of a program execution is correct. In the case +//! of this zkVM, we will represent the execution trace by using a set of +//! columns whose values will represent the evaluations of polynomials over a +//! certain pre-defined domain. The correct execution will be proven using a +//! polynomial commitment protocol. The polynomials are described in the +//! structure [crate::interpreters::mips::column::ColumnAlias]. These +//! polynomials will be committed and evaluated at certain points following the +//! polynomial protocol, +//! and it will form the proof of the correct execution that the prover will +//! build and send to the verifier. The corresponding structure is +//! Proof. The prover will start by computing the +//! execution trace using the interpreter implemented in the module +//! [crate::interpreters::mips::interpreter], and the evaluations will be kept +//! in the structure ProofInputs. + +pub mod column; +pub mod constraints; +pub mod interpreter; +pub mod registers; +#[cfg(test)] +pub mod tests; +#[cfg(test)] +pub mod tests_helpers; +pub mod witness; + +pub use interpreter::{ITypeInstruction, Instruction, JTypeInstruction, RTypeInstruction}; diff --git a/o1vm/src/interpreters/mips/registers.rs b/o1vm/src/interpreters/mips/registers.rs new file mode 100644 index 0000000000..0dcfe9516b --- /dev/null +++ b/o1vm/src/interpreters/mips/registers.rs @@ -0,0 +1,92 @@ +use serde::{Deserialize, Serialize}; +use std::ops::{Index, IndexMut}; + +pub const REGISTER_HI: usize = 32; +pub const REGISTER_LO: usize = 33; +pub const REGISTER_CURRENT_IP: usize = 34; +pub const REGISTER_NEXT_IP: usize = 35; +pub const REGISTER_HEAP_POINTER: usize = 36; +pub const REGISTER_PREIMAGE_KEY_START: usize = 37; +pub const REGISTER_PREIMAGE_KEY_END: usize = REGISTER_PREIMAGE_KEY_START + 8 /* 37 + 8 = 45 */; +pub const REGISTER_PREIMAGE_OFFSET: usize = 45; + +pub const NUM_REGISTERS: usize = 46; + +/// This represents the internal state of the virtual machine. +#[derive(Clone, Default, Debug, Serialize, Deserialize)] +pub struct Registers { + pub general_purpose: [T; 32], + pub hi: T, + pub lo: T, + pub current_instruction_pointer: T, + pub next_instruction_pointer: T, + pub heap_pointer: T, + pub preimage_key: [T; 8], + pub preimage_offset: T, +} + +impl Registers { + pub fn iter(&self) -> impl Iterator { + self.general_purpose + .iter() + .chain([ + &self.hi, + &self.lo, + &self.current_instruction_pointer, + &self.next_instruction_pointer, + &self.heap_pointer, + ]) + .chain(self.preimage_key.iter()) + .chain([&self.preimage_offset]) + } +} + +impl Index for Registers { + type Output = T; + + fn index(&self, index: usize) -> &Self::Output { + if index < 32 { + &self.general_purpose[index] + } else if index == REGISTER_HI { + &self.hi + } else if index == REGISTER_LO { + &self.lo + } else if index == REGISTER_CURRENT_IP { + &self.current_instruction_pointer + } else if index == REGISTER_NEXT_IP { + &self.next_instruction_pointer + } else if index == REGISTER_HEAP_POINTER { + &self.heap_pointer + } else if (REGISTER_PREIMAGE_KEY_START..REGISTER_PREIMAGE_KEY_END).contains(&index) { + &self.preimage_key[index - REGISTER_PREIMAGE_KEY_START] + } else if index == REGISTER_PREIMAGE_OFFSET { + &self.preimage_offset + } else { + panic!("Index out of bounds"); + } + } +} + +impl IndexMut for Registers { + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + if index < 32 { + &mut self.general_purpose[index] + } else if index == REGISTER_HI { + &mut self.hi + } else if index == REGISTER_LO { + &mut self.lo + } else if index == REGISTER_CURRENT_IP { + &mut self.current_instruction_pointer + } else if index == REGISTER_NEXT_IP { + &mut self.next_instruction_pointer + } else if index == REGISTER_HEAP_POINTER { + &mut self.heap_pointer + } else if (REGISTER_PREIMAGE_KEY_START..REGISTER_PREIMAGE_KEY_END).contains(&index) { + &mut self.preimage_key[index - REGISTER_PREIMAGE_KEY_START] + } else if index == REGISTER_PREIMAGE_OFFSET { + &mut self.preimage_offset + } else { + panic!("Index out of bounds"); + } + } +} diff --git a/o1vm/src/interpreters/mips/tests.rs b/o1vm/src/interpreters/mips/tests.rs new file mode 100644 index 0000000000..1b212077d1 --- /dev/null +++ b/o1vm/src/interpreters/mips/tests.rs @@ -0,0 +1,351 @@ +// Here live the unit tests for the MIPS instructions +use crate::{ + interpreters::mips::{ + column::{N_MIPS_REL_COLS, N_MIPS_SEL_COLS}, + constraints, + interpreter::debugging::InstructionParts, + tests_helpers::*, + ITypeInstruction, JTypeInstruction, RTypeInstruction, + }, + preimage_oracle::PreImageOracleT, +}; +use kimchi::o1_utils; +use mina_curves::pasta::Fp; +use rand::Rng; +use strum::{EnumCount, IntoEnumIterator}; + +use super::Instruction; + +pub(crate) fn sign_extend(x: u32, bitlength: u32) -> u32 { + let high_bit = (x >> (bitlength - 1)) & 1; + high_bit * (((1 << (32 - bitlength)) - 1) << bitlength) + x +} + +pub(crate) fn bitmask(x: u32, highest_bit: u32, lowest_bit: u32) -> u32 { + let res = (x >> lowest_bit) as u64 & (2u64.pow(highest_bit - lowest_bit) - 1); + res as u32 +} + +#[test] +fn test_sext() { + assert_eq!(sign_extend(0b1001_0110, 16), 0b1001_0110); + assert_eq!( + sign_extend(0b1001_0110_0000_0000, 16), + 0b1111_1111_1111_1111_1001_0110_0000_0000 + ); +} + +#[test] +fn test_bitmask() { + assert_eq!(bitmask(0xaf, 8, 0), 0xaf); + assert_eq!(bitmask(0x3671e4cb, 32, 0), 0x3671e4cb); +} + +#[test] +fn test_on_disk_preimage_can_read_file() { + let mut rng = o1_utils::tests::make_test_rng(None); + let mut dummy_env = dummy_env(&mut rng); + let preimage_key_u8: [u8; 32] = [ + 0x02, 0x21, 0x07, 0x30, 0x78, 0x79, 0x25, 0x85, 0x77, 0x23, 0x0c, 0x5a, 0xa2, 0xf9, 0x05, + 0x67, 0xbd, 0xa4, 0x08, 0x77, 0xa7, 0xe8, 0x5d, 0xce, 0xb6, 0xff, 0x1f, 0x37, 0x48, 0x0f, + 0xef, 0x3d, + ]; + let preimage = dummy_env.preimage_oracle.get_preimage(preimage_key_u8); + let bytes = preimage.get(); + // Number of bytes inside the corresponding file (preimage) + assert_eq!(bytes.len(), 358); +} + +#[test] +fn test_all_instructions_have_a_usize_representation_smaller_than_the_number_of_selectors() { + Instruction::iter().for_each(|i| { + assert!( + usize::from(i) - N_MIPS_REL_COLS < N_MIPS_SEL_COLS, + "Instruction {:?} has a usize representation larger than the number of selectors ({})", + i, + usize::from(i) - N_MIPS_REL_COLS + ); + }); +} +mod rtype { + + use super::*; + use crate::interpreters::mips::{interpreter::interpret_rtype, RTypeInstruction}; + + #[test] + fn test_unit_syscall_read_preimage() { + let mut rng = o1_utils::tests::make_test_rng(None); + let mut dummy_env = dummy_env(&mut rng); + // Instruction: syscall (Read 5) + // Set preimage key + let preimage_key = [ + 0x02, 0x21, 0x07, 0x30, 0x78, 0x79, 0x25, 0x85, 0x77, 0x23, 0x0c, 0x5a, 0xa2, 0xf9, + 0x05, 0x67, 0xbd, 0xa4, 0x08, 0x77, 0xa7, 0xe8, 0x5d, 0xce, 0xb6, 0xff, 0x1f, 0x37, + 0x48, 0x0f, 0xef, 0x3d, + ]; + let chunks = preimage_key + .chunks(4) + .map(|chunk| { + ((chunk[0] as u32) << 24) + + ((chunk[1] as u32) << 16) + + ((chunk[2] as u32) << 8) + + (chunk[3] as u32) + }) + .collect::>(); + dummy_env.registers.preimage_key = std::array::from_fn(|i| chunks[i]); + + // The whole preimage + let preimage = dummy_env.preimage_oracle.get_preimage(preimage_key).get(); + + // Total number of bytes that need to be processed (includes length) + let total_length = 8 + preimage.len() as u32; + + // At first, offset is 0 + + // Set a random address for register 5 that might not be aligned + let addr = rng.gen_range(100..200); + dummy_env.registers[5] = addr; + + // Read oracle until the totality of the preimage is reached + // NOTE: the termination condition is not + // `while dummy_env.preimage_bytes_read < preimage.len()` + // because the interpreter sets it back to 0 when the preimage + // is read fully and the Keccak process is triggered (meaning + // that condition would generate an infinite loop instead) + while dummy_env.registers.preimage_offset < total_length { + dummy_env.reset_scratch_state(); + dummy_env.reset_scratch_state_inverse(); + + // Set maximum number of bytes to read in this call + dummy_env.registers[6] = rng.gen_range(1..=4); + + interpret_rtype(&mut dummy_env, RTypeInstruction::SyscallReadPreimage); + + // Update the address to store the next bytes with the offset + dummy_env.registers[5] = addr + dummy_env.registers.preimage_offset; + } + + // Number of bytes inside the corresponding file (preimage) + // This should be the length read from the oracle in the first 8 bytes + assert_eq!(dummy_env.registers.preimage_offset, total_length); + + // Check the content of memory addresses corresponds to the oracle + + // The first 8 bytes are the length of the preimage + let length_byte = u64::to_be_bytes(preimage.len() as u64); + for (i, b) in length_byte.iter().enumerate() { + assert_eq!( + dummy_env.memory[0].1[i + addr as usize], + *b, + "{}-th length byte does not match", + i + ); + } + // Check that the preimage bytes are stored afterwards in the memory + for (i, b) in preimage.iter().enumerate() { + assert_eq!( + dummy_env.memory[0].1[i + addr as usize + 8], + *b, + "{}-th preimage byte does not match", + i + ); + } + } + + #[test] + fn test_unit_sub_instruction() { + let mut rng = o1_utils::tests::make_test_rng(None); + // We only care about instruction parts and instruction pointer + let mut dummy_env = dummy_env(&mut rng); + // FIXME: at the moment, we do not support writing and reading into the + // same register + // reg_dst <- reg_src - reg_tar + let reg_src = 1; + let reg_dst = 2; + let reg_tar = 3; + // Instruction: 0b00000000001000100001100000100010 sub $at, $at, $at + write_instruction( + &mut dummy_env, + InstructionParts { + op_code: 0b000000, + rs: reg_src as u32, // source register + rt: reg_tar as u32, // target register + rd: reg_dst as u32, // destination register + shamt: 0b00000, + funct: 0b100010, + }, + ); + let (exp_res, _underflow) = + dummy_env.registers[reg_src].overflowing_sub(dummy_env.registers[reg_tar]); + interpret_rtype(&mut dummy_env, RTypeInstruction::Sub); + assert_eq!(dummy_env.registers.general_purpose[reg_dst], exp_res); + } +} + +mod itype { + use super::*; + use crate::interpreters::mips::{interpreter::interpret_itype, ITypeInstruction}; + + #[test] + fn test_unit_addi_instruction() { + let mut rng = o1_utils::tests::make_test_rng(None); + // We only care about instruction parts and instruction pointer + let mut dummy_env = dummy_env(&mut rng); + // Instruction: 0b10001111101001000000000000000000 addi a1,sp,4 + write_instruction( + &mut dummy_env, + InstructionParts { + op_code: 0b000010, + rs: 0b11101, + rt: 0b00101, + rd: 0b00000, + shamt: 0b00000, + funct: 0b000100, + }, + ); + interpret_itype(&mut dummy_env, ITypeInstruction::AddImmediate); + assert_eq!( + dummy_env.registers.general_purpose[5], + dummy_env.registers.general_purpose[29] + 4 + ); + } + + #[test] + fn test_unit_addiu_instruction() { + let mut rng = o1_utils::tests::make_test_rng(None); + // We only care about instruction parts and instruction pointer + let mut dummy_env = dummy_env(&mut rng); + // FIXME: at the moment, we do not support writing and reading into the + // same register + let reg_src = 1; + let reg_dest = 2; + // Instruction: 0b00100100001000010110110011101000 + // addiu $at, $at, 27880 + write_instruction( + &mut dummy_env, + InstructionParts { + op_code: 0b001001, + rs: reg_src, // source register + rt: reg_dest, // destination register + // The rest is the immediate value + rd: 0b01101, + shamt: 0b10011, + funct: 0b101000, + }, + ); + let exp_res = dummy_env.registers[reg_src as usize] + 27880; + interpret_itype(&mut dummy_env, ITypeInstruction::AddImmediateUnsigned); + assert_eq!( + dummy_env.registers.general_purpose[reg_dest as usize], + exp_res + ); + } + + #[test] + fn test_unit_lui_instruction() { + let mut rng = o1_utils::tests::make_test_rng(None); + // We only care about instruction parts and instruction pointer + let mut dummy_env = dummy_env(&mut rng); + // Instruction: 0b00111100000000010000000000001010 + // lui at, 0xa + write_instruction( + &mut dummy_env, + InstructionParts { + op_code: 0b000010, + rs: 0b00000, + rt: 0b00001, + rd: 0b00000, + shamt: 0b00000, + funct: 0b001010, + }, + ); + interpret_itype(&mut dummy_env, ITypeInstruction::LoadUpperImmediate); + assert_eq!(dummy_env.registers.general_purpose[1], 0xa0000); + } + + #[test] + fn test_unit_load16_instruction() { + let mut rng = o1_utils::tests::make_test_rng(None); + // lh instruction + let mut dummy_env = dummy_env(&mut rng); + // Instruction: 0b100001 11101 00100 00000 00000 000000 lh $a0, 0(29) a0 = 4 + // Random address in SP Address has only one index + + let addr: u32 = rng.gen_range(0u32..100u32); + let aligned_addr: u32 = (addr / 4) * 4; + dummy_env.registers[29] = aligned_addr; + let mem = &dummy_env.memory[0]; + let mem = &mem.1; + let v0 = mem[aligned_addr as usize]; + let v1 = mem[(aligned_addr + 1) as usize]; + let v = ((v0 as u32) << 8) + (v1 as u32); + let high_bit = (v >> 15) & 1; + let exp_v = high_bit * (((1 << 16) - 1) << 16) + v; + write_instruction( + &mut dummy_env, + InstructionParts { + op_code: 0b100001, + rs: 0b11101, + rt: 0b00100, + rd: 0b00000, + shamt: 0b00000, + funct: 0b000000, + }, + ); + interpret_itype(&mut dummy_env, ITypeInstruction::Load16); + assert_eq!(dummy_env.registers.general_purpose[4], exp_v); + } + + #[test] + fn test_unit_load32_instruction() { + let mut rng = o1_utils::tests::make_test_rng(None); + // lw instruction + let mut dummy_env = dummy_env(&mut rng); + // Instruction: 0b10001111101001000000000000000000 lw $a0, 0(29) a0 = 4 + // Random address in SP Address has only one index + + let addr: u32 = rng.gen_range(0u32..100u32); + let aligned_addr: u32 = (addr / 4) * 4; + dummy_env.registers[29] = aligned_addr; + let mem = &dummy_env.memory[0]; + let mem = &mem.1; + let v0 = mem[aligned_addr as usize]; + let v1 = mem[(aligned_addr + 1) as usize]; + let v2 = mem[(aligned_addr + 2) as usize]; + let v3 = mem[(aligned_addr + 3) as usize]; + let exp_v = ((v0 as u32) << 24) + ((v1 as u32) << 16) + ((v2 as u32) << 8) + (v3 as u32); + write_instruction( + &mut dummy_env, + InstructionParts { + op_code: 0b100011, + rs: 0b11101, + rt: 0b00100, + rd: 0b00000, + shamt: 0b00000, + funct: 0b000000, + }, + ); + interpret_itype(&mut dummy_env, ITypeInstruction::Load32); + assert_eq!(dummy_env.registers.general_purpose[4], exp_v); + } +} + +#[test] +// Sanity check that we have as many selector as we have instructions +fn test_regression_selectors_for_instructions() { + let mips_con_env = constraints::Env::::default(); + let constraints = mips_con_env.get_selector_constraints(); + assert_eq!( + // We substract 1 as we have one boolean check per sel + // and 1 constraint to check that one and only one + // sel is activated + constraints.len() - 1, + // We could use N_MIPS_SEL_COLS, but sanity check in case this value is + // changed. + RTypeInstruction::COUNT + JTypeInstruction::COUNT + ITypeInstruction::COUNT + ); + // All instructions are degree 1 or 2. + constraints + .iter() + .for_each(|c| assert!(c.degree(1, 0) == 2 || c.degree(1, 0) == 1)); +} diff --git a/o1vm/src/interpreters/mips/tests_helpers.rs b/o1vm/src/interpreters/mips/tests_helpers.rs new file mode 100644 index 0000000000..1682196588 --- /dev/null +++ b/o1vm/src/interpreters/mips/tests_helpers.rs @@ -0,0 +1,128 @@ +use crate::{ + cannon::{Hint, Preimage, PAGE_ADDRESS_MASK, PAGE_ADDRESS_SIZE, PAGE_SIZE}, + interpreters::mips::{ + interpreter::{debugging::InstructionParts, InterpreterEnv}, + registers::Registers, + witness::{Env as WEnv, SyscallEnv}, + }, + preimage_oracle::PreImageOracleT, +}; +use rand::{CryptoRng, Rng, RngCore}; +use std::{fs, path::PathBuf}; + +// FIXME: we should parametrize the tests with different fields. +use ark_bn254::Fr as Fp; + +use super::column::{SCRATCH_SIZE, SCRATCH_SIZE_INVERSE}; + +const PAGE_INDEX_EXECUTABLE_MEMORY: u32 = 1; + +pub(crate) struct OnDiskPreImageOracle; + +impl PreImageOracleT for OnDiskPreImageOracle { + fn get_preimage(&mut self, key: [u8; 32]) -> Preimage { + let key_s = hex::encode(key); + let full_path = format!("resources/tests/0x{key_s}.txt"); + let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + d.push(full_path); + let contents = fs::read_to_string(d).expect("Should have been able to read the file"); + // Decode String (ASCII) as Vec (hexadecimal bytes) + let bytes = hex::decode(contents) + .expect("Should have been able to decode the file as hexadecimal bytes"); + Preimage::create(bytes) + } + + fn hint(&mut self, _hint: Hint) {} +} + +pub(crate) fn dummy_env(rng: &mut RNG) -> WEnv +where + RNG: RngCore + CryptoRng, +{ + let dummy_preimage_oracle = OnDiskPreImageOracle; + let mut env = WEnv { + // Set it to 2 to run 1 instruction that access registers if + instruction_counter: 2, + // Only 8kb of memory (two PAGE_ADDRESS_SIZE) + memory: vec![ + // Read/write memory + // Initializing with random data + ( + 0, + (0..PAGE_SIZE).map(|_| rng.gen_range(0u8..=255)).collect(), + ), + // Executable memory. Allocating 4 * 4kB + (PAGE_INDEX_EXECUTABLE_MEMORY, vec![0; PAGE_SIZE as usize]), + ( + PAGE_INDEX_EXECUTABLE_MEMORY + 1, + vec![0; PAGE_SIZE as usize], + ), + ( + PAGE_INDEX_EXECUTABLE_MEMORY + 2, + vec![0; PAGE_SIZE as usize], + ), + ( + PAGE_INDEX_EXECUTABLE_MEMORY + 3, + vec![0; PAGE_SIZE as usize], + ), + ], + last_memory_accesses: [0; 3], + memory_write_index: vec![ + // Read/write memory + (0, vec![0; PAGE_SIZE as usize]), + // Executable memory. Allocating 4 * 4kB + (PAGE_INDEX_EXECUTABLE_MEMORY, vec![0; PAGE_SIZE as usize]), + ( + PAGE_INDEX_EXECUTABLE_MEMORY + 1, + vec![0; PAGE_SIZE as usize], + ), + ( + PAGE_INDEX_EXECUTABLE_MEMORY + 2, + vec![0; PAGE_SIZE as usize], + ), + ( + PAGE_INDEX_EXECUTABLE_MEMORY + 3, + vec![0; PAGE_SIZE as usize], + ), + ], + last_memory_write_index_accesses: [0; 3], + registers: Registers::default(), + registers_write_index: Registers::default(), + scratch_state_idx: 0, + scratch_state_idx_inverse: 0, + scratch_state: [Fp::from(0); SCRATCH_SIZE], + scratch_state_inverse: [Fp::from(0); SCRATCH_SIZE_INVERSE], + selector: crate::interpreters::mips::column::N_MIPS_SEL_COLS, + halt: false, + // Keccak related + syscall_env: SyscallEnv::default(), + preimage: None, + preimage_oracle: dummy_preimage_oracle, + preimage_bytes_read: 0, + preimage_key: None, + keccak_env: None, + hash_counter: 0, + }; + // Initialize general purpose registers with random values + for reg in env.registers.general_purpose.iter_mut() { + *reg = rng.gen_range(0u32..=u32::MAX); + } + env.registers.current_instruction_pointer = PAGE_INDEX_EXECUTABLE_MEMORY * PAGE_SIZE; + env.registers.next_instruction_pointer = env.registers.current_instruction_pointer + 4; + env +} + +// Write the instruction to the location of the instruction pointer. +pub(crate) fn write_instruction( + env: &mut WEnv, + instruction_parts: InstructionParts, +) { + let instr = instruction_parts.encode(); + let instr_pointer: u32 = env.get_instruction_pointer().try_into().unwrap(); + let page = instr_pointer >> PAGE_ADDRESS_SIZE; + let page_address = (instr_pointer & PAGE_ADDRESS_MASK) as usize; + env.memory[page as usize].1[page_address] = ((instr >> 24) & 0xFF) as u8; + env.memory[page as usize].1[page_address + 1] = ((instr >> 16) & 0xFF) as u8; + env.memory[page as usize].1[page_address + 2] = ((instr >> 8) & 0xFF) as u8; + env.memory[page as usize].1[page_address + 3] = (instr & 0xFF) as u8; +} diff --git a/o1vm/src/interpreters/mips/witness.rs b/o1vm/src/interpreters/mips/witness.rs new file mode 100644 index 0000000000..991b946c8d --- /dev/null +++ b/o1vm/src/interpreters/mips/witness.rs @@ -0,0 +1,1298 @@ +use super::column::{N_MIPS_SEL_COLS, SCRATCH_SIZE, SCRATCH_SIZE_INVERSE}; +use crate::{ + cannon::{ + Hint, Meta, Page, Start, State, StepFrequency, VmConfiguration, PAGE_ADDRESS_MASK, + PAGE_ADDRESS_SIZE, PAGE_SIZE, + }, + interpreters::{ + keccak::environment::KeccakEnv, + mips::{ + column::{ + ColumnAlias as Column, MIPS_BYTE_COUNTER_OFF, MIPS_CHUNK_BYTES_LEN, + MIPS_END_OF_PREIMAGE_OFF, MIPS_HASH_COUNTER_OFF, MIPS_HAS_N_BYTES_OFF, + MIPS_LENGTH_BYTES_OFF, MIPS_NUM_BYTES_READ_OFF, MIPS_PREIMAGE_BYTES_OFF, + MIPS_PREIMAGE_CHUNK_OFF, MIPS_PREIMAGE_KEY, + }, + interpreter::{ + self, ITypeInstruction, Instruction, InterpreterEnv, JTypeInstruction, + RTypeInstruction, + }, + registers::Registers, + }, + }, + lookups::Lookup, + preimage_oracle::PreImageOracleT, + utils::memory_size, +}; +use ark_ff::Field; +use core::panic; +use kimchi::o1_utils::Two; +use log::{debug, info}; +use std::{ + array, + fs::File, + io::{BufWriter, Write}, +}; + +// TODO: do we want to be more restrictive and refer to the number of accesses +// to the SAME register/memory addrss? + +/// Maximum number of register accesses per instruction (based on demo) +pub const MAX_NB_REG_ACC: u64 = 7; +/// Maximum number of memory accesses per instruction (based on demo) +pub const MAX_NB_MEM_ACC: u64 = 12; +/// Maximum number of memory or register accesses per instruction +pub const MAX_ACC: u64 = MAX_NB_REG_ACC + MAX_NB_MEM_ACC; + +pub const NUM_GLOBAL_LOOKUP_TERMS: usize = 1; +pub const NUM_DECODING_LOOKUP_TERMS: usize = 2; +pub const NUM_INSTRUCTION_LOOKUP_TERMS: usize = 5; +pub const NUM_LOOKUP_TERMS: usize = + NUM_GLOBAL_LOOKUP_TERMS + NUM_DECODING_LOOKUP_TERMS + NUM_INSTRUCTION_LOOKUP_TERMS; +// TODO: Delete and use a vector instead + +#[derive(Clone, Default)] +pub struct SyscallEnv { + pub last_hint: Option>, +} + +impl SyscallEnv { + pub fn create(state: &State) -> Self { + SyscallEnv { + last_hint: state.last_hint.clone(), + } + } +} + +/// This structure represents the environment the virtual machine state will use +/// to transition. This environment will be used by the interpreter. The virtual +/// machine has access to its internal state and some external memory. In +/// addition to that, it has access to the environment of the Keccak interpreter +/// that is used to verify the preimage requested during the execution. +pub struct Env { + pub instruction_counter: u64, + pub memory: Vec<(u32, Vec)>, + pub last_memory_accesses: [usize; 3], + pub memory_write_index: Vec<(u32, Vec)>, + pub last_memory_write_index_accesses: [usize; 3], + pub registers: Registers, + pub registers_write_index: Registers, + pub scratch_state_idx: usize, + pub scratch_state_idx_inverse: usize, + pub scratch_state: [Fp; SCRATCH_SIZE], + pub scratch_state_inverse: [Fp; SCRATCH_SIZE_INVERSE], + pub halt: bool, + pub syscall_env: SyscallEnv, + pub selector: usize, + pub preimage_oracle: PreImageOracle, + pub preimage: Option>, + pub preimage_bytes_read: u64, + pub preimage_key: Option<[u8; 32]>, + pub keccak_env: Option>, + pub hash_counter: u64, +} + +fn fresh_scratch_state() -> [Fp; N] { + array::from_fn(|_| Fp::zero()) +} + +impl InterpreterEnv for Env { + type Position = Column; + + fn alloc_scratch(&mut self) -> Self::Position { + let scratch_idx = self.scratch_state_idx; + self.scratch_state_idx += 1; + Column::ScratchState(scratch_idx) + } + + fn alloc_scratch_inverse(&mut self) -> Self::Position { + let scratch_idx = self.scratch_state_idx_inverse; + self.scratch_state_idx_inverse += 1; + Column::ScratchStateInverse(scratch_idx) + } + + type Variable = u64; + + fn variable(&self, _column: Self::Position) -> Self::Variable { + todo!() + } + + fn add_constraint(&mut self, _assert_equals_zero: Self::Variable) { + // No-op for witness + // Do not assert that _assert_equals_zero is zero here! + // Some variables may have placeholders that do not faithfully + // represent the underlying values. + } + + fn activate_selector(&mut self, instruction: Instruction) { + self.selector = instruction.into(); + } + + fn check_is_zero(assert_equals_zero: &Self::Variable) { + assert_eq!(*assert_equals_zero, 0); + } + + fn check_equal(x: &Self::Variable, y: &Self::Variable) { + assert_eq!(*x, *y); + } + + fn check_boolean(x: &Self::Variable) { + if !(*x == 0 || *x == 1) { + panic!("The value {} is not a boolean", *x); + } + } + + fn add_lookup(&mut self, _lookup: Lookup) { + // No-op, constraints only + // TODO: keep track of multiplicities of fixed tables here as in Keccak? + } + + fn instruction_counter(&self) -> Self::Variable { + self.instruction_counter + } + + fn increase_instruction_counter(&mut self) { + self.instruction_counter += 1; + } + + unsafe fn fetch_register( + &mut self, + idx: &Self::Variable, + output: Self::Position, + ) -> Self::Variable { + let res = self.registers[*idx as usize] as u64; + self.write_column(output, res); + res + } + + unsafe fn push_register_if( + &mut self, + idx: &Self::Variable, + value: Self::Variable, + if_is_true: &Self::Variable, + ) { + let value: u32 = value.try_into().unwrap(); + if *if_is_true == 1 { + self.registers[*idx as usize] = value + } else if *if_is_true == 0 { + // No-op + } else { + panic!("Bad value for flag in push_register: {}", *if_is_true); + } + } + + unsafe fn fetch_register_access( + &mut self, + idx: &Self::Variable, + output: Self::Position, + ) -> Self::Variable { + let res = self.registers_write_index[*idx as usize]; + self.write_column(output, res); + res + } + + unsafe fn push_register_access_if( + &mut self, + idx: &Self::Variable, + value: Self::Variable, + if_is_true: &Self::Variable, + ) { + if *if_is_true == 1 { + self.registers_write_index[*idx as usize] = value + } else if *if_is_true == 0 { + // No-op + } else { + panic!("Bad value for flag in push_register: {}", *if_is_true); + } + } + + unsafe fn fetch_memory( + &mut self, + addr: &Self::Variable, + output: Self::Position, + ) -> Self::Variable { + let addr: u32 = (*addr).try_into().unwrap(); + let page = addr >> PAGE_ADDRESS_SIZE; + let page_address = (addr & PAGE_ADDRESS_MASK) as usize; + let memory_page_idx = self.get_memory_page_index(page); + let value = self.memory[memory_page_idx].1[page_address]; + self.write_column(output, value.into()); + value.into() + } + + unsafe fn push_memory(&mut self, addr: &Self::Variable, value: Self::Variable) { + let addr: u32 = (*addr).try_into().unwrap(); + let page = addr >> PAGE_ADDRESS_SIZE; + let page_address = (addr & PAGE_ADDRESS_MASK) as usize; + let memory_page_idx = self.get_memory_page_index(page); + self.memory[memory_page_idx].1[page_address] = + value.try_into().expect("push_memory values fit in a u8"); + } + + unsafe fn fetch_memory_access( + &mut self, + addr: &Self::Variable, + output: Self::Position, + ) -> Self::Variable { + let addr: u32 = (*addr).try_into().unwrap(); + let page = addr >> PAGE_ADDRESS_SIZE; + let page_address = (addr & PAGE_ADDRESS_MASK) as usize; + let memory_write_index_page_idx = self.get_memory_access_page_index(page); + let value = self.memory_write_index[memory_write_index_page_idx].1[page_address]; + self.write_column(output, value); + value + } + + unsafe fn push_memory_access(&mut self, addr: &Self::Variable, value: Self::Variable) { + let addr = *addr as u32; + let page = addr >> PAGE_ADDRESS_SIZE; + let page_address = (addr & PAGE_ADDRESS_MASK) as usize; + let memory_write_index_page_idx = self.get_memory_access_page_index(page); + self.memory_write_index[memory_write_index_page_idx].1[page_address] = value; + } + + fn constant(x: u32) -> Self::Variable { + x as u64 + } + + unsafe fn bitmask( + &mut self, + x: &Self::Variable, + highest_bit: u32, + lowest_bit: u32, + position: Self::Position, + ) -> Self::Variable { + let x: u32 = (*x).try_into().unwrap(); + let res = (x >> lowest_bit) & ((1 << (highest_bit - lowest_bit)) - 1); + let res = res as u64; + self.write_column(position, res); + res + } + + unsafe fn shift_left( + &mut self, + x: &Self::Variable, + by: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + let x: u32 = (*x).try_into().unwrap(); + let by: u32 = (*by).try_into().unwrap(); + let res = x << by; + let res = res as u64; + self.write_column(position, res); + res + } + + unsafe fn shift_right( + &mut self, + x: &Self::Variable, + by: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + let x: u32 = (*x).try_into().unwrap(); + let by: u32 = (*by).try_into().unwrap(); + let res = x >> by; + let res = res as u64; + self.write_column(position, res); + res + } + + unsafe fn shift_right_arithmetic( + &mut self, + x: &Self::Variable, + by: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + let x: u32 = (*x).try_into().unwrap(); + let by: u32 = (*by).try_into().unwrap(); + let res = ((x as i32) >> by) as u32; + let res = res as u64; + self.write_column(position, res); + res + } + + unsafe fn test_zero(&mut self, x: &Self::Variable, position: Self::Position) -> Self::Variable { + let res = if *x == 0 { 1 } else { 0 }; + self.write_column(position, res); + res + } + + fn is_zero(&mut self, x: &Self::Variable) -> Self::Variable { + // write the result + let res = { + let pos = self.alloc_scratch(); + unsafe { self.test_zero(x, pos) } + }; + // write the non deterministic advice inv_or_zero + let pos = self.alloc_scratch_inverse(); + if *x == 0 { + self.write_field_column(pos, Fp::zero()); + } else { + self.write_field_column(pos, Fp::from(*x)); + }; + // return the result + res + } + + fn equal(&mut self, x: &Self::Variable, y: &Self::Variable) -> Self::Variable { + // We replicate is_zero(x-y), but working on field elt, + // to avoid subtraction overflow in the witness interpreter for u32 + let to_zero_test = Fp::from(*x) - Fp::from(*y); + let res = { + let pos = self.alloc_scratch(); + let is_zero: u64 = if to_zero_test == Fp::zero() { 1 } else { 0 }; + self.write_column(pos, is_zero); + is_zero + }; + let pos = self.alloc_scratch_inverse(); + if to_zero_test == Fp::zero() { + self.write_field_column(pos, Fp::zero()); + } else { + self.write_field_column(pos, to_zero_test); + }; + res + } + + unsafe fn test_less_than( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + let x: u32 = (*x).try_into().unwrap(); + let y: u32 = (*y).try_into().unwrap(); + let res = if x < y { 1 } else { 0 }; + let res = res as u64; + self.write_column(position, res); + res + } + + unsafe fn test_less_than_signed( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + let x: u32 = (*x).try_into().unwrap(); + let y: u32 = (*y).try_into().unwrap(); + let res = if (x as i32) < (y as i32) { 1 } else { 0 }; + let res = res as u64; + self.write_column(position, res); + res + } + + unsafe fn and_witness( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + let x: u32 = (*x).try_into().unwrap(); + let y: u32 = (*y).try_into().unwrap(); + let res = x & y; + let res = res as u64; + self.write_column(position, res); + res + } + + unsafe fn nor_witness( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + let x: u32 = (*x).try_into().unwrap(); + let y: u32 = (*y).try_into().unwrap(); + let res = !(x | y); + let res = res as u64; + self.write_column(position, res); + res + } + + unsafe fn or_witness( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + let x: u32 = (*x).try_into().unwrap(); + let y: u32 = (*y).try_into().unwrap(); + let res = x | y; + let res = res as u64; + self.write_column(position, res); + res + } + + unsafe fn xor_witness( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + let x: u32 = (*x).try_into().unwrap(); + let y: u32 = (*y).try_into().unwrap(); + let res = x ^ y; + let res = res as u64; + self.write_column(position, res); + res + } + + unsafe fn add_witness( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + out_position: Self::Position, + overflow_position: Self::Position, + ) -> (Self::Variable, Self::Variable) { + let x: u32 = (*x).try_into().unwrap(); + let y: u32 = (*y).try_into().unwrap(); + // https://doc.rust-lang.org/std/primitive.u32.html#method.overflowing_add + let res = x.overflowing_add(y); + let (res_, overflow) = (res.0 as u64, res.1 as u64); + self.write_column(out_position, res_); + self.write_column(overflow_position, overflow); + (res_, overflow) + } + + unsafe fn sub_witness( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + out_position: Self::Position, + underflow_position: Self::Position, + ) -> (Self::Variable, Self::Variable) { + let x: u32 = (*x).try_into().unwrap(); + let y: u32 = (*y).try_into().unwrap(); + // https://doc.rust-lang.org/std/primitive.u32.html#method.overflowing_sub + let res = x.overflowing_sub(y); + let (res_, underflow) = (res.0 as u64, res.1 as u64); + self.write_column(out_position, res_); + self.write_column(underflow_position, underflow); + (res_, underflow) + } + + unsafe fn mul_signed_witness( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + let x: u32 = (*x).try_into().unwrap(); + let y: u32 = (*y).try_into().unwrap(); + let res = ((x as i32) * (y as i32)) as u32; + let res = res as u64; + self.write_column(position, res); + res + } + + unsafe fn mul_hi_lo_signed( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position_hi: Self::Position, + position_lo: Self::Position, + ) -> (Self::Variable, Self::Variable) { + let x: u32 = (*x).try_into().unwrap(); + let y: u32 = (*y).try_into().unwrap(); + let mul = (((x as i32) as i64) * ((y as i32) as i64)) as u64; + let hi = (mul >> 32) as u32; + let lo = (mul & ((1 << 32) - 1)) as u32; + let hi = hi as u64; + let lo = lo as u64; + self.write_column(position_hi, hi); + self.write_column(position_lo, lo); + (hi, lo) + } + + unsafe fn mul_hi_lo( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position_hi: Self::Position, + position_lo: Self::Position, + ) -> (Self::Variable, Self::Variable) { + let x: u32 = (*x).try_into().unwrap(); + let y: u32 = (*y).try_into().unwrap(); + let mul = (x as u64) * (y as u64); + let hi = (mul >> 32) as u32; + let lo = (mul & ((1 << 32) - 1)) as u32; + let hi = hi as u64; + let lo = lo as u64; + self.write_column(position_hi, hi); + self.write_column(position_lo, lo); + (hi, lo) + } + + unsafe fn divmod_signed( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position_quotient: Self::Position, + position_remainder: Self::Position, + ) -> (Self::Variable, Self::Variable) { + let x: u32 = (*x).try_into().unwrap(); + let y: u32 = (*y).try_into().unwrap(); + let q = ((x as i32) / (y as i32)) as u32; + let r = ((x as i32) % (y as i32)) as u32; + let q = q as u64; + let r = r as u64; + self.write_column(position_quotient, q); + self.write_column(position_remainder, r); + (q, r) + } + + unsafe fn divmod( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position_quotient: Self::Position, + position_remainder: Self::Position, + ) -> (Self::Variable, Self::Variable) { + let x: u32 = (*x).try_into().unwrap(); + let y: u32 = (*y).try_into().unwrap(); + let q = x / y; + let r = x % y; + let q = q as u64; + let r = r as u64; + self.write_column(position_quotient, q); + self.write_column(position_remainder, r); + (q, r) + } + + unsafe fn count_leading_zeros( + &mut self, + x: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + let x: u32 = (*x).try_into().unwrap(); + let res = x.leading_zeros(); + let res = res as u64; + self.write_column(position, res); + res + } + + unsafe fn count_leading_ones( + &mut self, + x: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + let x: u32 = (*x).try_into().unwrap(); + let res = x.leading_ones(); + let res = res as u64; + self.write_column(position, res); + res + } + + fn copy(&mut self, x: &Self::Variable, position: Self::Position) -> Self::Variable { + self.write_column(position, *x); + *x + } + + fn set_halted(&mut self, flag: Self::Variable) { + if flag == 0 { + self.halt = false + } else if flag == 1 { + self.halt = true + } else { + panic!("Bad value for flag in set_halted: {}", flag); + } + } + + fn report_exit(&mut self, exit_code: &Self::Variable) { + println!( + "Exited with code {} at step {}", + *exit_code, + self.normalized_instruction_counter() + ); + } + + fn request_preimage_write( + &mut self, + addr: &Self::Variable, + len: &Self::Variable, + pos: Self::Position, + ) -> Self::Variable { + // The beginning of the syscall + if self.registers.preimage_offset == 0 { + let mut preimage_key = [0u8; 32]; + for i in 0..8 { + let bytes = u32::to_be_bytes(self.registers.preimage_key[i]); + for j in 0..4 { + preimage_key[4 * i + j] = bytes[j] + } + } + let preimage = self.preimage_oracle.get_preimage(preimage_key).get(); + self.preimage = Some(preimage.clone()); + self.preimage_key = Some(preimage_key); + } + + const LENGTH_SIZE: usize = 8; + + let preimage = self + .preimage + .as_ref() + .expect("to have a preimage if we're requesting it at a non-zero offset"); + let preimage_len = preimage.len(); + let preimage_offset = self.registers.preimage_offset as u64; + + let max_read_len = + std::cmp::min(preimage_offset + len, (preimage_len + LENGTH_SIZE) as u64) + - preimage_offset; + + // We read at most 4 bytes, ensuring that we respect word alignment. + // Here, if the address is not aligned, the first call will read < 4 + // but the next calls will be 4 bytes (because the actual address would + // be updated with the offset) until reaching the end of the preimage + // (where the last call could be less than 4 bytes). + let actual_read_len = std::cmp::min(max_read_len, 4 - (addr & 3)); + + // This variable will contain the amount of bytes read which belong to + // the actual preimage + let mut preimage_read_len = 0; + let mut chunk = 0; + for i in 0..actual_read_len { + let idx = (preimage_offset + i) as usize; + // The first 8 bytes of the read preimage are the preimage length, + // followed by the body of the preimage + if idx < LENGTH_SIZE { + // Compute the byte index read from the length + let len_i = idx % MIPS_CHUNK_BYTES_LEN; + + let length_byte = u64::to_be_bytes(preimage_len as u64)[idx]; + + // Write the individual byte of the length to the witness + self.write_column( + Column::ScratchState(MIPS_LENGTH_BYTES_OFF + len_i), + length_byte as u64, + ); + + // TODO: Proabably, the scratch state of MIPS_LENGTH_BYTES_OFF + // is redundant with lines below + unsafe { + self.push_memory(&(*addr + i), length_byte as u64); + self.push_memory_access(&(*addr + i), self.next_instruction_counter()); + } + } else { + // Compute the byte index in the chunk of at most 4 bytes read + // from the preimage + let byte_i = (idx - LENGTH_SIZE) % MIPS_CHUNK_BYTES_LEN; + + // This should really be handled by the keccak oracle. + let preimage_byte = self.preimage.as_ref().unwrap()[idx - LENGTH_SIZE]; + + // Write the individual byte of the preimage to the witness + self.write_column( + Column::ScratchState(MIPS_PREIMAGE_BYTES_OFF + byte_i), + preimage_byte as u64, + ); + + // Update the chunk of at most 4 bytes read from the preimage + chunk = chunk << 8 | preimage_byte as u64; + + // At most, it will be actual_read_len when the length is not + // read in this call + preimage_read_len += 1; + + // TODO: Proabably, the scratch state of MIPS_PREIMAGE_BYTES_OFF + // is redundant with lines below + unsafe { + self.push_memory(&(*addr + i), preimage_byte as u64); + self.push_memory_access(&(*addr + i), self.next_instruction_counter()); + } + } + } + // Update the chunk of at most 4 bytes read from the preimage + // FIXME: this is not linked to the registers content in any way. + // Is there anywhere else where the bytes are stored in the + // scratch state? + self.write_column(Column::ScratchState(MIPS_PREIMAGE_CHUNK_OFF), chunk); + + // Update the number of bytes read from the oracle in this step (can + // include bytelength and preimage bytes) + self.write_column(pos, actual_read_len); + + // Number of preimage bytes processed in this instruction + self.write_column( + Column::ScratchState(MIPS_NUM_BYTES_READ_OFF), + preimage_read_len, + ); + + // Update the flags to count how many bytes are contained at least + for i in 0..MIPS_CHUNK_BYTES_LEN { + if preimage_read_len > i as u64 { + // This amount is only nonzero when it has read some preimage + // bytes. + self.write_column(Column::ScratchState(MIPS_HAS_N_BYTES_OFF + i), 1); + } + } + + // Update the total number of preimage bytes read so far + self.preimage_bytes_read += preimage_read_len; + self.write_column( + Column::ScratchState(MIPS_BYTE_COUNTER_OFF), + self.preimage_bytes_read, + ); + + // If we've read the entire preimage, trigger Keccak workflow + if self.preimage_bytes_read == preimage_len as u64 { + self.write_column(Column::ScratchState(MIPS_END_OF_PREIMAGE_OFF), 1); + + // Store preimage key in the witness excluding the MSB as 248 bits + // so it can be used for the communication channel between Keccak + let bytes31 = (1..32).fold(Fp::zero(), |acc, i| { + acc * Fp::two_pow(8) + Fp::from(self.preimage_key.unwrap()[i]) + }); + self.write_field_column(Self::Position::ScratchState(MIPS_PREIMAGE_KEY), bytes31); + + debug!("Preimage has been read entirely, triggering Keccak process"); + self.keccak_env = Some(KeccakEnv::::new( + self.hash_counter, + self.preimage.as_ref().unwrap(), + )); + + // COMMUNICATION CHANNEL: only on constraint side + + // Update hash counter column + self.write_column( + Column::ScratchState(MIPS_HASH_COUNTER_OFF), + self.hash_counter, + ); + // Number of preimage bytes left to be read should be zero at this + // point + + // Reset environment + self.preimage_bytes_read = 0; + self.preimage_key = None; + self.hash_counter += 1; + + // Reset PreimageCounter column will be done in the next call + } + actual_read_len + } + + fn request_hint_write(&mut self, addr: &Self::Variable, len: &Self::Variable) { + let mut last_hint = match std::mem::take(&mut self.syscall_env.last_hint) { + Some(mut last_hint) => { + last_hint.reserve(*len as usize); + last_hint + } + None => Vec::with_capacity(*len as usize), + }; + + // This should really be handled by the keccak oracle. + for i in 0..*len { + // Push memory access + unsafe { self.push_memory_access(&(*addr + i), self.next_instruction_counter()) }; + // Fetch the value without allocating witness columns + let value = { + let addr: u32 = (*addr).try_into().unwrap(); + let page = addr >> PAGE_ADDRESS_SIZE; + let page_address = (addr & PAGE_ADDRESS_MASK) as usize; + let memory_page_idx = self.get_memory_page_index(page); + self.memory[memory_page_idx].1[page_address] + }; + last_hint.push(value); + } + + let len = last_hint.len(); + let mut idx = 0; + + while idx + 4 <= len { + let hint_len = u32::from_be_bytes(last_hint[idx..idx + 4].try_into().unwrap()) as usize; + idx += 4; + if idx + hint_len <= len { + let hint = last_hint[idx..idx + hint_len].to_vec(); + idx += hint_len; + self.preimage_oracle.hint(Hint::create(hint)); + } + } + + let remaining = last_hint[idx..len].to_vec(); + + self.syscall_env.last_hint = Some(remaining); + } + + fn reset(&mut self) { + self.scratch_state_idx = 0; + self.scratch_state = fresh_scratch_state(); + self.selector = N_MIPS_SEL_COLS; + } +} + +impl Env { + pub fn create(page_size: usize, state: State, preimage_oracle: PreImageOracle) -> Self { + let initial_instruction_pointer = state.pc; + let next_instruction_pointer = state.next_pc; + + let selector = N_MIPS_SEL_COLS; + + let syscall_env = SyscallEnv::create(&state); + + let mut initial_memory: Vec<(u32, Vec)> = state + .memory + .into_iter() + // Check that the conversion from page data is correct + .map(|page| (page.index, page.data)) + .collect(); + + for (_address, initial_memory) in initial_memory.iter_mut() { + initial_memory.extend((0..(page_size - initial_memory.len())).map(|_| 0u8)); + assert_eq!(initial_memory.len(), page_size); + } + + let memory_offsets = initial_memory + .iter() + .map(|(offset, _)| *offset) + .collect::>(); + + let initial_registers = { + let preimage_key = { + let mut preimage_key = [0u32; 8]; + for (i, preimage_key_word) in preimage_key.iter_mut().enumerate() { + *preimage_key_word = u32::from_be_bytes( + state.preimage_key[i * 4..(i + 1) * 4].try_into().unwrap(), + ) + } + preimage_key + }; + Registers { + lo: state.lo, + hi: state.hi, + general_purpose: state.registers, + current_instruction_pointer: initial_instruction_pointer, + next_instruction_pointer, + heap_pointer: state.heap, + preimage_key, + preimage_offset: state.preimage_offset, + } + }; + + Env { + instruction_counter: state.step, + memory: initial_memory.clone(), + last_memory_accesses: [0usize; 3], + memory_write_index: memory_offsets + .iter() + .map(|offset| (*offset, vec![0u64; page_size])) + .collect(), + last_memory_write_index_accesses: [0usize; 3], + registers: initial_registers.clone(), + registers_write_index: Registers::default(), + scratch_state_idx: 0, + scratch_state_idx_inverse: 0, + scratch_state: fresh_scratch_state(), + scratch_state_inverse: fresh_scratch_state(), + halt: state.exited, + syscall_env, + selector, + preimage_oracle, + preimage: state.preimage, + preimage_bytes_read: 0, + preimage_key: None, + keccak_env: None, + hash_counter: 0, + } + } + + pub fn reset_scratch_state(&mut self) { + self.scratch_state_idx = 0; + self.scratch_state = fresh_scratch_state(); + self.selector = N_MIPS_SEL_COLS; + } + + pub fn reset_scratch_state_inverse(&mut self) { + self.scratch_state_idx_inverse = 0; + self.scratch_state_inverse = fresh_scratch_state(); + } + + pub fn write_column(&mut self, column: Column, value: u64) { + self.write_field_column(column, value.into()) + } + + pub fn write_field_column(&mut self, column: Column, value: Fp) { + match column { + Column::ScratchState(idx) => self.scratch_state[idx] = value, + Column::ScratchStateInverse(idx) => self.scratch_state_inverse[idx] = value, + Column::InstructionCounter => panic!("Cannot overwrite the column {:?}", column), + Column::Selector(s) => self.selector = s, + } + } + + pub fn update_last_memory_access(&mut self, i: usize) { + let [i_0, i_1, _] = self.last_memory_accesses; + self.last_memory_accesses = [i, i_0, i_1] + } + + pub fn get_memory_page_index(&mut self, page: u32) -> usize { + for &i in self.last_memory_accesses.iter() { + if self.memory_write_index[i].0 == page { + return i; + } + } + for (i, (page_index, _memory)) in self.memory.iter_mut().enumerate() { + if *page_index == page { + self.update_last_memory_access(i); + return i; + } + } + + // Memory not found; dynamically allocate + let memory = vec![0u8; PAGE_SIZE as usize]; + self.memory.push((page, memory)); + let i = self.memory.len() - 1; + self.update_last_memory_access(i); + i + } + + pub fn update_last_memory_write_index_access(&mut self, i: usize) { + let [i_0, i_1, _] = self.last_memory_write_index_accesses; + self.last_memory_write_index_accesses = [i, i_0, i_1] + } + + pub fn get_memory_access_page_index(&mut self, page: u32) -> usize { + for &i in self.last_memory_write_index_accesses.iter() { + if self.memory_write_index[i].0 == page { + return i; + } + } + for (i, (page_index, _memory_write_index)) in self.memory_write_index.iter_mut().enumerate() + { + if *page_index == page { + self.update_last_memory_write_index_access(i); + return i; + } + } + + // Memory not found; dynamically allocate + let memory_write_index = vec![0u64; PAGE_SIZE as usize]; + self.memory_write_index.push((page, memory_write_index)); + let i = self.memory_write_index.len() - 1; + self.update_last_memory_write_index_access(i); + i + } + + pub fn get_memory_direct(&mut self, addr: u32) -> u8 { + let page = addr >> PAGE_ADDRESS_SIZE; + let page_address = (addr & PAGE_ADDRESS_MASK) as usize; + let memory_idx = self.get_memory_page_index(page); + self.memory[memory_idx].1[page_address] + } + + pub fn decode_instruction(&mut self) -> (Instruction, u32) { + let instruction = + ((self.get_memory_direct(self.registers.current_instruction_pointer) as u32) << 24) + | ((self.get_memory_direct(self.registers.current_instruction_pointer + 1) as u32) + << 16) + | ((self.get_memory_direct(self.registers.current_instruction_pointer + 2) as u32) + << 8) + | (self.get_memory_direct(self.registers.current_instruction_pointer + 3) as u32); + let opcode = { + match instruction >> 26 { + 0x00 => match instruction & 0x3F { + 0x00 => Instruction::RType(RTypeInstruction::ShiftLeftLogical), + 0x02 => Instruction::RType(RTypeInstruction::ShiftRightLogical), + 0x03 => Instruction::RType(RTypeInstruction::ShiftRightArithmetic), + 0x04 => Instruction::RType(RTypeInstruction::ShiftLeftLogicalVariable), + 0x06 => Instruction::RType(RTypeInstruction::ShiftRightLogicalVariable), + 0x07 => Instruction::RType(RTypeInstruction::ShiftRightArithmeticVariable), + 0x08 => Instruction::RType(RTypeInstruction::JumpRegister), + 0x09 => Instruction::RType(RTypeInstruction::JumpAndLinkRegister), + 0x0a => Instruction::RType(RTypeInstruction::MoveZero), + 0x0b => Instruction::RType(RTypeInstruction::MoveNonZero), + 0x0c => match self.registers.general_purpose[2] { + 4090 => Instruction::RType(RTypeInstruction::SyscallMmap), + 4045 => { + // sysBrk + Instruction::RType(RTypeInstruction::SyscallOther) + } + 4120 => { + // sysClone + Instruction::RType(RTypeInstruction::SyscallOther) + } + 4246 => Instruction::RType(RTypeInstruction::SyscallExitGroup), + 4003 => match self.registers.general_purpose[4] { + interpreter::FD_HINT_READ => { + Instruction::RType(RTypeInstruction::SyscallReadHint) + } + interpreter::FD_PREIMAGE_READ => { + Instruction::RType(RTypeInstruction::SyscallReadPreimage) + } + _ => Instruction::RType(RTypeInstruction::SyscallReadOther), + }, + 4004 => match self.registers.general_purpose[4] { + interpreter::FD_PREIMAGE_WRITE => { + Instruction::RType(RTypeInstruction::SyscallWritePreimage) + } + interpreter::FD_HINT_WRITE => { + Instruction::RType(RTypeInstruction::SyscallWriteHint) + } + _ => Instruction::RType(RTypeInstruction::SyscallWriteOther), + }, + 4055 => Instruction::RType(RTypeInstruction::SyscallFcntl), + _ => { + // NB: This has well-defined behavior. Don't panic! + Instruction::RType(RTypeInstruction::SyscallOther) + } + }, + 0x0f => Instruction::RType(RTypeInstruction::Sync), + 0x10 => Instruction::RType(RTypeInstruction::MoveFromHi), + 0x11 => Instruction::RType(RTypeInstruction::MoveToHi), + 0x12 => Instruction::RType(RTypeInstruction::MoveFromLo), + 0x13 => Instruction::RType(RTypeInstruction::MoveToLo), + 0x18 => Instruction::RType(RTypeInstruction::Multiply), + 0x19 => Instruction::RType(RTypeInstruction::MultiplyUnsigned), + 0x1a => Instruction::RType(RTypeInstruction::Div), + 0x1b => Instruction::RType(RTypeInstruction::DivUnsigned), + 0x20 => Instruction::RType(RTypeInstruction::Add), + 0x21 => Instruction::RType(RTypeInstruction::AddUnsigned), + 0x22 => Instruction::RType(RTypeInstruction::Sub), + 0x23 => Instruction::RType(RTypeInstruction::SubUnsigned), + 0x24 => Instruction::RType(RTypeInstruction::And), + 0x25 => Instruction::RType(RTypeInstruction::Or), + 0x26 => Instruction::RType(RTypeInstruction::Xor), + 0x27 => Instruction::RType(RTypeInstruction::Nor), + 0x2a => Instruction::RType(RTypeInstruction::SetLessThan), + 0x2b => Instruction::RType(RTypeInstruction::SetLessThanUnsigned), + _ => { + panic!("Unhandled instruction {:#X}", instruction) + } + }, + 0x01 => { + // RegImm instructions + match (instruction >> 16) & 0x1F { + 0x0 => Instruction::IType(ITypeInstruction::BranchLtZero), + 0x1 => Instruction::IType(ITypeInstruction::BranchGeqZero), + _ => panic!("Unhandled instruction {:#X}", instruction), + } + } + 0x02 => Instruction::JType(JTypeInstruction::Jump), + 0x03 => Instruction::JType(JTypeInstruction::JumpAndLink), + 0x04 => Instruction::IType(ITypeInstruction::BranchEq), + 0x05 => Instruction::IType(ITypeInstruction::BranchNeq), + 0x06 => Instruction::IType(ITypeInstruction::BranchLeqZero), + 0x07 => Instruction::IType(ITypeInstruction::BranchGtZero), + 0x08 => Instruction::IType(ITypeInstruction::AddImmediate), + 0x09 => Instruction::IType(ITypeInstruction::AddImmediateUnsigned), + 0x0A => Instruction::IType(ITypeInstruction::SetLessThanImmediate), + 0x0B => Instruction::IType(ITypeInstruction::SetLessThanImmediateUnsigned), + 0x0C => Instruction::IType(ITypeInstruction::AndImmediate), + 0x0D => Instruction::IType(ITypeInstruction::OrImmediate), + 0x0E => Instruction::IType(ITypeInstruction::XorImmediate), + 0x0F => Instruction::IType(ITypeInstruction::LoadUpperImmediate), + 0x1C => match instruction & 0x3F { + 0x02 => Instruction::RType(RTypeInstruction::MultiplyToRegister), + 0x20 => Instruction::RType(RTypeInstruction::CountLeadingZeros), + 0x21 => Instruction::RType(RTypeInstruction::CountLeadingOnes), + _ => panic!("Unhandled instruction {:#X}", instruction), + }, + 0x20 => Instruction::IType(ITypeInstruction::Load8), + 0x21 => Instruction::IType(ITypeInstruction::Load16), + 0x22 => Instruction::IType(ITypeInstruction::LoadWordLeft), + 0x23 => Instruction::IType(ITypeInstruction::Load32), + 0x24 => Instruction::IType(ITypeInstruction::Load8Unsigned), + 0x25 => Instruction::IType(ITypeInstruction::Load16Unsigned), + 0x26 => Instruction::IType(ITypeInstruction::LoadWordRight), + 0x28 => Instruction::IType(ITypeInstruction::Store8), + 0x29 => Instruction::IType(ITypeInstruction::Store16), + 0x2a => Instruction::IType(ITypeInstruction::StoreWordLeft), + 0x2b => Instruction::IType(ITypeInstruction::Store32), + 0x2e => Instruction::IType(ITypeInstruction::StoreWordRight), + 0x30 => { + // Note: This is ll (LoadLinked), but we're only simulating + // a single processor. + Instruction::IType(ITypeInstruction::Load32) + } + 0x38 => { + // Note: This is sc (StoreConditional), but we're only + // simulating a single processor. + Instruction::IType(ITypeInstruction::Store32Conditional) + } + _ => { + panic!("Unhandled instruction {:#X}", instruction) + } + } + }; + (opcode, instruction) + } + + /// The actual number of instructions executed results from dividing the + /// instruction counter by MAX_ACC (floor). + /// + /// NOTE: actually, in practice it will be less than that, as there is no + /// single instruction that performs all of them. + pub fn normalized_instruction_counter(&self) -> u64 { + self.instruction_counter / MAX_ACC + } + + /// Computes what is the non-normalized next instruction counter, which + /// accounts for the maximum number of register and memory accesses per + /// instruction. + /// + /// Because MAX_NB_REG_ACC = 7 and MAX_NB_MEM_ACC = 12, at most the same + /// instruction will increase the instruction counter by MAX_ACC = 19. + /// + /// Then, in order to update the instruction counter, we need to add 1 to + /// the real instruction counter and multiply it by MAX_ACC to have a unique + /// representation of each step (which is helpful for debugging). + pub fn next_instruction_counter(&self) -> u64 { + (self.normalized_instruction_counter() + 1) * MAX_ACC + } + + /// Execute a single step of the MIPS program. + /// Returns the instruction that was executed. + pub fn step( + &mut self, + config: &VmConfiguration, + metadata: &Meta, + start: &Start, + ) -> Instruction { + self.reset_scratch_state(); + self.reset_scratch_state_inverse(); + let (opcode, _instruction) = self.decode_instruction(); + + self.pp_info(&config.info_at, metadata, start); + self.snapshot_state_at(&config.snapshot_state_at); + + // Force stops at given iteration + if self.should_trigger_at(&config.stop_at) { + self.halt = true; + println!( + "Halted as requested at step={} instruction={:?}", + self.normalized_instruction_counter(), + opcode + ); + return opcode; + } + + interpreter::interpret_instruction(self, opcode); + + self.instruction_counter = self.next_instruction_counter(); + + // Integer division by MAX_ACC to obtain the actual instruction count + if self.halt { + println!( + "Halted at step={} instruction={:?}", + self.normalized_instruction_counter(), + opcode + ); + } + opcode + } + + fn should_trigger_at(&self, at: &StepFrequency) -> bool { + let m: u64 = self.normalized_instruction_counter(); + match at { + StepFrequency::Never => false, + StepFrequency::Always => true, + StepFrequency::Exactly(n) => *n == m, + StepFrequency::Every(n) => m % *n == 0, + StepFrequency::Range(lo, hi_opt) => { + m >= *lo && (hi_opt.is_none() || m < hi_opt.unwrap()) + } + } + } + + // Compute memory usage + fn memory_usage(&self) -> String { + let total = self.memory.len() * PAGE_SIZE as usize; + memory_size(total) + } + + fn page_address(&self) -> (u32, usize) { + let address = self.registers.current_instruction_pointer; + let page = address >> PAGE_ADDRESS_SIZE; + let page_address = (address & PAGE_ADDRESS_MASK) as usize; + (page, page_address) + } + + fn get_opcode(&mut self) -> Option { + let (page_id, page_address) = self.page_address(); + for (page_index, memory) in self.memory.iter() { + if page_id == *page_index { + let memory_slice: [u8; 4] = memory[page_address..page_address + 4] + .try_into() + .expect("Couldn't read 4 bytes at given address"); + return Some(u32::from_be_bytes(memory_slice)); + } + } + None + } + + fn snapshot_state_at(&mut self, at: &StepFrequency) { + if self.should_trigger_at(at) { + let filename = format!( + "snapshot-state-{}.json", + self.normalized_instruction_counter() + ); + let file = File::create(filename.clone()).expect("Impossible to open file"); + let mut writer = BufWriter::new(file); + let mut preimage_key = [0u8; 32]; + for i in 0..8 { + let bytes = u32::to_be_bytes(self.registers.preimage_key[i]); + for j in 0..4 { + preimage_key[4 * i + j] = bytes[j] + } + } + let memory = self + .memory + .clone() + .into_iter() + .map(|(idx, data)| Page { index: idx, data }) + .collect(); + let s: State = State { + pc: self.registers.current_instruction_pointer, + next_pc: self.registers.next_instruction_pointer, + step: self.normalized_instruction_counter(), + registers: self.registers.general_purpose, + lo: self.registers.lo, + hi: self.registers.hi, + heap: self.registers.heap_pointer, + // FIXME: it should be the exit code. We do not keep it in the + // witness atm + exit: if self.halt { 1 } else { 0 }, + last_hint: self.syscall_env.last_hint.clone(), + exited: self.halt, + preimage_offset: self.registers.preimage_offset, + preimage_key, + memory, + preimage: self.preimage.clone(), + }; + let _ = serde_json::to_writer(&mut writer, &s); + info!( + "Snapshot state in {}, step {}", + filename, + self.normalized_instruction_counter() + ); + writer.flush().expect("Flush writer failing") + } + } + + fn pp_info(&mut self, at: &StepFrequency, meta: &Meta, start: &Start) { + if self.should_trigger_at(at) { + let elapsed = start.time.elapsed(); + // Compute the step number removing the MAX_ACC factor + let step = self.normalized_instruction_counter(); + let pc = self.registers.current_instruction_pointer; + + // Get the 32-bits opcode + let insn = self.get_opcode().unwrap(); + + // Approximate instruction per seconds + let how_many_steps = step as usize - start.step; + + let ips = how_many_steps as f64 / elapsed.as_secs() as f64; + + let pages = self.memory.len(); + + let mem = self.memory_usage(); + let name = meta + .find_address_symbol(pc) + .unwrap_or_else(|| "n/a".to_string()); + + info!( + "processing step={} pc={:010x} insn={:010x} ips={:.2} pages={} mem={} name={}", + step, pc, insn, ips, pages, mem, name + ); + } + } +} diff --git a/o1vm/src/interpreters/mod.rs b/o1vm/src/interpreters/mod.rs new file mode 100644 index 0000000000..426192f514 --- /dev/null +++ b/o1vm/src/interpreters/mod.rs @@ -0,0 +1,10 @@ +/// An interpreter for an optimised version of Keccak +pub mod keccak; + +/// An interpreter for the MIPS instruction set. +pub mod mips; + +/// An interpreter for the RISC-V 32IM instruction set, following the specification +/// on +/// [riscv.org](https://riscv.org/wp-content/uploads/2019/12/riscv-spec-20191213.pdf). +pub mod riscv32im; diff --git a/o1vm/src/interpreters/riscv32im/column.rs b/o1vm/src/interpreters/riscv32im/column.rs new file mode 100644 index 0000000000..871db4a546 --- /dev/null +++ b/o1vm/src/interpreters/riscv32im/column.rs @@ -0,0 +1,104 @@ +use super::{ + interpreter::{ + IInstruction, + Instruction::{self, IType, MType, RType, SBType, SType, SyscallType, UJType, UType}, + RInstruction, SBInstruction, SInstruction, SyscallInstruction, UInstruction, UJInstruction, + }, + INSTRUCTION_SET_SIZE, SCRATCH_SIZE, +}; +use kimchi::circuits::{ + berkeley_columns::BerkeleyChallengeTerm, + expr::{ConstantExpr, Expr}, +}; +use strum::EnumCount; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum Column { + ScratchState(usize), + InstructionCounter, + Selector(usize), +} + +impl From for usize { + fn from(col: Column) -> usize { + match col { + Column::ScratchState(i) => { + assert!(i < SCRATCH_SIZE); + i + } + Column::InstructionCounter => SCRATCH_SIZE, + Column::Selector(s) => { + assert!( + s < INSTRUCTION_SET_SIZE, + "There is only {INSTRUCTION_SET_SIZE}" + ); + SCRATCH_SIZE + 1 + s + } + } + } +} + +impl From for usize { + fn from(instr: Instruction) -> usize { + match instr { + RType(rtype) => SCRATCH_SIZE + 1 + rtype as usize, + IType(itype) => SCRATCH_SIZE + 1 + RInstruction::COUNT + itype as usize, + SType(stype) => { + SCRATCH_SIZE + 1 + RInstruction::COUNT + IInstruction::COUNT + stype as usize + } + SBType(sbtype) => { + SCRATCH_SIZE + + 1 + + RInstruction::COUNT + + IInstruction::COUNT + + SInstruction::COUNT + + sbtype as usize + } + UType(utype) => { + SCRATCH_SIZE + + 1 + + RInstruction::COUNT + + IInstruction::COUNT + + SInstruction::COUNT + + SBInstruction::COUNT + + utype as usize + } + UJType(ujtype) => { + SCRATCH_SIZE + + 1 + + RInstruction::COUNT + + IInstruction::COUNT + + SInstruction::COUNT + + SBInstruction::COUNT + + UInstruction::COUNT + + ujtype as usize + } + SyscallType(syscalltype) => { + SCRATCH_SIZE + + 1 + + RInstruction::COUNT + + IInstruction::COUNT + + SInstruction::COUNT + + SBInstruction::COUNT + + UInstruction::COUNT + + UJInstruction::COUNT + + syscalltype as usize + } + MType(mtype) => { + SCRATCH_SIZE + + 1 + + RInstruction::COUNT + + IInstruction::COUNT + + SInstruction::COUNT + + SBInstruction::COUNT + + UInstruction::COUNT + + UJInstruction::COUNT + + SyscallInstruction::COUNT + + mtype as usize + } + } + } +} + +// FIXME: use other challenges, not Berkeley. +pub type E = Expr, Column>; diff --git a/o1vm/src/interpreters/riscv32im/constraints.rs b/o1vm/src/interpreters/riscv32im/constraints.rs new file mode 100644 index 0000000000..ba1633e1a7 --- /dev/null +++ b/o1vm/src/interpreters/riscv32im/constraints.rs @@ -0,0 +1,471 @@ +use super::{ + column::{Column, E}, + interpreter::{Instruction, InterpreterEnv}, + INSTRUCTION_SET_SIZE, +}; +use crate::{ + interpreters::riscv32im::{constraints::ConstantTerm::Literal, SCRATCH_SIZE}, + lookups::Lookup, +}; +use ark_ff::{Field, One}; +use kimchi::circuits::{ + expr::{ConstantTerm, Expr, ExprInner, Operations, Variable}, + gate::CurrOrNext, +}; + +pub struct Env { + pub scratch_state_idx: usize, + pub lookups: Vec>>, + pub constraints: Vec>, + pub selector: Option>, +} + +impl Default for Env { + fn default() -> Self { + Self { + scratch_state_idx: 0, + constraints: Vec::new(), + lookups: Vec::new(), + selector: None, + } + } +} + +impl InterpreterEnv for Env { + /// In the concrete implementation for the constraints, the interpreter will + /// work over columns. The position in this case can be seen as a new + /// variable/input of our circuit. + type Position = Column; + + // Use one of the available columns. It won't create a new column every time + // this function is called. The number of columns is defined upfront by + // crate::mips::witness::SCRATCH_SIZE. + fn alloc_scratch(&mut self) -> Self::Position { + // All columns are implemented using a simple index, and a name is given + // to the index. See crate::SCRATCH_SIZE for the maximum number of + // columns the circuit can use. + let scratch_idx = self.scratch_state_idx; + self.scratch_state_idx += 1; + Column::ScratchState(scratch_idx) + } + + type Variable = E; + + fn variable(&self, column: Self::Position) -> Self::Variable { + Expr::Atom(ExprInner::Cell(Variable { + col: column, + row: CurrOrNext::Curr, + })) + } + + fn activate_selector(&mut self, selector: Instruction) { + // Sanity check: we only want to activate once per instruction + assert!(self.selector.is_none(), "A selector has been already activated. You might need to reset the environment if you want to start a new instruction."); + let n = usize::from(selector) - SCRATCH_SIZE - 1; + self.selector = Some(self.variable(Column::Selector(n))) + } + + fn add_constraint(&mut self, assert_equals_zero: Self::Variable) { + self.constraints.push(assert_equals_zero) + } + + fn check_is_zero(_assert_equals_zero: &Self::Variable) { + // No-op, witness only + } + + fn check_equal(_x: &Self::Variable, _y: &Self::Variable) { + // No-op, witness only + } + + fn check_boolean(_x: &Self::Variable) { + // No-op, witness only + } + + fn add_lookup(&mut self, lookup: Lookup) { + self.lookups.push(lookup); + } + + fn instruction_counter(&self) -> Self::Variable { + self.variable(Column::InstructionCounter) + } + + fn increase_instruction_counter(&mut self) { + // No-op, witness only + } + + unsafe fn fetch_register( + &mut self, + _idx: &Self::Variable, + output: Self::Position, + ) -> Self::Variable { + self.variable(output) + } + + unsafe fn push_register_if( + &mut self, + _idx: &Self::Variable, + _value: Self::Variable, + _if_is_true: &Self::Variable, + ) { + // No-op, witness only + } + + unsafe fn fetch_register_access( + &mut self, + _idx: &Self::Variable, + output: Self::Position, + ) -> Self::Variable { + self.variable(output) + } + + unsafe fn push_register_access_if( + &mut self, + _idx: &Self::Variable, + _value: Self::Variable, + _if_is_true: &Self::Variable, + ) { + // No-op, witness only + } + + unsafe fn fetch_memory( + &mut self, + _addr: &Self::Variable, + output: Self::Position, + ) -> Self::Variable { + self.variable(output) + } + + unsafe fn push_memory(&mut self, _addr: &Self::Variable, _value: Self::Variable) { + // No-op, witness only + } + + unsafe fn fetch_memory_access( + &mut self, + _addr: &Self::Variable, + output: Self::Position, + ) -> Self::Variable { + self.variable(output) + } + + unsafe fn push_memory_access(&mut self, _addr: &Self::Variable, _value: Self::Variable) { + // No-op, witness only + } + + fn constant(x: u32) -> Self::Variable { + Self::Variable::constant(Operations::from(Literal(Fp::from(x)))) + } + + unsafe fn bitmask( + &mut self, + _x: &Self::Variable, + _highest_bit: u32, + _lowest_bit: u32, + position: Self::Position, + ) -> Self::Variable { + self.variable(position) + } + + unsafe fn shift_left( + &mut self, + _x: &Self::Variable, + _by: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + self.variable(position) + } + + unsafe fn shift_right( + &mut self, + _x: &Self::Variable, + _by: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + self.variable(position) + } + + unsafe fn shift_right_arithmetic( + &mut self, + _x: &Self::Variable, + _by: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + self.variable(position) + } + + unsafe fn test_zero( + &mut self, + _x: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + self.variable(position) + } + + fn is_zero(&mut self, x: &Self::Variable) -> Self::Variable { + let res = { + let pos = self.alloc_scratch(); + unsafe { self.test_zero(x, pos) } + }; + let x_inv_or_zero = { + let pos = self.alloc_scratch(); + unsafe { self.inverse_or_zero(x, pos) } + }; + // If x = 0, then res = 1 and x_inv_or_zero = 0 + // If x <> 0, then res = 0 and x_inv_or_zero = x^(-1) + self.add_constraint(x.clone() * x_inv_or_zero.clone() + res.clone() - Self::constant(1)); + self.add_constraint(x.clone() * res.clone()); + res + } + + unsafe fn inverse_or_zero( + &mut self, + _x: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + self.variable(position) + } + + fn equal(&mut self, x: &Self::Variable, y: &Self::Variable) -> Self::Variable { + self.is_zero(&(x.clone() - y.clone())) + } + + unsafe fn test_less_than( + &mut self, + _x: &Self::Variable, + _y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + self.variable(position) + } + + unsafe fn test_less_than_signed( + &mut self, + _x: &Self::Variable, + _y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + self.variable(position) + } + + unsafe fn and_witness( + &mut self, + _x: &Self::Variable, + _y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + self.variable(position) + } + + unsafe fn nor_witness( + &mut self, + _x: &Self::Variable, + _y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + self.variable(position) + } + + unsafe fn or_witness( + &mut self, + _x: &Self::Variable, + _y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + self.variable(position) + } + + unsafe fn xor_witness( + &mut self, + _x: &Self::Variable, + _y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + self.variable(position) + } + + unsafe fn add_witness( + &mut self, + _y: &Self::Variable, + _x: &Self::Variable, + out_position: Self::Position, + overflow_position: Self::Position, + ) -> (Self::Variable, Self::Variable) { + ( + self.variable(out_position), + self.variable(overflow_position), + ) + } + + unsafe fn sub_witness( + &mut self, + _y: &Self::Variable, + _x: &Self::Variable, + out_position: Self::Position, + underflow_position: Self::Position, + ) -> (Self::Variable, Self::Variable) { + ( + self.variable(out_position), + self.variable(underflow_position), + ) + } + + unsafe fn mul_signed_witness( + &mut self, + _x: &Self::Variable, + _y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + self.variable(position) + } + + unsafe fn mul_hi_signed( + &mut self, + _x: &Self::Variable, + _y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + self.variable(position) + } + + unsafe fn mul_lo_signed( + &mut self, + _x: &Self::Variable, + _y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + self.variable(position) + } + + unsafe fn mul_hi( + &mut self, + _x: &Self::Variable, + _y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + self.variable(position) + } + + unsafe fn mul_hi_signed_unsigned( + &mut self, + _x: &Self::Variable, + _y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + self.variable(position) + } + + unsafe fn div_signed( + &mut self, + _x: &Self::Variable, + _y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + self.variable(position) + } + + unsafe fn mod_signed( + &mut self, + _x: &Self::Variable, + _y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + self.variable(position) + } + + unsafe fn div( + &mut self, + _x: &Self::Variable, + _y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + self.variable(position) + } + + unsafe fn mod_unsigned( + &mut self, + _x: &Self::Variable, + _y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + self.variable(position) + } + + unsafe fn mul_lo( + &mut self, + _x: &Self::Variable, + _y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + self.variable(position) + } + + unsafe fn count_leading_zeros( + &mut self, + _x: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + self.variable(position) + } + + unsafe fn count_leading_ones( + &mut self, + _x: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + self.variable(position) + } + + fn copy(&mut self, x: &Self::Variable, position: Self::Position) -> Self::Variable { + let res = self.variable(position); + self.constraints.push(x.clone() - res.clone()); + res + } + + fn set_halted(&mut self, _flag: Self::Variable) { + // TODO + } + + fn report_exit(&mut self, _exit_code: &Self::Variable) {} + + fn reset(&mut self) { + self.scratch_state_idx = 0; + self.constraints.clear(); + self.lookups.clear(); + self.selector = None; + } +} + +impl Env { + /// Return the constraints for the selector. + /// Each selector must be a boolean. + pub fn get_selector_constraints(&self) -> Vec> { + let one = ::Variable::one(); + let mut enforce_bool: Vec> = (0..INSTRUCTION_SET_SIZE) + .map(|i| { + let var = self.variable(Column::Selector(i)); + (var.clone() - one.clone()) * var.clone() + }) + .collect(); + let enforce_one_activation = (0..INSTRUCTION_SET_SIZE).fold(E::::one(), |res, i| { + let var = self.variable(Column::Selector(i)); + res - var.clone() + }); + + enforce_bool.push(enforce_one_activation); + enforce_bool + } + + pub fn get_selector(&self) -> E { + self.selector + .clone() + .unwrap_or_else(|| panic!("Selector is not set")) + } + + /// Return the constraints for the current instruction, without the selector + pub fn get_constraints(&self) -> Vec> { + self.constraints.clone() + } + + pub fn get_lookups(&self) -> Vec>> { + self.lookups.clone() + } +} diff --git a/o1vm/src/interpreters/riscv32im/interpreter.rs b/o1vm/src/interpreters/riscv32im/interpreter.rs new file mode 100644 index 0000000000..0bc94d1c7e --- /dev/null +++ b/o1vm/src/interpreters/riscv32im/interpreter.rs @@ -0,0 +1,2355 @@ +//! This module implement an interpreter for the RISCV32 IM instruction set +//! architecture. +//! +//! The implementation mostly follows (and copy) code from the MIPS interpreter +//! available [here](../mips/interpreter.rs). +//! +//! ## Credits +//! +//! We would like to thank the authors of the following documentations: +//! - ([CC BY +//! 4.0](https://creativecommons.org/licenses/by/4.0/)) from +//! [msyksphinz-self](https://github.com/msyksphinz-self/riscv-isadoc) +//! - +//! +//! from the course [CS 3410: Computer System Organization and +//! Programming](https://www.cs.cornell.edu/courses/cs3410/2024fa/home.html) at +//! Cornell University. +//! +//! The format and description of each instruction is taken from these sources, +//! and copied in this file for offline reference. +//! If you are the author of the above documentations and would like to add or +//! modify the credits, please open a pull request. +//! +//! For each instruction, we provide the format, description, and the +//! semantic in pseudo-code of the instruction. +//! When `signed` is mentioned in the pseudo-code, it means that the +//! operation is performed as a signed operation (i.e. signed(v) where `v` is a +//! 32 bits value means that `v` must be interpreted as a i32 value in Rust, the +//! most significant bit being the sign - 1 for negative, 0 for positive). +//! By default, unsigned operations are performed. + +use super::registers::{REGISTER_CURRENT_IP, REGISTER_HEAP_POINTER, REGISTER_NEXT_IP}; +use crate::lookups::{Lookup, LookupTableIDs}; +use ark_ff::{One, Zero}; +use strum::{EnumCount, IntoEnumIterator}; +use strum_macros::{EnumCount, EnumIter}; + +#[derive(Debug, Clone, Copy, Eq, PartialEq, EnumCount, EnumIter, Hash, Ord, PartialOrd)] +pub enum Instruction { + RType(RInstruction), + IType(IInstruction), + SType(SInstruction), + SBType(SBInstruction), + UType(UInstruction), + UJType(UJInstruction), + SyscallType(SyscallInstruction), + MType(MInstruction), +} + +// See +// https://www.cs.cornell.edu/courses/cs3410/2024fa/assignments/cpusim/riscv-instructions.pdf +// for the order +#[derive( + Debug, Clone, Copy, Eq, PartialEq, EnumCount, EnumIter, Default, Hash, Ord, PartialOrd, +)] +pub enum RInstruction { + #[default] + /// Format: `add rd, rs1, rs2` + /// + /// Description: Adds the registers rs1 and rs2 and stores the result in rd. + /// Arithmetic overflow is ignored and the result is simply the low 32 + /// bits of the result. + Add, // add + /// Format: `sub rd, rs1, rs2` + /// + /// Description: Subs the register rs2 from rs1 and stores the result in rd. + /// Arithmetic overflow is ignored and the result is simply the low 32 + /// bits of the result. + Sub, // sub + /// Format: `sll rd, rs1, rs2` + /// + /// Description: Performs logical left shift on the value in register rs1 by + /// the shift amount held in the lower 5 bits of register rs2. + ShiftLeftLogical, // sll + /// Format: `slt rd, rs1, rs2` + /// + /// Description: Place the value 1 in register rd if register rs1 is less + /// than register rs2 when both are treated as signed numbers, else 0 is + /// written to rd. + SetLessThan, // slt + /// Format: `sltu rd, rs1, rs2` + /// + /// Description: Place the value 1 in register rd if register rs1 is less + /// than register rs2 when both are treated as unsigned numbers, else 0 is + /// written to rd. + SetLessThanUnsigned, // sltu + /// Format: `xor rd, rs1, rs2` + /// + /// Description: Performs bitwise XOR on registers rs1 and rs2 and place the + /// result in rd + Xor, // xor + /// Format: `srl rd, rs1, rs2` + /// + /// Description: Logical right shift on the value in register rs1 by the + /// shift amount held in the lower 5 bits of register rs2 + ShiftRightLogical, // srl + /// Format: `sra rd, rs1, rs2` + /// + /// Description: Performs arithmetic right shift on the value in register + /// rs1 by the shift amount held in the lower 5 bits of register rs2 + ShiftRightArithmetic, // sra + /// Format: `or rd, rs1, rs2` + /// + /// Description: Performs bitwise OR on registers rs1 and rs2 and place the + /// result in rd + Or, // or + /// Format: `and rd, rs1, rs2` + /// + /// Description: Performs bitwise AND on registers rs1 and rs2 and place the + /// result in rd + And, // and + /// Format: `fence` + /// + /// Description: Used to order device I/O and memory accesses as viewed by + /// other RISC-V harts and external devices or coprocessors. + /// Any combination of device input (I), device output (O), memory reads + /// (R), and memory writes (W) may be ordered with respect to any + /// combination of the same. Informally, no other RISC-V hart or external + /// device can observe any operation in the successor set following a FENCE + /// before any operation in the predecessor set preceding the FENCE. + Fence, // fence + /// Format: `fence.i` + /// + /// Description: Provides explicit synchronization between writes to + /// instruction memory and instruction fetches on the same hart. + FenceI, // fence.i +} + +#[derive( + Debug, Clone, Copy, Eq, PartialEq, EnumCount, EnumIter, Default, Hash, Ord, PartialOrd, +)] +pub enum IInstruction { + #[default] + /// Format: `lb rd, offset(rs1)` + /// + /// Description: Loads a 8-bit value from memory and sign-extends this to + /// 32 bits before storing it in register rd. + LoadByte, // lb + /// Format: `lh rd, offset(rs1)` + /// + /// Description: Loads a 16-bit value from memory and sign-extends this to + /// 32 bits before storing it in register rd. + LoadHalf, // lh + /// Format: `lw rd, offset(rs1)` + /// + /// Description: Loads a 32-bit value from memory and sign-extends this to + /// 32 bits before storing it in register rd. + LoadWord, // lw + /// Format: `lbu rd, offset(rs1)` + /// + /// Description: Loads a 8-bit value from memory and zero-extends this to + /// 32 bits before storing it in register rd. + LoadByteUnsigned, // lbu + /// Format: `lhu rd, offset(rs1)` + /// + /// Description: Loads a 16-bit value from memory and zero-extends this to + /// 32 bits before storing it in register rd. + LoadHalfUnsigned, // lhu + + /// Format: `slli rd, rs1, shamt` + /// + /// Description: Performs logical left shift on the value in register rs1 by + /// the shift amount held in the lower 5 bits of the immediate + ShiftLeftLogicalImmediate, // slli + /// Format: `srli rd, rs1, shamt` + /// + /// Description: Performs logical right shift on the value in register rs1 + /// by the shift amount held in the lower 5 bits of the immediate + ShiftRightLogicalImmediate, // srli + /// Format: `srai rd, rs1, shamt` + /// + /// Description: Performs arithmetic right shift on the value in register + /// rs1 by the shift amount held in the lower 5 bits of the immediate + ShiftRightArithmeticImmediate, // srai + /// Format: `slti rd, rs1, imm` + /// + /// Description: Place the value 1 in register rd if register rs1 is less + /// than the signextended immediate when both are treated as signed numbers, + /// else 0 is written to rd. + SetLessThanImmediate, // slti + /// Format: `sltiu rd, rs1, imm` + /// + /// Description: Place the value 1 in register rd if register rs1 is less + /// than the immediate when both are treated as unsigned numbers, else 0 is + /// written to rd. + SetLessThanImmediateUnsigned, // sltiu + + /// Format: `addi rd, rs1, imm` + /// + /// Description: Adds the sign-extended 12-bit immediate to register rs1. + /// Arithmetic overflow is ignored and the result is simply the low 32 + /// bits of the result. ADDI rd, rs1, 0 is used to implement the MV rd, rs1 + /// assembler pseudo-instruction. + AddImmediate, // addi + /// Format: `xori rd, rs1, imm` + /// + /// Description: Performs bitwise XOR on register rs1 and the sign-extended + /// 12-bit immediate and place the result in rd Note, “XORI rd, rs1, -1” + /// performs a bitwise logical inversion of register rs1(assembler + /// pseudo-instruction NOT rd, rs) + XorImmediate, // xori + /// Format: `ori rd, rs1, imm` + /// + /// Description: Performs bitwise OR on register rs1 and the sign-extended + /// 12-bit immediate and place the result in rd + OrImmediate, // ori + /// Format: `andi rd, rs1, imm` + /// + /// Description: Performs bitwise AND on register rs1 and the sign-extended + /// 12-bit immediate and place the result in rd + AndImmediate, // andi + + /// Format: `jalr rd, rs1, imm` + /// + /// Description: Jump to address and place return address in rd. + JumpAndLinkRegister, // jalr +} + +#[derive( + Debug, Clone, Copy, Eq, PartialEq, EnumCount, EnumIter, Default, Hash, Ord, PartialOrd, +)] +pub enum SInstruction { + #[default] + /// Format: `sb rs2, offset(rs1)` + /// + /// Description: Store 8-bit, values from the low bits of register rs2 to + /// memory. + StoreByte, // sb + /// Format: `sh rs2, offset(rs1)` + /// + /// Description: Store 16-bit, values from the low bits of register rs2 to + /// memory. + StoreHalf, // sh + /// Format: `sw rs2, offset(rs1)` + /// + /// Description: Store 32-bit, values from the low bits of register rs2 to + /// memory. + StoreWord, // sw +} + +#[derive( + Debug, Clone, Copy, Eq, PartialEq, EnumCount, EnumIter, Default, Hash, Ord, PartialOrd, +)] +pub enum SBInstruction { + #[default] + /// Format: `beq rs1, rs2, offset` + /// + /// Description: Take the branch if registers rs1 and rs2 are equal. + BranchEq, // beq + /// Format: `bne rs1, rs2, offset` + /// + /// Description: Take the branch if registers rs1 and rs2 are not equal. + BranchNeq, // bne + /// Format: `blt rs1, rs2, offset` + /// + /// Description: Take the branch if registers rs1 is less than rs2, using + /// signed comparison. + BranchLessThan, // blt + /// Format: `bge rs1, rs2, offset` + /// + /// Description: Take the branch if registers rs1 is greater than or equal + /// to rs2, using signed comparison. + BranchGreaterThanEqual, // bge + /// Format: `bltu rs1, rs2, offset` + /// + /// Description: Take the branch if registers rs1 is less than rs2, using + /// unsigned comparison. + BranchLessThanUnsigned, // bltu + /// Format: `bgeu rs1, rs2, offset` + /// + /// Description: Take the branch if registers rs1 is greater than or equal + /// to rs2, using unsigned comparison. + BranchGreaterThanEqualUnsigned, // bgeu +} + +#[derive( + Debug, Clone, Copy, Eq, PartialEq, EnumCount, EnumIter, Default, Hash, Ord, PartialOrd, +)] +pub enum UInstruction { + #[default] + /// Format: `lui rd,imm` + /// + /// Description: Build 32-bit constants and uses the U-type format. LUI + /// places the U-immediate value in the top 20 bits of the destination + /// register rd, filling in the lowest 12 bits with zeros. + LoadUpperImmediate, // lui + /// Format: `auipc rd,imm` + /// + /// Description: Build pc-relative addresses and uses the U-type format. + /// AUIPC (Add upper immediate to PC) forms a 32-bit offset from the 20-bit + /// U-immediate, filling in the lowest 12 bits with zeros, adds this offset + /// to the pc, then places the result in register rd. + AddUpperImmediate, // auipc +} + +#[derive( + Debug, Clone, Copy, Eq, PartialEq, EnumCount, EnumIter, Default, Hash, Ord, PartialOrd, +)] +pub enum UJInstruction { + #[default] + /// Format: `jal rd,imm` + /// + /// Description: Jump to address and place return address in rd. + JumpAndLink, // jal +} + +#[derive( + Debug, Clone, Copy, Eq, PartialEq, EnumCount, EnumIter, Default, Hash, Ord, PartialOrd, +)] +pub enum SyscallInstruction { + #[default] + SyscallSuccess, +} + +/// M extension instructions +/// Following +#[derive( + Debug, Clone, Copy, Eq, PartialEq, EnumCount, EnumIter, Default, Hash, Ord, PartialOrd, +)] +pub enum MInstruction { + /// Format: `mul rd, rs1, rs2` + /// + /// Description: performs an 32-bit 32-bit multiplication of signed rs1 + /// by signed rs2 and places the lower 32 bits in the destination register. + /// Implementation: `x[rd] = x[rs1] * x[rs2]` + #[default] + Mul, // mul + /// Format: `mulh rd, rs1, rs2` + /// + /// Description: performs an 32-bit 32-bit multiplication of signed rs1 by + /// signed rs2 and places the upper 32 bits in the destination register. + /// Implementation: `x[rd] = (x[rs1] * x[rs2]) >> 32` + Mulh, // mulh + /// Format: `mulhsu rd, rs1, rs2` + /// + /// Description: performs an 32-bit 32-bit multiplication of signed rs1 by + /// unsigned rs2 and places the upper 32 bits in the destination register. + /// Implementation: `x[rd] = (x[rs1] * x[rs2]) >> 32` + Mulhsu, // mulhsu + /// Format: `mulhu rd, rs1, rs2` + /// + /// Description: performs an 32-bit 32-bit multiplication of unsigned rs1 by + /// unsigned rs2 and places the upper 32 bits in the destination register. + /// Implementation: `x[rd] = (x[rs1] * x[rs2]) >> 32` + Mulhu, // mulhu + /// Format: `div rd, rs1, rs2` + /// + /// Description: perform an 32 bits by 32 bits signed integer division of + /// rs1 by rs2, rounding towards zero + /// Implementation: `x[rd] = x[rs1] /s x[rs2]` + Div, // div + /// Format: `divu rd, rs1, rs2` + /// + /// Description: performs an 32 bits by 32 bits unsigned integer division of + /// rs1 by rs2, rounding towards zero. + /// Implementation: `x[rd] = x[rs1] /u x[rs2]` + Divu, // divu + /// Format: `rem rd, rs1, rs2` + /// + /// Description: performs an 32 bits by 32 bits signed integer reminder of + /// rs1 by rs2. + /// Implementation: `x[rd] = x[rs1] %s x[rs2]` + Rem, // rem + /// Format: `remu rd, rs1, rs2` + /// + /// Description: performs an 32 bits by 32 bits unsigned integer reminder of + /// rs1 by rs2. + /// Implementation: `x[rd] = x[rs1] %u x[rs2]` + Remu, // remu +} + +impl IntoIterator for Instruction { + type Item = Instruction; + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + match self { + Instruction::RType(_) => { + let mut iter_contents = Vec::with_capacity(RInstruction::COUNT); + for rtype in RInstruction::iter() { + iter_contents.push(Instruction::RType(rtype)); + } + iter_contents.into_iter() + } + Instruction::IType(_) => { + let mut iter_contents = Vec::with_capacity(IInstruction::COUNT); + for itype in IInstruction::iter() { + iter_contents.push(Instruction::IType(itype)); + } + iter_contents.into_iter() + } + Instruction::SType(_) => { + let mut iter_contents = Vec::with_capacity(SInstruction::COUNT); + for stype in SInstruction::iter() { + iter_contents.push(Instruction::SType(stype)); + } + iter_contents.into_iter() + } + Instruction::SBType(_) => { + let mut iter_contents = Vec::with_capacity(SBInstruction::COUNT); + for sbtype in SBInstruction::iter() { + iter_contents.push(Instruction::SBType(sbtype)); + } + iter_contents.into_iter() + } + Instruction::UType(_) => { + let mut iter_contents = Vec::with_capacity(UInstruction::COUNT); + for utype in UInstruction::iter() { + iter_contents.push(Instruction::UType(utype)); + } + iter_contents.into_iter() + } + Instruction::UJType(_) => { + let mut iter_contents = Vec::with_capacity(UJInstruction::COUNT); + for ujtype in UJInstruction::iter() { + iter_contents.push(Instruction::UJType(ujtype)); + } + iter_contents.into_iter() + } + Instruction::SyscallType(_) => { + let mut iter_contents = Vec::with_capacity(SyscallInstruction::COUNT); + for syscall in SyscallInstruction::iter() { + iter_contents.push(Instruction::SyscallType(syscall)); + } + iter_contents.into_iter() + } + Instruction::MType(_) => { + let mut iter_contents = Vec::with_capacity(MInstruction::COUNT); + for mtype in MInstruction::iter() { + iter_contents.push(Instruction::MType(mtype)); + } + iter_contents.into_iter() + } + } + } +} + +impl std::fmt::Display for Instruction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Instruction::RType(rtype) => write!(f, "{}", rtype), + Instruction::IType(itype) => write!(f, "{}", itype), + Instruction::SType(stype) => write!(f, "{}", stype), + Instruction::SBType(sbtype) => write!(f, "{}", sbtype), + Instruction::UType(utype) => write!(f, "{}", utype), + Instruction::UJType(ujtype) => write!(f, "{}", ujtype), + Instruction::SyscallType(_syscall) => write!(f, "ecall"), + Instruction::MType(mtype) => write!(f, "{}", mtype), + } + } +} + +impl std::fmt::Display for RInstruction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + RInstruction::Add => write!(f, "add"), + RInstruction::Sub => write!(f, "sub"), + RInstruction::ShiftLeftLogical => write!(f, "sll"), + RInstruction::SetLessThan => write!(f, "slt"), + RInstruction::SetLessThanUnsigned => write!(f, "sltu"), + RInstruction::Xor => write!(f, "xor"), + RInstruction::ShiftRightLogical => write!(f, "srl"), + RInstruction::ShiftRightArithmetic => write!(f, "sra"), + RInstruction::Or => write!(f, "or"), + RInstruction::And => write!(f, "and"), + RInstruction::Fence => write!(f, "fence"), + RInstruction::FenceI => write!(f, "fence.i"), + } + } +} + +impl std::fmt::Display for IInstruction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + IInstruction::LoadByte => write!(f, "lb"), + IInstruction::LoadHalf => write!(f, "lh"), + IInstruction::LoadWord => write!(f, "lw"), + IInstruction::LoadByteUnsigned => write!(f, "lbu"), + IInstruction::LoadHalfUnsigned => write!(f, "lhu"), + IInstruction::ShiftLeftLogicalImmediate => write!(f, "slli"), + IInstruction::ShiftRightLogicalImmediate => write!(f, "srli"), + IInstruction::ShiftRightArithmeticImmediate => write!(f, "srai"), + IInstruction::SetLessThanImmediate => write!(f, "slti"), + IInstruction::SetLessThanImmediateUnsigned => write!(f, "sltiu"), + IInstruction::AddImmediate => write!(f, "addi"), + IInstruction::XorImmediate => write!(f, "xori"), + IInstruction::OrImmediate => write!(f, "ori"), + IInstruction::AndImmediate => write!(f, "andi"), + IInstruction::JumpAndLinkRegister => write!(f, "jalr"), + } + } +} + +impl std::fmt::Display for SInstruction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + SInstruction::StoreByte => write!(f, "sb"), + SInstruction::StoreHalf => write!(f, "sh"), + SInstruction::StoreWord => write!(f, "sw"), + } + } +} + +impl std::fmt::Display for SBInstruction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + SBInstruction::BranchEq => write!(f, "beq"), + SBInstruction::BranchNeq => write!(f, "bne"), + SBInstruction::BranchLessThan => write!(f, "blt"), + SBInstruction::BranchGreaterThanEqual => write!(f, "bge"), + SBInstruction::BranchLessThanUnsigned => write!(f, "bltu"), + SBInstruction::BranchGreaterThanEqualUnsigned => write!(f, "bgeu"), + } + } +} + +impl std::fmt::Display for UInstruction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + UInstruction::LoadUpperImmediate => write!(f, "lui"), + UInstruction::AddUpperImmediate => write!(f, "auipc"), + } + } +} + +impl std::fmt::Display for UJInstruction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + UJInstruction::JumpAndLink => write!(f, "jal"), + } + } +} + +impl std::fmt::Display for MInstruction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + MInstruction::Mul => write!(f, "mul"), + MInstruction::Mulh => write!(f, "mulh"), + MInstruction::Mulhsu => write!(f, "mulhsu"), + MInstruction::Mulhu => write!(f, "mulhu"), + MInstruction::Div => write!(f, "div"), + MInstruction::Divu => write!(f, "divu"), + MInstruction::Rem => write!(f, "rem"), + MInstruction::Remu => write!(f, "remu"), + } + } +} + +pub trait InterpreterEnv { + /// A position can be seen as an indexed variable + type Position; + + /// Allocate a new abstract variable for the current step. + /// The variable can be used to store temporary values. + /// The variables are "freed" after each step/instruction. + /// The variable allocation can be seen as an allocation on a stack that is + /// popped after each step execution. + /// At the moment, [crate::interpreters::riscv32im::SCRATCH_SIZE] + /// elements can be allocated. If more temporary variables are required for + /// an instruction, increase the value + /// [crate::interpreters::riscv32im::SCRATCH_SIZE] + fn alloc_scratch(&mut self) -> Self::Position; + + type Variable: Clone + + std::ops::Add + + std::ops::Sub + + std::ops::Mul + + std::fmt::Debug + + Zero + + One; + + // Returns the variable in the current row corresponding to a given column alias. + fn variable(&self, column: Self::Position) -> Self::Variable; + + /// Add a constraint to the proof system, asserting that + /// `assert_equals_zero` is 0. + fn add_constraint(&mut self, assert_equals_zero: Self::Variable); + + /// Activate the selector for the given instruction. + fn activate_selector(&mut self, selector: Instruction); + + /// Check that the witness value in `assert_equals_zero` is 0; otherwise abort. + fn check_is_zero(assert_equals_zero: &Self::Variable); + + /// Assert that the value `assert_equals_zero` is 0, and add a constraint in the proof system. + fn assert_is_zero(&mut self, assert_equals_zero: Self::Variable) { + Self::check_is_zero(&assert_equals_zero); + self.add_constraint(assert_equals_zero); + } + + /// Check that the witness values in `x` and `y` are equal; otherwise abort. + fn check_equal(x: &Self::Variable, y: &Self::Variable); + + /// Assert that the values `x` and `y` are equal, and add a constraint in the proof system. + fn assert_equal(&mut self, x: Self::Variable, y: Self::Variable) { + // NB: We use a different function to give a better error message for debugging. + Self::check_equal(&x, &y); + self.add_constraint(x - y); + } + + /// Check that the witness value `x` is a boolean (`0` or `1`); otherwise abort. + fn check_boolean(x: &Self::Variable); + + /// Assert that the value `x` is boolean, and add a constraint in the proof system. + fn assert_boolean(&mut self, x: Self::Variable) { + Self::check_boolean(&x); + self.add_constraint(x.clone() * x.clone() - x); + } + + fn add_lookup(&mut self, lookup: Lookup); + + fn instruction_counter(&self) -> Self::Variable; + + fn increase_instruction_counter(&mut self); + + /// Fetch the value of the general purpose register with index `idx` and store it in local + /// position `output`. + /// + /// # Safety + /// + /// No lookups or other constraints are added as part of this operation. The caller must + /// manually add the lookups for this operation. + unsafe fn fetch_register( + &mut self, + idx: &Self::Variable, + output: Self::Position, + ) -> Self::Variable; + + /// Set the general purpose register with index `idx` to `value` if `if_is_true` is true. + /// + /// # Safety + /// + /// No lookups or other constraints are added as part of this operation. The caller must + /// manually add the lookups for this operation. + unsafe fn push_register_if( + &mut self, + idx: &Self::Variable, + value: Self::Variable, + if_is_true: &Self::Variable, + ); + + /// Set the general purpose register with index `idx` to `value`. + /// + /// # Safety + /// + /// No lookups or other constraints are added as part of this operation. The caller must + /// manually add the lookups for this operation. + unsafe fn push_register(&mut self, idx: &Self::Variable, value: Self::Variable) { + self.push_register_if(idx, value, &Self::constant(1)) + } + + /// Fetch the last 'access index' for the general purpose register with index `idx`, and store + /// it in local position `output`. + /// + /// # Safety + /// + /// No lookups or other constraints are added as part of this operation. The caller must + /// manually add the lookups for this operation. + unsafe fn fetch_register_access( + &mut self, + idx: &Self::Variable, + output: Self::Position, + ) -> Self::Variable; + + /// Set the last 'access index' for the general purpose register with index `idx` to `value` if + /// `if_is_true` is true. + /// + /// # Safety + /// + /// No lookups or other constraints are added as part of this operation. The caller must + /// manually add the lookups for this operation. + unsafe fn push_register_access_if( + &mut self, + idx: &Self::Variable, + value: Self::Variable, + if_is_true: &Self::Variable, + ); + + /// Set the last 'access index' for the general purpose register with index `idx` to `value`. + /// + /// # Safety + /// + /// No lookups or other constraints are added as part of this operation. The caller must + /// manually add the lookups for this operation. + unsafe fn push_register_access(&mut self, idx: &Self::Variable, value: Self::Variable) { + self.push_register_access_if(idx, value, &Self::constant(1)) + } + + /// Access the general purpose register with index `idx`, adding constraints asserting that the + /// old value was `old_value` and that the new value will be `new_value`, if `if_is_true` is + /// true. + /// + /// # Safety + /// + /// Callers of this function must manually update the registers if required, this function will + /// only update the access counter. + unsafe fn access_register_if( + &mut self, + idx: &Self::Variable, + old_value: &Self::Variable, + new_value: &Self::Variable, + if_is_true: &Self::Variable, + ) { + let last_accessed = { + let last_accessed_location = self.alloc_scratch(); + unsafe { self.fetch_register_access(idx, last_accessed_location) } + }; + let instruction_counter = self.instruction_counter(); + let elapsed_time = instruction_counter.clone() - last_accessed.clone(); + let new_accessed = { + // Here, we write as if the register had been written *at the start of the next + // instruction*. This ensures that we can't 'time travel' within this + // instruction, and claim to read the value that we're about to write! + instruction_counter + Self::constant(1) + // A register should allow multiple accesses to the same register within the same instruction. + // In order to allow this, we always increase the instruction counter by 1. + }; + unsafe { self.push_register_access_if(idx, new_accessed.clone(), if_is_true) }; + self.add_lookup(Lookup::write_if( + if_is_true.clone(), + LookupTableIDs::RegisterLookup, + vec![idx.clone(), last_accessed, old_value.clone()], + )); + self.add_lookup(Lookup::read_if( + if_is_true.clone(), + LookupTableIDs::RegisterLookup, + vec![idx.clone(), new_accessed, new_value.clone()], + )); + self.range_check64(&elapsed_time); + + // Update instruction counter after accessing a register. + self.increase_instruction_counter(); + } + + fn read_register(&mut self, idx: &Self::Variable) -> Self::Variable { + let value = { + let value_location = self.alloc_scratch(); + unsafe { self.fetch_register(idx, value_location) } + }; + unsafe { + self.access_register(idx, &value, &value); + }; + value + } + + /// Access the general purpose register with index `idx`, adding constraints asserting that the + /// old value was `old_value` and that the new value will be `new_value`. + /// + /// # Safety + /// + /// Callers of this function must manually update the registers if required, this function will + /// only update the access counter. + unsafe fn access_register( + &mut self, + idx: &Self::Variable, + old_value: &Self::Variable, + new_value: &Self::Variable, + ) { + self.access_register_if(idx, old_value, new_value, &Self::constant(1)) + } + + fn write_register_if( + &mut self, + idx: &Self::Variable, + new_value: Self::Variable, + if_is_true: &Self::Variable, + ) { + let old_value = { + let value_location = self.alloc_scratch(); + unsafe { self.fetch_register(idx, value_location) } + }; + // Ensure that we only write 0 to the 0 register. + let actual_new_value = { + let idx_is_zero = self.is_zero(idx); + let pos = self.alloc_scratch(); + self.copy(&((Self::constant(1) - idx_is_zero) * new_value), pos) + }; + unsafe { + self.access_register_if(idx, &old_value, &actual_new_value, if_is_true); + }; + unsafe { + self.push_register_if(idx, actual_new_value, if_is_true); + }; + } + + fn write_register(&mut self, idx: &Self::Variable, new_value: Self::Variable) { + self.write_register_if(idx, new_value, &Self::constant(1)) + } + + /// Fetch the memory value at address `addr` and store it in local position `output`. + /// + /// # Safety + /// + /// No lookups or other constraints are added as part of this operation. The caller must + /// manually add the lookups for this memory operation. + unsafe fn fetch_memory( + &mut self, + addr: &Self::Variable, + output: Self::Position, + ) -> Self::Variable; + + /// Set the memory value at address `addr` to `value`. + /// + /// # Safety + /// + /// No lookups or other constraints are added as part of this operation. The caller must + /// manually add the lookups for this memory operation. + unsafe fn push_memory(&mut self, addr: &Self::Variable, value: Self::Variable); + + /// Fetch the last 'access index' that the memory at address `addr` was written at, and store + /// it in local position `output`. + /// + /// # Safety + /// + /// No lookups or other constraints are added as part of this operation. The caller must + /// manually add the lookups for this memory operation. + unsafe fn fetch_memory_access( + &mut self, + addr: &Self::Variable, + output: Self::Position, + ) -> Self::Variable; + + /// Set the last 'access index' for the memory at address `addr` to `value`. + /// + /// # Safety + /// + /// No lookups or other constraints are added as part of this operation. The caller must + /// manually add the lookups for this memory operation. + unsafe fn push_memory_access(&mut self, addr: &Self::Variable, value: Self::Variable); + + /// Access the memory address `addr`, adding constraints asserting that the old value was + /// `old_value` and that the new value will be `new_value`. + /// + /// # Safety + /// + /// Callers of this function must manually update the memory if required, this function will + /// only update the access counter. + unsafe fn access_memory( + &mut self, + addr: &Self::Variable, + old_value: &Self::Variable, + new_value: &Self::Variable, + ) { + let last_accessed = { + let last_accessed_location = self.alloc_scratch(); + unsafe { self.fetch_memory_access(addr, last_accessed_location) } + }; + let instruction_counter = self.instruction_counter(); + let elapsed_time = instruction_counter.clone() - last_accessed.clone(); + let new_accessed = { + // Here, we write as if the memory had been written *at the start of the next + // instruction*. This ensures that we can't 'time travel' within this + // instruction, and claim to read the value that we're about to write! + instruction_counter + Self::constant(1) + }; + unsafe { self.push_memory_access(addr, new_accessed.clone()) }; + self.add_lookup(Lookup::write_one( + LookupTableIDs::MemoryLookup, + vec![addr.clone(), last_accessed, old_value.clone()], + )); + self.add_lookup(Lookup::read_one( + LookupTableIDs::MemoryLookup, + vec![addr.clone(), new_accessed, new_value.clone()], + )); + self.range_check64(&elapsed_time); + + // Update instruction counter after accessing a memory address. + self.increase_instruction_counter(); + } + + fn read_memory(&mut self, addr: &Self::Variable) -> Self::Variable { + let value = { + let value_location = self.alloc_scratch(); + unsafe { self.fetch_memory(addr, value_location) } + }; + unsafe { + self.access_memory(addr, &value, &value); + }; + value + } + + fn write_memory(&mut self, addr: &Self::Variable, new_value: Self::Variable) { + let old_value = { + let value_location = self.alloc_scratch(); + unsafe { self.fetch_memory(addr, value_location) } + }; + unsafe { + self.access_memory(addr, &old_value, &new_value); + }; + unsafe { + self.push_memory(addr, new_value); + }; + } + + /// Adds a lookup to the RangeCheck16Lookup table + fn lookup_16bits(&mut self, value: &Self::Variable) { + self.add_lookup(Lookup::read_one( + LookupTableIDs::RangeCheck16Lookup, + vec![value.clone()], + )); + } + + /// Range checks with 2 lookups to the RangeCheck16Lookup table that a value + /// is at most 2^`bits`-1 (bits <= 16). + fn range_check16(&mut self, value: &Self::Variable, bits: u32) { + assert!(bits <= 16); + // 0 <= value < 2^bits + // First, check lowerbound: 0 <= value < 2^16 + self.lookup_16bits(value); + // Second, check upperbound: value + 2^16 - 2^bits < 2^16 + self.lookup_16bits(&(value.clone() + Self::constant(1 << 16) - Self::constant(1 << bits))); + } + + /// Adds a lookup to the ByteLookup table + fn lookup_8bits(&mut self, value: &Self::Variable) { + self.add_lookup(Lookup::read_one( + LookupTableIDs::ByteLookup, + vec![value.clone()], + )); + } + + /// Range checks with 2 lookups to the ByteLookup table that a value + /// is at most 2^`bits`-1 (bits <= 8). + fn range_check8(&mut self, value: &Self::Variable, bits: u32) { + assert!(bits <= 8); + // 0 <= value < 2^bits + // First, check lowerbound: 0 <= value < 2^8 + self.lookup_8bits(value); + // Second, check upperbound: value + 2^8 - 2^bits < 2^8 + self.lookup_8bits(&(value.clone() + Self::constant(1 << 8) - Self::constant(1 << bits))); + } + + /// Adds a lookup to the AtMost4Lookup table + fn lookup_2bits(&mut self, value: &Self::Variable) { + self.add_lookup(Lookup::read_one( + LookupTableIDs::AtMost4Lookup, + vec![value.clone()], + )); + } + + fn range_check64(&mut self, _value: &Self::Variable) { + // TODO + } + + fn set_instruction_pointer(&mut self, ip: Self::Variable) { + let idx = Self::constant(REGISTER_CURRENT_IP as u32); + let new_accessed = self.instruction_counter() + Self::constant(1); + unsafe { + self.push_register_access(&idx, new_accessed.clone()); + } + unsafe { + self.push_register(&idx, ip.clone()); + } + self.add_lookup(Lookup::read_one( + LookupTableIDs::RegisterLookup, + vec![idx, new_accessed, ip], + )); + } + + fn get_instruction_pointer(&mut self) -> Self::Variable { + let idx = Self::constant(REGISTER_CURRENT_IP as u32); + let ip = { + let value_location = self.alloc_scratch(); + unsafe { self.fetch_register(&idx, value_location) } + }; + self.add_lookup(Lookup::write_one( + LookupTableIDs::RegisterLookup, + vec![idx, self.instruction_counter(), ip.clone()], + )); + ip + } + + fn set_next_instruction_pointer(&mut self, ip: Self::Variable) { + let idx = Self::constant(REGISTER_NEXT_IP as u32); + let new_accessed = self.instruction_counter() + Self::constant(1); + unsafe { + self.push_register_access(&idx, new_accessed.clone()); + } + unsafe { + self.push_register(&idx, ip.clone()); + } + self.add_lookup(Lookup::read_one( + LookupTableIDs::RegisterLookup, + vec![idx, new_accessed, ip], + )); + } + + fn get_next_instruction_pointer(&mut self) -> Self::Variable { + let idx = Self::constant(REGISTER_NEXT_IP as u32); + let ip = { + let value_location = self.alloc_scratch(); + unsafe { self.fetch_register(&idx, value_location) } + }; + self.add_lookup(Lookup::write_one( + LookupTableIDs::RegisterLookup, + vec![idx, self.instruction_counter(), ip.clone()], + )); + ip + } + + fn constant(x: u32) -> Self::Variable; + + /// Extract the bits from the variable `x` between `highest_bit` and `lowest_bit`, and store + /// the result in `position`. + /// `lowest_bit` becomes the least-significant bit of the resulting value. + /// + /// # Safety + /// + /// There are no constraints on the returned value; callers must assert the relationship with + /// the source variable `x` and that the returned value fits in `highest_bit - lowest_bit` + /// bits. + /// + /// Do not call this function with highest_bit - lowest_bit >= 32. + // TODO: embed the range check in the function when highest_bit - lowest_bit <= 16? + unsafe fn bitmask( + &mut self, + x: &Self::Variable, + highest_bit: u32, + lowest_bit: u32, + position: Self::Position, + ) -> Self::Variable; + + /// Return the result of shifting `x` by `by`, storing the result in `position`. + /// + /// # Safety + /// + /// There are no constraints on the returned value; callers must assert the relationship with + /// the source variable `x` and the shift amount `by`. + unsafe fn shift_left( + &mut self, + x: &Self::Variable, + by: &Self::Variable, + position: Self::Position, + ) -> Self::Variable; + + /// Return the result of shifting `x` by `by`, storing the result in `position`. + /// + /// # Safety + /// + /// There are no constraints on the returned value; callers must assert the relationship with + /// the source variable `x` and the shift amount `by`. + unsafe fn shift_right( + &mut self, + x: &Self::Variable, + by: &Self::Variable, + position: Self::Position, + ) -> Self::Variable; + + /// Return the result of shifting `x` by `by`, storing the result in `position`. + /// + /// # Safety + /// + /// There are no constraints on the returned value; callers must assert the relationship with + /// the source variable `x` and the shift amount `by`. + unsafe fn shift_right_arithmetic( + &mut self, + x: &Self::Variable, + by: &Self::Variable, + position: Self::Position, + ) -> Self::Variable; + + /// Returns 1 if `x` is 0, or 0 otherwise, storing the result in `position`. + /// + /// # Safety + /// + /// There are no constraints on the returned value; callers must assert the relationship with + /// `x`. + unsafe fn test_zero(&mut self, x: &Self::Variable, position: Self::Position) -> Self::Variable; + + /// Returns `x^(-1)`, or `0` if `x` is `0`, storing the result in `position`. + /// + /// # Safety + /// + /// There are no constraints on the returned value; callers must assert the relationship with + /// `x`. + /// + /// The value returned may be a placeholder; callers should be careful not to depend directly + /// on the value stored in the variable. + unsafe fn inverse_or_zero( + &mut self, + x: &Self::Variable, + position: Self::Position, + ) -> Self::Variable; + + fn is_zero(&mut self, x: &Self::Variable) -> Self::Variable; + + /// Returns 1 if `x` is equal to `y`, or 0 otherwise, storing the result in `position`. + fn equal(&mut self, x: &Self::Variable, y: &Self::Variable) -> Self::Variable; + + /// Returns 1 if `x < y` as unsigned integers, or 0 otherwise, storing the result in + /// `position`. + /// + /// # Safety + /// + /// There are no constraints on the returned value; callers must assert that the value + /// correctly represents the relationship between `x` and `y` + unsafe fn test_less_than( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable; + + /// Returns 1 if `x < y` as signed integers, or 0 otherwise, storing the result in `position`. + /// + /// # Safety + /// + /// There are no constraints on the returned value; callers must assert that the value + /// correctly represents the relationship between `x` and `y` + unsafe fn test_less_than_signed( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable; + + /// Returns `x or y`, storing the result in `position`. + /// + /// # Safety + /// + /// There are no constraints on the returned value; callers must manually add constraints to + /// ensure that it is correctly constructed. + unsafe fn and_witness( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable; + + /// Returns `x or y`, storing the result in `position`. + /// + /// # Safety + /// + /// There are no constraints on the returned value; callers must manually add constraints to + /// ensure that it is correctly constructed. + unsafe fn or_witness( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable; + + /// Returns `x nor y`, storing the result in `position`. + /// + /// # Safety + /// + /// There are no constraints on the returned value; callers must manually add constraints to + /// ensure that it is correctly constructed. + unsafe fn nor_witness( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable; + + /// Returns `x xor y`, storing the result in `position`. + /// + /// # Safety + /// + /// There are no constraints on the returned value; callers must manually add constraints to + /// ensure that it is correctly constructed. + unsafe fn xor_witness( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable; + + /// Returns `x + y` and the overflow bit, storing the results in `position_out` and + /// `position_overflow` respectively. + /// + /// # Safety + /// + /// There are no constraints on the returned values; callers must manually add constraints to + /// ensure that they are correctly constructed. + unsafe fn add_witness( + &mut self, + y: &Self::Variable, + x: &Self::Variable, + out_position: Self::Position, + overflow_position: Self::Position, + ) -> (Self::Variable, Self::Variable); + + /// Returns `x + y` and the underflow bit, storing the results in `position_out` and + /// `position_underflow` respectively. + /// + /// # Safety + /// + /// There are no constraints on the returned values; callers must manually add constraints to + /// ensure that they are correctly constructed. + unsafe fn sub_witness( + &mut self, + y: &Self::Variable, + x: &Self::Variable, + out_position: Self::Position, + underflow_position: Self::Position, + ) -> (Self::Variable, Self::Variable); + + /// Returns `x * y`, where `x` and `y` are treated as integers, storing the result in `position`. + /// + /// # Safety + /// + /// There are no constraints on the returned value; callers must manually add constraints to + /// ensure that it is correctly constructed. + unsafe fn mul_signed_witness( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable; + + /// Returns `((x * y) >> 32`, storing the results in `position`. + /// + /// # Safety + /// + /// There are no constraints on the returned values; callers must manually add constraints to + /// ensure that the pair of returned values correspond to the given values `x` and `y`, and + /// that they fall within the desired range. + unsafe fn mul_hi_signed( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable; + + /// Returns `(x * y) & ((1 << 32) - 1))`, storing the results in `position` + /// + /// # Safety + /// + /// There are no constraints on the returned values; callers must manually add constraints to + /// ensure that the pair of returned values correspond to the given values `x` and `y`, and + /// that they fall within the desired range. + unsafe fn mul_lo_signed( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable; + + /// Returns `((x * y) >> 32`, storing the results in `position`. + /// + /// # Safety + /// + /// There are no constraints on the returned values; callers must manually add constraints to + /// ensure that the pair of returned values correspond to the given values `x` and `y`, and + /// that they fall within the desired range. + unsafe fn mul_hi( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable; + + /// Returns `(x * y) & ((1 << 32) - 1))`, storing the results in `position`. + /// + /// # Safety + /// + /// There are no constraints on the returned values; callers must manually add constraints to + /// ensure that the pair of returned values correspond to the given values `x` and `y`, and + /// that they fall within the desired range. + unsafe fn mul_lo( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable; + + /// Returns `((x * y) >> 32`, storing the results in `position`. + /// + /// # Safety + /// + /// There are no constraints on the returned values; callers must manually add constraints to + /// ensure that the pair of returned values correspond to the given values `x` and `y`, and + /// that they fall within the desired range. + unsafe fn mul_hi_signed_unsigned( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable; + + /// Returns `x / y`, storing the results in `position`. + /// + /// # Safety + /// + /// There are no constraints on the returned values; callers must manually add constraints to + /// ensure that the pair of returned values correspond to the given values `x` and `y`, and + /// that they fall within the desired range. + unsafe fn div_signed( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable; + + /// Returns `x % y`, storing the results in `position`. + /// + /// # Safety + /// + /// There are no constraints on the returned values; callers must manually add constraints to + /// ensure that the pair of returned values correspond to the given values `x` and `y`, and + /// that they fall within the desired range. + unsafe fn mod_signed( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable; + + /// Returns `x / y`, storing the results in `position`. + /// + /// # Safety + /// + /// There are no constraints on the returned values; callers must manually add constraints to + /// ensure that the pair of returned values correspond to the given values `x` and `y`, and + /// that they fall within the desired range. + unsafe fn div( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable; + + /// Returns `x % y`, storing the results in `position`. + /// + /// # Safety + /// + /// There are no constraints on the returned values; callers must manually add constraints to + /// ensure that the pair of returned values correspond to the given values `x` and `y`, and + /// that they fall within the desired range. + unsafe fn mod_unsigned( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable; + + /// Returns the number of leading 0s in `x`, storing the result in `position`. + /// + /// # Safety + /// + /// There are no constraints on the returned value; callers must manually add constraints to + /// ensure that it is correctly constructed. + unsafe fn count_leading_zeros( + &mut self, + x: &Self::Variable, + position: Self::Position, + ) -> Self::Variable; + + /// Returns the number of leading 1s in `x`, storing the result in `position`. + /// + /// # Safety + /// + /// There are no constraints on the returned value; callers must manually add constraints to + /// ensure that it is correctly constructed. + unsafe fn count_leading_ones( + &mut self, + x: &Self::Variable, + position: Self::Position, + ) -> Self::Variable; + + fn copy(&mut self, x: &Self::Variable, position: Self::Position) -> Self::Variable; + + /// Increases the heap pointer by `by_amount` if `if_is_true` is `1`, and returns the previous + /// value of the heap pointer. + fn increase_heap_pointer( + &mut self, + by_amount: &Self::Variable, + if_is_true: &Self::Variable, + ) -> Self::Variable { + let idx = Self::constant(REGISTER_HEAP_POINTER as u32); + let old_ptr = { + let value_location = self.alloc_scratch(); + unsafe { self.fetch_register(&idx, value_location) } + }; + let new_ptr = old_ptr.clone() + by_amount.clone(); + unsafe { + self.access_register_if(&idx, &old_ptr, &new_ptr, if_is_true); + }; + unsafe { + self.push_register_if(&idx, new_ptr, if_is_true); + }; + old_ptr + } + + fn set_halted(&mut self, flag: Self::Variable); + + /// Given a variable `x`, this function extends it to a signed integer of + /// `bitlength` bits. + fn sign_extend(&mut self, x: &Self::Variable, bitlength: u32) -> Self::Variable { + // FIXME: Constrain `high_bit` + let high_bit = { + let pos = self.alloc_scratch(); + unsafe { self.bitmask(x, bitlength, bitlength - 1, pos) } + }; + high_bit * Self::constant(((1 << (32 - bitlength)) - 1) << bitlength) + x.clone() + } + + fn report_exit(&mut self, exit_code: &Self::Variable); + + fn reset(&mut self); +} + +pub fn interpret_instruction(env: &mut Env, instr: Instruction) { + env.activate_selector(instr); + match instr { + Instruction::RType(rtype) => interpret_rtype(env, rtype), + Instruction::IType(itype) => interpret_itype(env, itype), + Instruction::SType(stype) => interpret_stype(env, stype), + Instruction::SBType(sbtype) => interpret_sbtype(env, sbtype), + Instruction::UType(utype) => interpret_utype(env, utype), + Instruction::UJType(ujtype) => interpret_ujtype(env, ujtype), + Instruction::SyscallType(syscall) => interpret_syscall(env, syscall), + Instruction::MType(mtype) => interpret_mtype(env, mtype), + } +} + +/// Interpret an R-type instruction. +/// The encoding of an R-type instruction is as follows: +/// ```text +/// | 31 25 | 24 20 | 19 15 | 14 12 | 11 7 | 6 0 | +/// | funct5 & funct 2 | rs2 | rs1 | funct3 | rd | opcode | +/// ``` +/// Following the documentation found +/// [here](https://www.cs.cornell.edu/courses/cs3410/2024fa/assignments/cpusim/riscv-instructions.pdf) +pub fn interpret_rtype(env: &mut Env, instr: RInstruction) { + let instruction_pointer = env.get_instruction_pointer(); + let next_instruction_pointer = env.get_next_instruction_pointer(); + + let instruction = { + let v0 = env.read_memory(&instruction_pointer); + let v1 = env.read_memory(&(instruction_pointer.clone() + Env::constant(1))); + let v2 = env.read_memory(&(instruction_pointer.clone() + Env::constant(2))); + let v3 = env.read_memory(&(instruction_pointer.clone() + Env::constant(3))); + (v3 * Env::constant(1 << 24)) + + (v2 * Env::constant(1 << 16)) + + (v1 * Env::constant(1 << 8)) + + v0 + }; + + // FIXME: constrain the opcode to match the instruction given as a parameter + let opcode = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 7, 0, pos) } + }; + env.range_check8(&opcode, 7); + + let rd = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 12, 7, pos) } + }; + env.range_check8(&rd, 5); + + let funct3 = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 15, 12, pos) } + }; + env.range_check8(&funct3, 3); + + let rs1 = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 20, 15, pos) } + }; + env.range_check8(&rs1, 5); + + let rs2 = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 25, 20, pos) } + }; + env.range_check8(&rs2, 5); + + let funct2 = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 27, 25, pos) } + }; + env.range_check8(&funct2, 2); + + let funct5 = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 32, 27, pos) } + }; + env.range_check8(&funct5, 5); + + // Check correctness of decomposition + env.add_constraint( + instruction + - (opcode.clone() * Env::constant(1 << 0)) // opcode at bits 0-6 + - (rd.clone() * Env::constant(1 << 7)) // rd at bits 7-11 + - (funct3.clone() * Env::constant(1 << 12)) // funct3 at bits 12-14 + - (rs1.clone() * Env::constant(1 << 15)) // rs1 at bits 15-19 + - (rs2.clone() * Env::constant(1 << 20)) // rs2 at bits 20-24 + - (funct2.clone() * Env::constant(1 << 25)) // funct2 at bits 25-26 + - (funct5.clone() * Env::constant(1 << 27)), // funct5 at bits 27-31 + ); + + match instr { + RInstruction::Add => { + // add: x[rd] = x[rs1] + x[rs2] + let local_rs1 = env.read_register(&rs1); + let local_rs2 = env.read_register(&rs2); + let overflow_scratch = env.alloc_scratch(); + let rd_scratch = env.alloc_scratch(); + let local_rd = unsafe { + let (local_rd, _overflow) = + env.add_witness(&local_rs1, &local_rs2, rd_scratch, overflow_scratch); + local_rd + }; + env.write_register(&rd, local_rd); + + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RInstruction::Sub => { + /* sub: x[rd] = x[rs1] - x[rs2] */ + let local_rs1 = env.read_register(&rs1); + let local_rs2 = env.read_register(&rs2); + let underflow_scratch = env.alloc_scratch(); + let rd_scratch = env.alloc_scratch(); + let local_rd = unsafe { + let (local_rd, _underflow) = + env.sub_witness(&local_rs1, &local_rs2, rd_scratch, underflow_scratch); + local_rd + }; + env.write_register(&rd, local_rd); + + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RInstruction::ShiftLeftLogical => { + /* sll: x[rd] = x[rs1] << x[rs2] */ + let local_rs1 = env.read_register(&rs1); + let local_rs2 = env.read_register(&rs2); + let local_rd = unsafe { + let rd_scratch = env.alloc_scratch(); + env.shift_left(&local_rs1, &local_rs2, rd_scratch) + }; + env.write_register(&rd, local_rd); + + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RInstruction::SetLessThan => { + /* slt: x[rd] = (x[rs1] < x[rs2]) ? 1 : 0 */ + let local_rs1 = env.read_register(&rs1); + let local_rs2 = env.read_register(&rs2); + let local_rd = unsafe { + let rd_scratch = env.alloc_scratch(); + env.test_less_than_signed(&local_rs1, &local_rs2, rd_scratch) + }; + env.write_register(&rd, local_rd); + + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RInstruction::SetLessThanUnsigned => { + /* sltu: x[rd] = (x[rs1] < (u)x[rs2]) ? 1 : 0 */ + let local_rs1 = env.read_register(&rs1); + let local_rs2 = env.read_register(&rs2); + let local_rd = unsafe { + let pos = env.alloc_scratch(); + env.test_less_than(&local_rs1, &local_rs2, pos) + }; + env.write_register(&rd, local_rd); + + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RInstruction::Xor => { + /* xor: x[rd] = x[rs1] ^ x[rs2] */ + let local_rs1 = env.read_register(&rs1); + let local_rs2 = env.read_register(&rs2); + let local_rd = unsafe { + let pos = env.alloc_scratch(); + env.xor_witness(&local_rs1, &local_rs2, pos) + }; + env.write_register(&rd, local_rd); + + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RInstruction::ShiftRightLogical => { + /* srl: x[rd] = x[rs1] >> x[rs2] */ + let local_rs1 = env.read_register(&rs1); + let local_rs2 = env.read_register(&rs2); + let local_rd = unsafe { + let pos = env.alloc_scratch(); + env.shift_right(&local_rs1, &local_rs2, pos) + }; + env.write_register(&rd, local_rd); + + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RInstruction::ShiftRightArithmetic => { + /* sra: x[rd] = x[rs1] >> x[rs2] */ + let local_rs1 = env.read_register(&rs1); + let local_rs2 = env.read_register(&rs2); + let local_rd = unsafe { + let pos = env.alloc_scratch(); + env.shift_right_arithmetic(&local_rs1, &local_rs2, pos) + }; + env.write_register(&rd, local_rd); + + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RInstruction::Or => { + /* or: x[rd] = x[rs1] | x[rs2] */ + let local_rs1 = env.read_register(&rs1); + let local_rs2 = env.read_register(&rs2); + let local_rd = unsafe { + let pos = env.alloc_scratch(); + env.or_witness(&local_rs1, &local_rs2, pos) + }; + env.write_register(&rd, local_rd); + + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RInstruction::And => { + /* and: x[rd] = x[rs1] & x[rs2] */ + let local_rs1 = env.read_register(&rs1); + let local_rs2 = env.read_register(&rs2); + let local_rd = unsafe { + let pos = env.alloc_scratch(); + env.and_witness(&local_rs1, &local_rs2, pos) + }; + env.write_register(&rd, local_rd); + + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + RInstruction::Fence => { + unimplemented!("Fence") + } + RInstruction::FenceI => { + unimplemented!("FenceI") + } + }; +} + +/// Interpret an I-type instruction. +/// The encoding of an I-type instruction is as follows: +/// ```text +/// | 31 20 | 19 15 | 14 12 | 11 7 | 6 0 | +/// | immediate | rs1 | funct3 | rd | opcode | +/// ``` +/// Following the documentation found +/// [here](https://www.cs.cornell.edu/courses/cs3410/2024fa/assignments/cpusim/riscv-instructions.pdf) +pub fn interpret_itype(env: &mut Env, instr: IInstruction) { + let instruction_pointer = env.get_instruction_pointer(); + let next_instruction_pointer = env.get_next_instruction_pointer(); + + let instruction = { + let v0 = env.read_memory(&instruction_pointer); + let v1 = env.read_memory(&(instruction_pointer.clone() + Env::constant(1))); + let v2 = env.read_memory(&(instruction_pointer.clone() + Env::constant(2))); + let v3 = env.read_memory(&(instruction_pointer.clone() + Env::constant(3))); + (v3 * Env::constant(1 << 24)) + + (v2 * Env::constant(1 << 16)) + + (v1 * Env::constant(1 << 8)) + + v0 + }; + + let opcode = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 7, 0, pos) } + }; + env.range_check8(&opcode, 7); + + let rd = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 12, 7, pos) } + }; + env.range_check8(&rd, 5); + + let funct3 = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 15, 12, pos) } + }; + env.range_check8(&funct3, 3); + + let rs1 = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 20, 15, pos) } + }; + env.range_check8(&rs1, 5); + + let imm = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 32, 20, pos) } + }; + + env.range_check16(&imm, 12); + + // check correctness of decomposition + env.add_constraint( + instruction + - (opcode.clone() * Env::constant(1 << 0)) // opcode at bits 0-6 + - (rd.clone() * Env::constant(1 << 7)) // rd at bits 7-11 + - (funct3.clone() * Env::constant(1 << 12)) // funct3 at bits 12-14 + - (rs1.clone() * Env::constant(1 << 15)) // rs1 at bits 15-19 + - (imm.clone() * Env::constant(1 << 20)), // imm at bits 20-32 + ); + + match instr { + IInstruction::LoadByte => { + // lb: x[rd] = sext(M[x[rs1] + sext(offset)][7:0]) + let local_rs1 = env.read_register(&rs1); + let local_imm = env.sign_extend(&imm, 12); + let address = { + let address_scratch = env.alloc_scratch(); + let overflow_scratch = env.alloc_scratch(); + let (address, _overflow) = unsafe { + env.add_witness(&local_rs1, &local_imm, address_scratch, overflow_scratch) + }; + address + }; + // Add a range check here for address + let value = env.read_memory(&address); + let value = env.sign_extend(&value, 8); + env.write_register(&rd, value); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + IInstruction::LoadHalf => { + // lh: x[rd] = sext(M[x[rs1] + sext(offset)][15:0]) + let local_rs1 = env.read_register(&rs1); + let local_imm = env.sign_extend(&imm, 12); + let address = { + let address_scratch = env.alloc_scratch(); + let overflow_scratch = env.alloc_scratch(); + let (address, _overflow) = unsafe { + env.add_witness(&local_rs1, &local_imm, address_scratch, overflow_scratch) + }; + address + }; + // Add a range check here for address + let v0 = env.read_memory(&address); + let v1 = env.read_memory(&(address.clone() + Env::constant(1))); + let value = (v0 * Env::constant(1 << 8)) + v1; + let value = env.sign_extend(&value, 16); + env.write_register(&rd, value); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + IInstruction::LoadWord => { + // lw: x[rd] = sext(M[x[rs1] + sext(offset)][31:0]) + let base = env.read_register(&rs1); + let offset = env.sign_extend(&imm, 12); + let address = { + let address_scratch = env.alloc_scratch(); + let overflow_scratch = env.alloc_scratch(); + let (address, _overflow) = + unsafe { env.add_witness(&base, &offset, address_scratch, overflow_scratch) }; + address + }; + // Add a range check here for address + let v0 = env.read_memory(&address); + let v1 = env.read_memory(&(address.clone() + Env::constant(1))); + let v2 = env.read_memory(&(address.clone() + Env::constant(2))); + let v3 = env.read_memory(&(address.clone() + Env::constant(3))); + let value = (v0 * Env::constant(1 << 24)) + + (v1 * Env::constant(1 << 16)) + + (v2 * Env::constant(1 << 8)) + + v3; + let value = env.sign_extend(&value, 32); + env.write_register(&rd, value); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + IInstruction::LoadByteUnsigned => { + //lbu: x[rd] = M[x[rs1] + sext(offset)][7:0] + let local_rs1 = env.read_register(&rs1); + let local_imm = env.sign_extend(&imm, 12); + let address = { + let address_scratch = env.alloc_scratch(); + let overflow_scratch = env.alloc_scratch(); + let (address, _overflow) = unsafe { + env.add_witness(&local_rs1, &local_imm, address_scratch, overflow_scratch) + }; + address + }; + // lhu: Add a range check here for address + let value = env.read_memory(&address); + env.write_register(&rd, value); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + IInstruction::LoadHalfUnsigned => { + // lhu: x[rd] = M[x[rs1] + sext(offset)][15:0] + let local_rs1 = env.read_register(&rs1); + let local_imm = env.sign_extend(&imm, 12); + let address = { + let address_scratch = env.alloc_scratch(); + let overflow_scratch = env.alloc_scratch(); + let (address, _overflow) = unsafe { + env.add_witness(&local_rs1, &local_imm, address_scratch, overflow_scratch) + }; + address + }; + // Add a range check here for address + let v0 = env.read_memory(&address); + let v1 = env.read_memory(&(address.clone() + Env::constant(1))); + let value = (v0 * Env::constant(1 << 8)) + v1; + env.write_register(&rd, value); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + IInstruction::ShiftLeftLogicalImmediate => { + unimplemented!("ShiftLeftLogicalImmediate") + } + IInstruction::ShiftRightLogicalImmediate => { + unimplemented!("ShiftRightLogicalImmediate") + } + IInstruction::ShiftRightArithmeticImmediate => { + unimplemented!("ShiftRightArithmeticImmediate") + } + IInstruction::SetLessThanImmediate => { + unimplemented!("SetLessThanImmediate") + } + IInstruction::SetLessThanImmediateUnsigned => { + unimplemented!("SetLessThanImmediateUnsigned") + } + IInstruction::AddImmediate => { + // addi: x[rd] = x[rs1] + sext(immediate) + let local_rs1 = env.read_register(&(rs1.clone())); + let local_imm = env.sign_extend(&imm, 12); + let overflow_scratch = env.alloc_scratch(); + let rd_scratch = env.alloc_scratch(); + let local_rd = unsafe { + let (local_rd, _overflow) = + env.add_witness(&local_rs1, &local_imm, rd_scratch, overflow_scratch); + local_rd + }; + env.write_register(&rd, local_rd); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + IInstruction::XorImmediate => { + // xori: x[rd] = x[rs1] ^ sext(immediate) + let local_rs1 = env.read_register(&rs1); + let local_imm = env.sign_extend(&imm, 12); + let rd_scratch = env.alloc_scratch(); + let local_rd = unsafe { env.xor_witness(&local_rs1, &local_imm, rd_scratch) }; + env.write_register(&rd, local_rd); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + IInstruction::OrImmediate => { + // ori: x[rd] = x[rs1] | sext(immediate) + let local_rs1 = env.read_register(&rs1); + let local_imm = env.sign_extend(&imm, 12); + let rd_scratch = env.alloc_scratch(); + let local_rd = unsafe { env.or_witness(&local_rs1, &local_imm, rd_scratch) }; + env.write_register(&rd, local_rd); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + IInstruction::AndImmediate => { + // andi: x[rd] = x[rs1] & sext(immediate) + let local_rs1 = env.read_register(&rs1); + let local_imm = env.sign_extend(&imm, 12); + let rd_scratch = env.alloc_scratch(); + let local_rd = unsafe { env.and_witness(&local_rs1, &local_imm, rd_scratch) }; + env.write_register(&rd, local_rd); + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + IInstruction::JumpAndLinkRegister => { + let addr = env.read_register(&rs1); + // jalr: + // t = pc+4; + // pc = (x[rs1] + sext(offset)) & ∼1; <- NOT NOW + // pc = (x[rs1] + sext(offset)); <- PLEASE FIXME + // x[rd] = t + let offset = env.sign_extend(&imm, 12); + let new_addr = { + let res_scratch = env.alloc_scratch(); + let overflow_scratch = env.alloc_scratch(); + let (res, _overflow) = + unsafe { env.add_witness(&addr, &offset, res_scratch, overflow_scratch) }; + res + }; + env.write_register(&rd, next_instruction_pointer.clone()); + env.set_instruction_pointer(new_addr.clone()); + env.set_next_instruction_pointer(new_addr.clone() + Env::constant(4u32)); + } + }; +} + +/// Interpret an S-type instruction. +/// The encoding of an S-type instruction is as follows: +/// ```text +/// | 31 25 | 24 20 | 19 15 | 14 12 | 11 7 | 6 0 | +/// | immediate | rs2 | rs1 | funct3 | imm | opcode | +/// ``` +/// Following the documentation found +/// [here](https://www.cs.cornell.edu/courses/cs3410/2024fa/assignments/cpusim/riscv-instructions.pdf) +pub fn interpret_stype(env: &mut Env, instr: SInstruction) { + let instruction_pointer = env.get_instruction_pointer(); + let _next_instruction_pointer = env.get_next_instruction_pointer(); + + let instruction = { + let v0 = env.read_memory(&instruction_pointer); + let v1 = env.read_memory(&(instruction_pointer.clone() + Env::constant(1))); + let v2 = env.read_memory(&(instruction_pointer.clone() + Env::constant(2))); + let v3 = env.read_memory(&(instruction_pointer.clone() + Env::constant(3))); + (v3 * Env::constant(1 << 24)) + + (v2 * Env::constant(1 << 16)) + + (v1 * Env::constant(1 << 8)) + + v0 + }; + + let opcode = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 7, 0, pos) } + }; + env.range_check8(&opcode, 7); + + let imm1 = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 12, 7, pos) } + }; + env.range_check8(&imm1, 5); + + let funct3 = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 15, 12, pos) } + }; + env.range_check8(&funct3, 3); + + let rs1 = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 20, 15, pos) } + }; + env.range_check8(&rs1, 5); + + let rs2 = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 25, 20, pos) } + }; + env.range_check8(&rs2, 5); + + let imm2 = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 32, 25, pos) } + }; + env.range_check16(&imm2, 12); + + match instr { + SInstruction::StoreByte => { + unimplemented!("StoreByte") + } + SInstruction::StoreHalf => { + unimplemented!("StoreHalf") + } + SInstruction::StoreWord => { + unimplemented!("StoreWord") + } + }; +} + +/// Interpret an SB-type instruction. +/// The encoding of an SB-type instruction is as follows: +/// ```text +/// | 31 25 | 24 20 | 19 15 | 14 12 | 11 7 | 6 0 | +/// | imm2 | rs2 | rs1 | funct3 | imm1 | opcode | +/// ``` +/// Following the documentation found +/// [here](https://www.cs.cornell.edu/courses/cs3410/2024fa/assignments/cpusim/riscv-instructions.pdf) +pub fn interpret_sbtype(env: &mut Env, instr: SBInstruction) { + let instruction_pointer = env.get_instruction_pointer(); + let _next_instruction_pointer = env.get_next_instruction_pointer(); + + let instruction = { + let v0 = env.read_memory(&instruction_pointer); + let v1 = env.read_memory(&(instruction_pointer.clone() + Env::constant(1))); + let v2 = env.read_memory(&(instruction_pointer.clone() + Env::constant(2))); + let v3 = env.read_memory(&(instruction_pointer.clone() + Env::constant(3))); + (v3 * Env::constant(1 << 24)) + + (v2 * Env::constant(1 << 16)) + + (v1 * Env::constant(1 << 8)) + + v0 + }; + + let opcode = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 7, 0, pos) } + }; + env.range_check8(&opcode, 7); + + // FIXME: trickier + let imm1 = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 12, 7, pos) } + }; + env.range_check8(&imm1, 5); + + let funct3 = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 15, 12, pos) } + }; + env.range_check8(&funct3, 3); + + let rs1 = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 20, 15, pos) } + }; + env.range_check8(&rs1, 5); + + let rs2 = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 25, 20, pos) } + }; + env.range_check8(&rs2, 5); + + // FIXME: trickier + let imm2 = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 32, 25, pos) } + }; + env.range_check16(&imm2, 12); + + // FIXME: check correctness of decomposition + + match instr { + SBInstruction::BranchEq => { + unimplemented!("BranchEq") + } + SBInstruction::BranchNeq => { + unimplemented!("BranchNeq") + } + SBInstruction::BranchLessThan => { + unimplemented!("BranchLessThan") + } + SBInstruction::BranchGreaterThanEqual => { + unimplemented!("BranchGreaterThanEqual") + } + SBInstruction::BranchLessThanUnsigned => { + unimplemented!("BranchLessThanUnsigned") + } + SBInstruction::BranchGreaterThanEqualUnsigned => { + unimplemented!("BranchGreaterThanEqualUnsigned") + } + }; +} + +/// Interpret an U-type instruction. +/// The encoding of an U-type instruction is as follows: +/// ```text +/// | 31 12 | 11 7 | 6 0 | +/// | immediate | rd | opcode | +/// ``` +/// Following the documentation found +/// [here](https://www.cs.cornell.edu/courses/cs3410/2024fa/assignments/cpusim/riscv-instructions.pdf) +pub fn interpret_utype(env: &mut Env, instr: UInstruction) { + let instruction_pointer = env.get_instruction_pointer(); + let _next_instruction_pointer = env.get_next_instruction_pointer(); + + let instruction = { + let v0 = env.read_memory(&instruction_pointer); + let v1 = env.read_memory(&(instruction_pointer.clone() + Env::constant(1))); + let v2 = env.read_memory(&(instruction_pointer.clone() + Env::constant(2))); + let v3 = env.read_memory(&(instruction_pointer.clone() + Env::constant(3))); + (v3 * Env::constant(1 << 24)) + + (v2 * Env::constant(1 << 16)) + + (v1 * Env::constant(1 << 8)) + + v0 + }; + + let opcode = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 7, 0, pos) } + }; + env.range_check8(&opcode, 7); + + let rd = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 12, 7, pos) } + }; + env.range_check8(&rd, 5); + + let imm = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 32, 12, pos) } + }; + // FIXME: rangecheck + + // check correctness of decomposition of U type function + env.add_constraint( + instruction + - (opcode.clone() * Env::constant(1 << 0)) // opcode at bits 0-6 + - (rd.clone() * Env::constant(1 << 7)) // rd at bits 7-11 + - (imm.clone() * Env::constant(1 << 12)), // imm at bits 12-31 + ); + + match instr { + UInstruction::LoadUpperImmediate => { + unimplemented!("LoadUpperImmediate") + } + UInstruction::AddUpperImmediate => { + unimplemented!("AddUpperImmediate") + } + }; +} + +/// Interpret an UJ-type instruction. +/// The encoding of an UJ-type instruction is as follows: +/// ```text +/// | 31 12 | 11 7 | 6 0 | +/// | immediate | rd | opcode | +/// ``` +/// Following the documentation found +/// [here](https://www.cs.cornell.edu/courses/cs3410/2024fa/assignments/cpusim/riscv-instructions.pdf) +pub fn interpret_ujtype(env: &mut Env, instr: UJInstruction) { + let instruction_pointer = env.get_instruction_pointer(); + let _next_instruction_pointer = env.get_next_instruction_pointer(); + + let instruction = { + let v0 = env.read_memory(&instruction_pointer); + let v1 = env.read_memory(&(instruction_pointer.clone() + Env::constant(1))); + let v2 = env.read_memory(&(instruction_pointer.clone() + Env::constant(2))); + let v3 = env.read_memory(&(instruction_pointer.clone() + Env::constant(3))); + (v3 * Env::constant(1 << 24)) + + (v2 * Env::constant(1 << 16)) + + (v1 * Env::constant(1 << 8)) + + v0 + }; + + let opcode = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 7, 0, pos) } + }; + env.range_check8(&opcode, 7); + + let rd = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 12, 7, pos) } + }; + env.range_check8(&rd, 5); + + // FIXME: trickier + let _imm = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 32, 12, pos) } + }; + + // FIXME: check correctness of decomposition + match instr { + UJInstruction::JumpAndLink => { + unimplemented!("JumpAndLink") + } + }; +} + +pub fn interpret_syscall(env: &mut Env, _instr: SyscallInstruction) { + // FIXME: check if it is syscall success. There is only one syscall atm + env.set_halted(Env::constant(1)); +} + +/// Interpret an M-type instruction. +/// The encoding of an M-type instruction is as follows: +/// ```text +/// | 31 27 | 26 25 | 24 20 | 19 15 | 14 12 | 11 7 | 6 0 | +/// | 00000 | 01 | rs2 | rs1 | funct3 | rd | opcode | +/// ``` +/// Following the documentation found +/// [here](https://www.cs.cornell.edu/courses/cs3410/2024fa/assignments/cpusim/riscv-instructions.pdf) +pub fn interpret_mtype(env: &mut Env, instr: MInstruction) { + let instruction_pointer = env.get_instruction_pointer(); + let next_instruction_pointer = env.get_next_instruction_pointer(); + + let instruction = { + let v0 = env.read_memory(&instruction_pointer); + let v1 = env.read_memory(&(instruction_pointer.clone() + Env::constant(1))); + let v2 = env.read_memory(&(instruction_pointer.clone() + Env::constant(2))); + let v3 = env.read_memory(&(instruction_pointer.clone() + Env::constant(3))); + (v3 * Env::constant(1 << 24)) + + (v2 * Env::constant(1 << 16)) + + (v1 * Env::constant(1 << 8)) + + v0 + }; + + let opcode = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 7, 0, pos) } + }; + env.range_check8(&opcode, 7); + + let rd = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 12, 7, pos) } + }; + env.range_check8(&rd, 5); + + let funct3 = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 15, 12, pos) } + }; + env.range_check8(&funct3, 3); + + let rs1 = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 20, 15, pos) } + }; + env.range_check8(&rs1, 5); + + let rs2 = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 25, 20, pos) } + }; + env.range_check8(&rs2, 5); + + let funct2 = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 27, 25, pos) } + }; + // FIXME: check it is equal to 01? + env.range_check8(&funct2, 2); + + let funct5 = { + let pos = env.alloc_scratch(); + unsafe { env.bitmask(&instruction, 32, 27, pos) } + }; + // FIXME: check it is equal to 00000? + env.range_check8(&funct5, 5); + + // Check decomposition of M type instruction + env.add_constraint( + instruction + - (opcode.clone() * Env::constant(1 << 0)) // opcode at bits 0-6 + - (rd.clone() * Env::constant(1 << 7)) // rd at bits 7-11 + - (funct3.clone() * Env::constant(1 << 12)) // funct3 at bits 12-14 + - (rs1.clone() * Env::constant(1 << 15)) // rs1 at bits 15-19 + - (rs2.clone() * Env::constant(1 << 20)) // rs2 at bits 20-24 + - (funct2.clone() * Env::constant(1 << 25)) // funct2 at bits 25-26 + - (funct5.clone() * Env::constant(1 << 27)), // funct5 at bits 27-31 + ); + + match instr { + MInstruction::Mul => { + // x[rd] = x[rs1] * x[rs2] + let rs1 = env.read_register(&rs1); + let rs2 = env.read_register(&rs2); + // FIXME: constrain + let res = { + let pos = env.alloc_scratch(); + unsafe { env.mul_lo_signed(&rs1, &rs2, pos) } + }; + env.write_register(&rd, res); + + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + MInstruction::Mulh => { + // x[rd] = (signed(x[rs1]) * signed(x[rs2])) >> 32 + let rs1 = env.read_register(&rs1); + let rs2 = env.read_register(&rs2); + // FIXME: constrain + let res = { + let pos = env.alloc_scratch(); + unsafe { env.mul_hi_signed(&rs1, &rs2, pos) } + }; + env.write_register(&rd, res); + + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + MInstruction::Mulhsu => { + // x[rd] = (signed(x[rs1]) * x[rs2]) >> 32 + let rs1 = env.read_register(&rs1); + let rs2 = env.read_register(&rs2); + // FIXME: constrain + let res = { + let pos = env.alloc_scratch(); + unsafe { env.mul_hi_signed_unsigned(&rs1, &rs2, pos) } + }; + env.write_register(&rd, res); + + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + MInstruction::Mulhu => { + // x[rd] = (x[rs1] * x[rs2]) >> 32 + let rs1 = env.read_register(&rs1); + let rs2 = env.read_register(&rs2); + // FIXME: constrain + let res = { + let pos = env.alloc_scratch(); + unsafe { env.mul_hi(&rs1, &rs2, pos) } + }; + env.write_register(&rd, res); + + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + MInstruction::Div => { + // x[rd] = signed(x[rs1]) / signed(x[rs2]) + let rs1 = env.read_register(&rs1); + let rs2 = env.read_register(&rs2); + // FIXME: constrain + let res = { + let pos = env.alloc_scratch(); + unsafe { env.div_signed(&rs1, &rs2, pos) } + }; + env.write_register(&rd, res); + + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + MInstruction::Divu => { + // x[rd] = x[rs1] / x[rs2] + let rs1 = env.read_register(&rs1); + let rs2 = env.read_register(&rs2); + // FIXME: constrain + let res = { + let pos = env.alloc_scratch(); + unsafe { env.div(&rs1, &rs2, pos) } + }; + env.write_register(&rd, res); + + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + MInstruction::Rem => { + // x[rd] = signed(x[rs1]) % signed(x[rs2]) + let rs1 = env.read_register(&rs1); + let rs2 = env.read_register(&rs2); + // FIXME: constrain + let res = { + let pos = env.alloc_scratch(); + unsafe { env.mod_signed(&rs1, &rs2, pos) } + }; + env.write_register(&rd, res); + + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + MInstruction::Remu => { + // x[rd] = x[rs1] % x[rs2] + let rs1 = env.read_register(&rs1); + let rs2 = env.read_register(&rs2); + // FIXME: constrain + let res = { + let pos = env.alloc_scratch(); + unsafe { env.mod_unsigned(&rs1, &rs2, pos) } + }; + env.write_register(&rd, res); + + env.set_instruction_pointer(next_instruction_pointer.clone()); + env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32)); + } + } +} diff --git a/o1vm/src/interpreters/riscv32im/mod.rs b/o1vm/src/interpreters/riscv32im/mod.rs new file mode 100644 index 0000000000..59c855694b --- /dev/null +++ b/o1vm/src/interpreters/riscv32im/mod.rs @@ -0,0 +1,26 @@ +/// The minimal number of columns required for the VM +// FIXME: the value will be updated when the interpreter is fully +// implemented. Using a small value for now. +pub const SCRATCH_SIZE: usize = 80; + +/// Number of instructions in the ISA +pub const INSTRUCTION_SET_SIZE: usize = 48; + +pub const PAGE_ADDRESS_SIZE: u32 = 12; +pub const PAGE_SIZE: u32 = 1 << PAGE_ADDRESS_SIZE; +pub const PAGE_ADDRESS_MASK: u32 = PAGE_SIZE - 1; + +/// List all columns used by the interpreter +pub mod column; + +pub mod constraints; + +pub mod interpreter; + +/// All the registers used by the ISA +pub mod registers; + +pub mod witness; + +#[cfg(test)] +mod tests; diff --git a/o1vm/src/interpreters/riscv32im/registers.rs b/o1vm/src/interpreters/riscv32im/registers.rs new file mode 100644 index 0000000000..ca5906bd54 --- /dev/null +++ b/o1vm/src/interpreters/riscv32im/registers.rs @@ -0,0 +1,66 @@ +use std::ops::{Index, IndexMut}; + +use serde::{Deserialize, Serialize}; + +pub const N_GP_REGISTERS: usize = 32; +// FIXME: +pub const REGISTER_CURRENT_IP: usize = N_GP_REGISTERS + 1; +pub const REGISTER_NEXT_IP: usize = N_GP_REGISTERS + 2; +pub const REGISTER_HEAP_POINTER: usize = N_GP_REGISTERS + 3; + +/// This represents the internal state of the virtual machine. +#[derive(Clone, Default, Debug, Serialize, Deserialize)] +pub struct Registers { + /// There are 32 general purpose registers. + /// - x0: hard-wired zero + /// - x1: return address + /// - x2: stack pointer + /// - x3: global pointer + /// - x4: thread pointer + /// - x5: temporary/alternate register + /// - x6-x7: temporaries + /// - x8: saved register/frame pointer + /// - x9: saved register + /// - x10-x11: function arguments/results + /// - x12-x17: function arguments + /// - x18-x27: saved registers + /// - x28-x31: temporaries + pub general_purpose: [T; N_GP_REGISTERS], + pub current_instruction_pointer: T, + pub next_instruction_pointer: T, + pub heap_pointer: T, +} + +impl Index for Registers { + type Output = T; + + fn index(&self, index: usize) -> &Self::Output { + if index < N_GP_REGISTERS { + &self.general_purpose[index] + } else if index == REGISTER_CURRENT_IP { + &self.current_instruction_pointer + } else if index == REGISTER_NEXT_IP { + &self.next_instruction_pointer + } else if index == REGISTER_HEAP_POINTER { + &self.heap_pointer + } else { + panic!("Index out of bounds"); + } + } +} + +impl IndexMut for Registers { + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + if index < N_GP_REGISTERS { + &mut self.general_purpose[index] + } else if index == REGISTER_CURRENT_IP { + &mut self.current_instruction_pointer + } else if index == REGISTER_NEXT_IP { + &mut self.next_instruction_pointer + } else if index == REGISTER_HEAP_POINTER { + &mut self.heap_pointer + } else { + panic!("Index out of bounds"); + } + } +} diff --git a/o1vm/src/interpreters/riscv32im/tests.rs b/o1vm/src/interpreters/riscv32im/tests.rs new file mode 100644 index 0000000000..e382dee701 --- /dev/null +++ b/o1vm/src/interpreters/riscv32im/tests.rs @@ -0,0 +1,421 @@ +use super::{registers::Registers, witness::Env, INSTRUCTION_SET_SIZE, PAGE_SIZE, SCRATCH_SIZE}; +use crate::interpreters::riscv32im::{ + constraints, + interpreter::{ + IInstruction, Instruction, MInstruction, RInstruction, SBInstruction, SInstruction, + SyscallInstruction, UInstruction, UJInstruction, + }, +}; +use ark_ff::Zero; +use mina_curves::pasta::Fp; +use rand::{CryptoRng, Rng, RngCore}; +use strum::EnumCount; + +// Sanity check that we have as many selector as we have instructions +#[test] +fn test_regression_selectors_for_instructions() { + let mips_con_env = constraints::Env::::default(); + let constraints = mips_con_env.get_selector_constraints(); + assert_eq!( + // We substract 1 as we have one boolean check per sel + // and 1 constraint to check that one and only one + // sel is activated + constraints.len() - 1, + // This should match the list in + // crate::interpreters::riscv32im::interpreter::Instruction + RInstruction::COUNT + + IInstruction::COUNT + + SInstruction::COUNT + + SBInstruction::COUNT + + UInstruction::COUNT + + UJInstruction::COUNT + + SyscallInstruction::COUNT + + MInstruction::COUNT + ); + // All instructions are degree 1 or 2. + constraints + .iter() + .for_each(|c| assert!(c.degree(1, 0) == 2 || c.degree(1, 0) == 1)); +} + +pub fn dummy_env() -> Env { + Env { + instruction_counter: 0, + memory: vec![(0, vec![0; PAGE_SIZE.try_into().unwrap()])], + last_memory_accesses: [0; 3], + memory_write_index: vec![(0, vec![0; PAGE_SIZE.try_into().unwrap()])], + last_memory_write_index_accesses: [0; 3], + registers: Registers::default(), + registers_write_index: Registers::default(), + scratch_state_idx: 0, + scratch_state: [Fp::zero(); SCRATCH_SIZE], + halt: false, + selector: INSTRUCTION_SET_SIZE, + } +} + +pub fn generate_random_add_instruction(rng: &mut RNG) -> [u8; 4] { + let opcode = 0b0110011; + let rd = rng.gen_range(0..32); + let funct3 = 0b000; + let rs1 = rng.gen_range(0..32); + let rs2 = rng.gen_range(0..32); + let funct2 = 0b00; + let funct5 = 0b00000; + let instruction = opcode + | (rd << 7) + | (funct3 << 12) + | (rs1 << 15) + | (rs2 << 20) + | (funct2 << 25) + | (funct5 << 27); + [ + instruction as u8, + (instruction >> 8) as u8, + (instruction >> 16) as u8, + (instruction >> 24) as u8, + ] +} + +pub fn generate_random_sub_instruction(rng: &mut RNG) -> [u8; 4] { + let opcode = 0b0110011; + let rd = rng.gen_range(0..32); + let funct3 = 0b000; + let rs1 = rng.gen_range(0..32); + let rs2 = rng.gen_range(0..32); + let funct2 = 0b00; + let funct5 = 0b01000; + let instruction = opcode + | (rd << 7) + | (funct3 << 12) + | (rs1 << 15) + | (rs2 << 20) + | (funct2 << 25) + | (funct5 << 27); + [ + instruction as u8, + (instruction >> 8) as u8, + (instruction >> 16) as u8, + (instruction >> 24) as u8, + ] +} + +pub fn generate_random_sll_instruction(rng: &mut RNG) -> [u8; 4] { + let opcode = 0b0110011; + let rd = rng.gen_range(0..32); + let funct3 = 0b001; + let rs1 = rng.gen_range(0..32); + let rs2 = rng.gen_range(0..32); + let funct2 = 0b00; + let funct5 = 0b00000; + let instruction = opcode + | (rd << 7) + | (funct3 << 12) + | (rs1 << 15) + | (rs2 << 20) + | (funct2 << 25) + | (funct5 << 27); + [ + instruction as u8, + (instruction >> 8) as u8, + (instruction >> 16) as u8, + (instruction >> 24) as u8, + ] +} + +pub fn generate_random_slt_instruction(rng: &mut RNG) -> [u8; 4] { + let opcode = 0b0110011; + let rd = rng.gen_range(0..32); + let funct3 = 0b010; + let rs1 = rng.gen_range(0..32); + let rs2 = rng.gen_range(0..32); + let funct2 = 0b00; + let funct5 = 0b00000; + let instruction = opcode + | (rd << 7) + | (funct3 << 12) + | (rs1 << 15) + | (rs2 << 20) + | (funct2 << 25) + | (funct5 << 27); + [ + instruction as u8, + (instruction >> 8) as u8, + (instruction >> 16) as u8, + (instruction >> 24) as u8, + ] +} + +pub fn generate_random_xor_instruction(rng: &mut RNG) -> [u8; 4] { + let opcode = 0b0110011; + let rd = rng.gen_range(0..32); + let funct3 = 0b100; + let rs1 = rng.gen_range(0..32); + let rs2 = rng.gen_range(0..32); + let funct2 = 0b00; + let funct5 = 0b00000; + let instruction = opcode + | (rd << 7) + | (funct3 << 12) + | (rs1 << 15) + | (rs2 << 20) + | (funct2 << 25) + | (funct5 << 27); + [ + instruction as u8, + (instruction >> 8) as u8, + (instruction >> 16) as u8, + (instruction >> 24) as u8, + ] +} + +pub fn generate_random_sltu_instruction(rng: &mut RNG) -> [u8; 4] { + let opcode = 0b0110011; + let rd = rng.gen_range(0..32); + let funct3 = 0b011; + let rs1 = rng.gen_range(0..32); + let rs2 = rng.gen_range(0..32); + let funct2 = 0b00; + let funct5 = 0b00000; + let instruction = opcode + | (rd << 7) + | (funct3 << 12) + | (rs1 << 15) + | (rs2 << 20) + | (funct2 << 25) + | (funct5 << 27); + [ + instruction as u8, + (instruction >> 8) as u8, + (instruction >> 16) as u8, + (instruction >> 24) as u8, + ] +} + +pub fn generate_random_srl_instruction(rng: &mut RNG) -> [u8; 4] { + let opcode = 0b0110011; + let rd = rng.gen_range(0..32); + let funct3 = 0b101; + let rs1 = rng.gen_range(0..32); + let rs2 = rng.gen_range(0..32); + let funct2 = 0b00; + let funct5 = 0b00000; + let instruction = opcode + | (rd << 7) + | (funct3 << 12) + | (rs1 << 15) + | (rs2 << 20) + | (funct2 << 25) + | (funct5 << 27); + [ + instruction as u8, + (instruction >> 8) as u8, + (instruction >> 16) as u8, + (instruction >> 24) as u8, + ] +} + +pub fn generate_random_sra_instruction(rng: &mut RNG) -> [u8; 4] { + let opcode = 0b0110011; + let rd = rng.gen_range(0..32); + let funct3 = 0b101; + let rs1 = rng.gen_range(0..32); + let rs2 = rng.gen_range(0..32); + let funct2 = 0b00; + let funct5 = 0b01000; + let instruction = opcode + | (rd << 7) + | (funct3 << 12) + | (rs1 << 15) + | (rs2 << 20) + | (funct2 << 25) + | (funct5 << 27); + [ + instruction as u8, + (instruction >> 8) as u8, + (instruction >> 16) as u8, + (instruction >> 24) as u8, + ] +} + +pub fn generate_random_or_instruction(rng: &mut RNG) -> [u8; 4] { + let opcode = 0b0110011; + let rd = rng.gen_range(0..32); + let funct3 = 0b110; + let rs1 = rng.gen_range(0..32); + let rs2 = rng.gen_range(0..32); + let funct2 = 0b00; + let funct5 = 0b00000; + let instruction = opcode + | (rd << 7) + | (funct3 << 12) + | (rs1 << 15) + | (rs2 << 20) + | (funct2 << 25) + | (funct5 << 27); + [ + instruction as u8, + (instruction >> 8) as u8, + (instruction >> 16) as u8, + (instruction >> 24) as u8, + ] +} + +pub fn generate_random_and_instruction(rng: &mut RNG) -> [u8; 4] { + let opcode = 0b0110011; + let rd = rng.gen_range(0..32); + let funct3 = 0b111; + let rs1 = rng.gen_range(0..32); + let rs2 = rng.gen_range(0..32); + let funct2 = 0b00; + let funct5 = 0b00000; + let instruction = opcode + | (rd << 7) + | (funct3 << 12) + | (rs1 << 15) + | (rs2 << 20) + | (funct2 << 25) + | (funct5 << 27); + [ + instruction as u8, + (instruction >> 8) as u8, + (instruction >> 16) as u8, + (instruction >> 24) as u8, + ] +} + +#[test] +pub fn test_instruction_decoding_add() { + let mut env: Env = dummy_env(); + let mut rng = o1_utils::tests::make_test_rng(None); + let instruction = generate_random_add_instruction(&mut rng); + env.memory[0].1[0] = instruction[0]; + env.memory[0].1[1] = instruction[1]; + env.memory[0].1[2] = instruction[2]; + env.memory[0].1[3] = instruction[3]; + let (opcode, _instruction) = env.decode_instruction(); + assert_eq!(opcode, Instruction::RType(RInstruction::Add)); +} + +#[test] +pub fn test_instruction_decoding_sub() { + let mut env: Env = dummy_env(); + let mut rng = o1_utils::tests::make_test_rng(None); + let instruction = generate_random_sub_instruction(&mut rng); + env.memory[0].1[0] = instruction[0]; + env.memory[0].1[1] = instruction[1]; + env.memory[0].1[2] = instruction[2]; + env.memory[0].1[3] = instruction[3]; + let (opcode, _instruction) = env.decode_instruction(); + assert_eq!(opcode, Instruction::RType(RInstruction::Sub)); +} + +#[test] +pub fn test_instruction_decoding_sll() { + let mut env: Env = dummy_env(); + let mut rng = o1_utils::tests::make_test_rng(None); + let instruction = generate_random_sll_instruction(&mut rng); + env.memory[0].1[0] = instruction[0]; + env.memory[0].1[1] = instruction[1]; + env.memory[0].1[2] = instruction[2]; + env.memory[0].1[3] = instruction[3]; + let (opcode, _instruction) = env.decode_instruction(); + assert_eq!(opcode, Instruction::RType(RInstruction::ShiftLeftLogical)); +} + +#[test] +pub fn test_instruction_decoding_slt() { + let mut env: Env = dummy_env(); + let mut rng = o1_utils::tests::make_test_rng(None); + let instruction = generate_random_slt_instruction(&mut rng); + env.memory[0].1[0] = instruction[0]; + env.memory[0].1[1] = instruction[1]; + env.memory[0].1[2] = instruction[2]; + env.memory[0].1[3] = instruction[3]; + let (opcode, _instruction) = env.decode_instruction(); + assert_eq!(opcode, Instruction::RType(RInstruction::SetLessThan)); +} + +#[test] +pub fn test_instruction_decoding_sltu() { + let mut env: Env = dummy_env(); + let mut rng = o1_utils::tests::make_test_rng(None); + let instruction = generate_random_sltu_instruction(&mut rng); + env.memory[0].1[0] = instruction[0]; + env.memory[0].1[1] = instruction[1]; + env.memory[0].1[2] = instruction[2]; + env.memory[0].1[3] = instruction[3]; + let (opcode, _instruction) = env.decode_instruction(); + assert_eq!( + opcode, + Instruction::RType(RInstruction::SetLessThanUnsigned) + ); +} + +#[test] +pub fn test_instruction_decoding_xor() { + let mut env: Env = dummy_env(); + let mut rng = o1_utils::tests::make_test_rng(None); + let instruction = generate_random_xor_instruction(&mut rng); + env.memory[0].1[0] = instruction[0]; + env.memory[0].1[1] = instruction[1]; + env.memory[0].1[2] = instruction[2]; + env.memory[0].1[3] = instruction[3]; + let (opcode, _instruction) = env.decode_instruction(); + assert_eq!(opcode, Instruction::RType(RInstruction::Xor)); +} + +#[test] +pub fn test_instruction_decoding_srl() { + let mut env: Env = dummy_env(); + let mut rng = o1_utils::tests::make_test_rng(None); + let instruction = generate_random_srl_instruction(&mut rng); + env.memory[0].1[0] = instruction[0]; + env.memory[0].1[1] = instruction[1]; + env.memory[0].1[2] = instruction[2]; + env.memory[0].1[3] = instruction[3]; + let (opcode, _instruction) = env.decode_instruction(); + assert_eq!(opcode, Instruction::RType(RInstruction::ShiftRightLogical)); +} + +#[test] +pub fn test_instruction_decoding_sr1() { + let mut env: Env = dummy_env(); + let mut rng = o1_utils::tests::make_test_rng(None); + let instruction = generate_random_sra_instruction(&mut rng); + env.memory[0].1[0] = instruction[0]; + env.memory[0].1[1] = instruction[1]; + env.memory[0].1[2] = instruction[2]; + env.memory[0].1[3] = instruction[3]; + let (opcode, _instruction) = env.decode_instruction(); + assert_eq!( + opcode, + Instruction::RType(RInstruction::ShiftRightArithmetic) + ); +} + +#[test] +pub fn test_instruction_decoding_or() { + let mut env: Env = dummy_env(); + let mut rng = o1_utils::tests::make_test_rng(None); + let instruction = generate_random_or_instruction(&mut rng); + env.memory[0].1[0] = instruction[0]; + env.memory[0].1[1] = instruction[1]; + env.memory[0].1[2] = instruction[2]; + env.memory[0].1[3] = instruction[3]; + let (opcode, _instruction) = env.decode_instruction(); + assert_eq!(opcode, Instruction::RType(RInstruction::Or)); +} + +#[test] +pub fn test_instruction_decoding_and() { + let mut env: Env = dummy_env(); + let mut rng = o1_utils::tests::make_test_rng(None); + let instruction = generate_random_and_instruction(&mut rng); + env.memory[0].1[0] = instruction[0]; + env.memory[0].1[1] = instruction[1]; + env.memory[0].1[2] = instruction[2]; + env.memory[0].1[3] = instruction[3]; + let (opcode, _instruction) = env.decode_instruction(); + assert_eq!(opcode, Instruction::RType(RInstruction::And)); +} diff --git a/o1vm/src/interpreters/riscv32im/witness.rs b/o1vm/src/interpreters/riscv32im/witness.rs new file mode 100644 index 0000000000..e9a976ca13 --- /dev/null +++ b/o1vm/src/interpreters/riscv32im/witness.rs @@ -0,0 +1,901 @@ +// TODO: do we want to be more restrictive and refer to the number of accesses +// to the SAME register/memory addrss? +use super::{ + column::Column, + interpreter::{ + self, IInstruction, Instruction, InterpreterEnv, RInstruction, SBInstruction, SInstruction, + SyscallInstruction, UInstruction, UJInstruction, + }, + registers::Registers, + INSTRUCTION_SET_SIZE, SCRATCH_SIZE, +}; +use crate::{ + cannon::{State, PAGE_ADDRESS_MASK, PAGE_ADDRESS_SIZE, PAGE_SIZE}, + lookups::Lookup, +}; +use ark_ff::Field; +use std::array; + +/// Maximum number of register accesses per instruction (based on demo) +// FIXME: can be different +pub const MAX_NB_REG_ACC: u64 = 7; +/// Maximum number of memory accesses per instruction (based on demo) +// FIXME: can be different +pub const MAX_NB_MEM_ACC: u64 = 12; +/// Maximum number of memory or register accesses per instruction +pub const MAX_ACC: u64 = MAX_NB_REG_ACC + MAX_NB_MEM_ACC; + +pub const NUM_GLOBAL_LOOKUP_TERMS: usize = 1; +pub const NUM_DECODING_LOOKUP_TERMS: usize = 2; +pub const NUM_INSTRUCTION_LOOKUP_TERMS: usize = 5; +pub const NUM_LOOKUP_TERMS: usize = + NUM_GLOBAL_LOOKUP_TERMS + NUM_DECODING_LOOKUP_TERMS + NUM_INSTRUCTION_LOOKUP_TERMS; + +/// This structure represents the environment the virtual machine state will use +/// to transition. This environment will be used by the interpreter. The virtual +/// machine has access to its internal state and some external memory. In +/// addition to that, it has access to the environment of the Keccak interpreter +/// that is used to verify the preimage requested during the execution. +pub struct Env { + pub instruction_counter: u64, + pub memory: Vec<(u32, Vec)>, + pub last_memory_accesses: [usize; 3], + pub memory_write_index: Vec<(u32, Vec)>, + pub last_memory_write_index_accesses: [usize; 3], + pub registers: Registers, + pub registers_write_index: Registers, + pub scratch_state_idx: usize, + pub scratch_state: [Fp; SCRATCH_SIZE], + pub halt: bool, + pub selector: usize, +} + +fn fresh_scratch_state() -> [Fp; N] { + array::from_fn(|_| Fp::zero()) +} + +impl InterpreterEnv for Env { + type Position = Column; + + fn alloc_scratch(&mut self) -> Self::Position { + let scratch_idx = self.scratch_state_idx; + self.scratch_state_idx += 1; + Column::ScratchState(scratch_idx) + } + + type Variable = u64; + + fn variable(&self, _column: Self::Position) -> Self::Variable { + todo!() + } + + fn add_constraint(&mut self, _assert_equals_zero: Self::Variable) { + // No-op for witness + // Do not assert that _assert_equals_zero is zero here! + // Some variables may have placeholders that do not faithfully + // represent the underlying values. + } + + fn activate_selector(&mut self, instruction: Instruction) { + self.selector = instruction.into(); + } + + fn check_is_zero(assert_equals_zero: &Self::Variable) { + assert_eq!(*assert_equals_zero, 0); + } + + fn check_equal(x: &Self::Variable, y: &Self::Variable) { + assert_eq!(*x, *y); + } + + fn check_boolean(x: &Self::Variable) { + if !(*x == 0 || *x == 1) { + panic!("The value {} is not a boolean", *x); + } + } + + fn add_lookup(&mut self, _lookup: Lookup) { + // No-op, constraints only + // TODO: keep track of multiplicities of fixed tables here as in Keccak? + } + + fn instruction_counter(&self) -> Self::Variable { + self.instruction_counter + } + + fn increase_instruction_counter(&mut self) { + self.instruction_counter += 1; + } + + unsafe fn fetch_register( + &mut self, + idx: &Self::Variable, + output: Self::Position, + ) -> Self::Variable { + let res = self.registers[*idx as usize] as u64; + self.write_column(output, res); + res + } + + unsafe fn push_register_if( + &mut self, + idx: &Self::Variable, + value: Self::Variable, + if_is_true: &Self::Variable, + ) { + let value: u32 = value.try_into().unwrap(); + if *if_is_true == 1 { + self.registers[*idx as usize] = value + } else if *if_is_true == 0 { + // No-op + } else { + panic!("Bad value for flag in push_register: {}", *if_is_true); + } + } + + unsafe fn fetch_register_access( + &mut self, + idx: &Self::Variable, + output: Self::Position, + ) -> Self::Variable { + let res = self.registers_write_index[*idx as usize]; + self.write_column(output, res); + res + } + + unsafe fn push_register_access_if( + &mut self, + idx: &Self::Variable, + value: Self::Variable, + if_is_true: &Self::Variable, + ) { + if *if_is_true == 1 { + self.registers_write_index[*idx as usize] = value + } else if *if_is_true == 0 { + // No-op + } else { + panic!("Bad value for flag in push_register: {}", *if_is_true); + } + } + + unsafe fn fetch_memory( + &mut self, + addr: &Self::Variable, + output: Self::Position, + ) -> Self::Variable { + let addr: u32 = (*addr).try_into().unwrap(); + let page = addr >> PAGE_ADDRESS_SIZE; + let page_address = (addr & PAGE_ADDRESS_MASK) as usize; + let memory_page_idx = self.get_memory_page_index(page); + let value = self.memory[memory_page_idx].1[page_address]; + self.write_column(output, value.into()); + value.into() + } + + unsafe fn push_memory(&mut self, addr: &Self::Variable, value: Self::Variable) { + let addr: u32 = (*addr).try_into().unwrap(); + let page = addr >> PAGE_ADDRESS_SIZE; + let page_address = (addr & PAGE_ADDRESS_MASK) as usize; + let memory_page_idx = self.get_memory_page_index(page); + self.memory[memory_page_idx].1[page_address] = + value.try_into().expect("push_memory values fit in a u8"); + } + + unsafe fn fetch_memory_access( + &mut self, + addr: &Self::Variable, + output: Self::Position, + ) -> Self::Variable { + let addr: u32 = (*addr).try_into().unwrap(); + let page = addr >> PAGE_ADDRESS_SIZE; + let page_address = (addr & PAGE_ADDRESS_MASK) as usize; + let memory_write_index_page_idx = self.get_memory_access_page_index(page); + let value = self.memory_write_index[memory_write_index_page_idx].1[page_address]; + self.write_column(output, value); + value + } + + unsafe fn push_memory_access(&mut self, addr: &Self::Variable, value: Self::Variable) { + let addr = *addr as u32; + let page = addr >> PAGE_ADDRESS_SIZE; + let page_address = (addr & PAGE_ADDRESS_MASK) as usize; + let memory_write_index_page_idx = self.get_memory_access_page_index(page); + self.memory_write_index[memory_write_index_page_idx].1[page_address] = value; + } + + fn constant(x: u32) -> Self::Variable { + x as u64 + } + + unsafe fn bitmask( + &mut self, + x: &Self::Variable, + highest_bit: u32, + lowest_bit: u32, + position: Self::Position, + ) -> Self::Variable { + let x: u32 = (*x).try_into().unwrap(); + let res = (x >> lowest_bit) & ((1 << (highest_bit - lowest_bit)) - 1); + let res = res as u64; + self.write_column(position, res); + res + } + + unsafe fn shift_left( + &mut self, + x: &Self::Variable, + by: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + let x: u32 = (*x).try_into().unwrap(); + let by: u32 = (*by).try_into().unwrap(); + let res = x << by; + let res = res as u64; + self.write_column(position, res); + res + } + + unsafe fn shift_right( + &mut self, + x: &Self::Variable, + by: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + let x: u32 = (*x).try_into().unwrap(); + let by: u32 = (*by).try_into().unwrap(); + let res = x >> by; + let res = res as u64; + self.write_column(position, res); + res + } + + unsafe fn shift_right_arithmetic( + &mut self, + x: &Self::Variable, + by: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + let x: u32 = (*x).try_into().unwrap(); + let by: u32 = (*by).try_into().unwrap(); + let res = ((x as i32) >> by) as u32; + let res = res as u64; + self.write_column(position, res); + res + } + + unsafe fn test_zero(&mut self, x: &Self::Variable, position: Self::Position) -> Self::Variable { + let res = if *x == 0 { 1 } else { 0 }; + self.write_column(position, res); + res + } + + unsafe fn inverse_or_zero( + &mut self, + x: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + if *x == 0 { + self.write_column(position, 0); + 0 + } else { + self.write_field_column(position, Fp::from(*x).inverse().unwrap()); + 1 // Placeholder value + } + } + + fn is_zero(&mut self, x: &Self::Variable) -> Self::Variable { + // write the result + let pos = self.alloc_scratch(); + let res = if *x == 0 { 1 } else { 0 }; + self.write_column(pos, res); + // write the non deterministic advice inv_or_zero + let pos = self.alloc_scratch(); + let inv_or_zero = if *x == 0 { + Fp::zero() + } else { + Fp::inverse(&Fp::from(*x)).unwrap() + }; + self.write_field_column(pos, inv_or_zero); + // return the result + res + } + + fn equal(&mut self, x: &Self::Variable, y: &Self::Variable) -> Self::Variable { + // To avoid subtraction overflow in the witness interpreter for u32 + if x > y { + self.is_zero(&(*x - *y)) + } else { + self.is_zero(&(*y - *x)) + } + } + + unsafe fn test_less_than( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + let x: u32 = (*x).try_into().unwrap(); + let y: u32 = (*y).try_into().unwrap(); + let res = if x < y { 1 } else { 0 }; + let res = res as u64; + self.write_column(position, res); + res + } + + unsafe fn test_less_than_signed( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + let x: u32 = (*x).try_into().unwrap(); + let y: u32 = (*y).try_into().unwrap(); + let res = if (x as i32) < (y as i32) { 1 } else { 0 }; + let res = res as u64; + self.write_column(position, res); + res + } + + unsafe fn and_witness( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + let x: u32 = (*x).try_into().unwrap(); + let y: u32 = (*y).try_into().unwrap(); + let res = x & y; + let res = res as u64; + self.write_column(position, res); + res + } + + unsafe fn nor_witness( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + let x: u32 = (*x).try_into().unwrap(); + let y: u32 = (*y).try_into().unwrap(); + let res = !(x | y); + let res = res as u64; + self.write_column(position, res); + res + } + + unsafe fn or_witness( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + let x: u32 = (*x).try_into().unwrap(); + let y: u32 = (*y).try_into().unwrap(); + let res = x | y; + let res = res as u64; + self.write_column(position, res); + res + } + + unsafe fn xor_witness( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + let x: u32 = (*x).try_into().unwrap(); + let y: u32 = (*y).try_into().unwrap(); + let res = x ^ y; + let res = res as u64; + self.write_column(position, res); + res + } + + unsafe fn add_witness( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + out_position: Self::Position, + overflow_position: Self::Position, + ) -> (Self::Variable, Self::Variable) { + let x: u32 = (*x).try_into().unwrap(); + let y: u32 = (*y).try_into().unwrap(); + // https://doc.rust-lang.org/std/primitive.u32.html#method.overflowing_add + let res = x.overflowing_add(y); + let (res_, overflow) = (res.0 as u64, res.1 as u64); + self.write_column(out_position, res_); + self.write_column(overflow_position, overflow); + (res_, overflow) + } + + unsafe fn sub_witness( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + out_position: Self::Position, + underflow_position: Self::Position, + ) -> (Self::Variable, Self::Variable) { + let x: u32 = (*x).try_into().unwrap(); + let y: u32 = (*y).try_into().unwrap(); + // https://doc.rust-lang.org/std/primitive.u32.html#method.overflowing_sub + let res = x.overflowing_sub(y); + let (res_, underflow) = (res.0 as u64, res.1 as u64); + self.write_column(out_position, res_); + self.write_column(underflow_position, underflow); + (res_, underflow) + } + + unsafe fn mul_signed_witness( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + let x: u32 = (*x).try_into().unwrap(); + let y: u32 = (*y).try_into().unwrap(); + let res = ((x as i32) * (y as i32)) as u32; + let res = res as u64; + self.write_column(position, res); + res + } + + unsafe fn mul_hi_signed( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + let x: i32 = (*x).try_into().unwrap(); + let y: i32 = (*y).try_into().unwrap(); + let res = (x as i64) * (y as i64); + let res = (res >> 32) as i32; + let res = res as u64; + self.write_column(position, res); + res + } + + unsafe fn mul_lo_signed( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + let x: i32 = (*x).try_into().unwrap(); + let y: i32 = (*y).try_into().unwrap(); + let res = ((x as i64) * (y as i64)) as u64; + let res = (res & ((1 << 32) - 1)) as u32; + let res = res as u64; + self.write_column(position, res); + res + } + + unsafe fn mul_hi( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + let x: u32 = (*x).try_into().unwrap(); + let y: u32 = (*y).try_into().unwrap(); + let res = (x as u64) * (y as u64); + let res = (res >> 32) as u32; + let res = res as u64; + self.write_column(position, res); + res + } + + unsafe fn mul_hi_signed_unsigned( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + let x: u32 = (*x).try_into().unwrap(); + let y: u32 = (*y).try_into().unwrap(); + let res = (((x as i32) as i64) * (y as i64)) as u64; + let res = (res >> 32) as u32; + let res = res as u64; + self.write_column(position, res); + res + } + + unsafe fn div_signed( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + let x: i32 = (*x).try_into().unwrap(); + let y: i32 = (*y).try_into().unwrap(); + let res = (x / y) as u32; + let res = res as u64; + self.write_column(position, res); + res + } + + unsafe fn mul_lo( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + let x: u32 = (*x).try_into().unwrap(); + let y: u32 = (*y).try_into().unwrap(); + let res = (x as u64) * (y as u64); + let res = (res & ((1 << 32) - 1)) as u32; + let res = res as u64; + self.write_column(position, res); + res + } + + unsafe fn mod_signed( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + let x: i32 = (*x).try_into().unwrap(); + let y: i32 = (*y).try_into().unwrap(); + let res = (x % y) as u32; + let res = res as u64; + self.write_column(position, res); + res + } + + unsafe fn div( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + let x: u32 = (*x).try_into().unwrap(); + let y: u32 = (*y).try_into().unwrap(); + let res = x / y; + let res = res as u64; + self.write_column(position, res); + res + } + + unsafe fn mod_unsigned( + &mut self, + x: &Self::Variable, + y: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + let x: u32 = (*x).try_into().unwrap(); + let y: u32 = (*y).try_into().unwrap(); + let res = x % y; + let res = res as u64; + self.write_column(position, res); + res + } + + unsafe fn count_leading_zeros( + &mut self, + x: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + let x: u32 = (*x).try_into().unwrap(); + let res = x.leading_zeros(); + let res = res as u64; + self.write_column(position, res); + res + } + + unsafe fn count_leading_ones( + &mut self, + x: &Self::Variable, + position: Self::Position, + ) -> Self::Variable { + let x: u32 = (*x).try_into().unwrap(); + let res = x.leading_ones(); + let res = res as u64; + self.write_column(position, res); + res + } + + fn copy(&mut self, x: &Self::Variable, position: Self::Position) -> Self::Variable { + self.write_column(position, *x); + *x + } + + fn set_halted(&mut self, flag: Self::Variable) { + if flag == 0 { + self.halt = false + } else if flag == 1 { + self.halt = true + } else { + panic!("Bad value for flag in set_halted: {}", flag); + } + } + + fn report_exit(&mut self, exit_code: &Self::Variable) { + println!( + "Exited with code {} at step {}", + *exit_code, + self.normalized_instruction_counter() + ); + } + + fn reset(&mut self) { + self.scratch_state_idx = 0; + self.scratch_state = fresh_scratch_state(); + self.selector = INSTRUCTION_SET_SIZE; + } +} + +impl Env { + pub fn create(page_size: usize, state: State) -> Self { + let initial_instruction_pointer = state.pc; + let next_instruction_pointer = state.next_pc; + + let selector = INSTRUCTION_SET_SIZE; + + let mut initial_memory: Vec<(u32, Vec)> = state + .memory + .into_iter() + // Check that the conversion from page data is correct + .map(|page| (page.index, page.data)) + .collect(); + + for (_address, initial_memory) in initial_memory.iter_mut() { + initial_memory.extend((0..(page_size - initial_memory.len())).map(|_| 0u8)); + assert_eq!(initial_memory.len(), page_size); + } + + let memory_offsets = initial_memory + .iter() + .map(|(offset, _)| *offset) + .collect::>(); + + let initial_registers = { + Registers { + general_purpose: state.registers, + current_instruction_pointer: initial_instruction_pointer, + next_instruction_pointer, + heap_pointer: state.heap, + } + }; + + let mut registers = initial_registers.clone(); + registers[2] = 0x408004f0; + // set the stack pointer to the top of the stack + + Env { + instruction_counter: state.step, + memory: initial_memory.clone(), + last_memory_accesses: [0usize; 3], + memory_write_index: memory_offsets + .iter() + .map(|offset| (*offset, vec![0u64; page_size])) + .collect(), + last_memory_write_index_accesses: [0usize; 3], + registers, + registers_write_index: Registers::default(), + scratch_state_idx: 0, + scratch_state: fresh_scratch_state(), + halt: state.exited, + selector, + } + } + + pub fn next_instruction_counter(&self) -> u64 { + (self.normalized_instruction_counter() + 1) * MAX_ACC + } + + pub fn decode_instruction(&mut self) -> (Instruction, u32) { + /* https://www.cs.cornell.edu/courses/cs3410/2024fa/assignments/cpusim/riscv-instructions.pdf */ + let instruction = + ((self.get_memory_direct(self.registers.current_instruction_pointer) as u32) << 24) + | ((self.get_memory_direct(self.registers.current_instruction_pointer + 1) as u32) + << 16) + | ((self.get_memory_direct(self.registers.current_instruction_pointer + 2) as u32) + << 8) + | (self.get_memory_direct(self.registers.current_instruction_pointer + 3) as u32); + let instruction = instruction.to_be(); // convert to big endian for more straightforward decoding + let opcode = { + match instruction & 0b1111111 // bits 0-6 + { + 0b0110111 => Instruction::UType(UInstruction::LoadUpperImmediate), + 0b0010111 => Instruction::UType(UInstruction::AddUpperImmediate), + 0b1101111 => Instruction::UJType(UJInstruction::JumpAndLink), + 0b1100011 => + match (instruction >> 12) & 0x7 // bits 12-14 for func3 + { + 0b000 => Instruction::SBType(SBInstruction::BranchEq), + 0b001 => Instruction::SBType(SBInstruction::BranchNeq), + 0b100 => Instruction::SBType(SBInstruction::BranchLessThan), + 0b101 => Instruction::SBType(SBInstruction::BranchGreaterThanEqual), + 0b110 => Instruction::SBType(SBInstruction::BranchLessThanUnsigned), + 0b111 => Instruction::SBType(SBInstruction::BranchGreaterThanEqualUnsigned), + _ => panic!("Unknown SBType instruction with full inst {}", instruction), + }, + 0b1100111 => Instruction::IType(IInstruction::JumpAndLinkRegister), + 0b0000011 => + match (instruction >> 12) & 0x7 // bits 12-14 for func3 + { + 0b000 => Instruction::IType(IInstruction::LoadByte), + 0b001 => Instruction::IType(IInstruction::LoadHalf), + 0b010 => Instruction::IType(IInstruction::LoadWord), + 0b100 => Instruction::IType(IInstruction::LoadByteUnsigned), + 0b101 => Instruction::IType(IInstruction::LoadHalfUnsigned), + _ => panic!("Unknown IType instruction with full inst {}", instruction), + }, + 0b0100011 => + match (instruction >> 12) & 0x7 // bits 12-14 for func3 + { + 0b000 => Instruction::SType(SInstruction::StoreByte), + 0b001 => Instruction::SType(SInstruction::StoreHalf), + 0b010 => Instruction::SType(SInstruction::StoreWord), + _ => panic!("Unknown SType instruction with full inst {}", instruction), + }, + 0b0010011 => + match (instruction >> 12) & 0x7 // bits 12-14 for func3 + { + 0b000 => Instruction::IType(IInstruction::AddImmediate), + 0b010 => Instruction::IType(IInstruction::SetLessThanImmediate), + 0b011 => Instruction::IType(IInstruction::SetLessThanImmediateUnsigned), + 0b100 => Instruction::IType(IInstruction::XorImmediate), + 0b110 => Instruction::IType(IInstruction::OrImmediate), + 0b111 => Instruction::IType(IInstruction::AndImmediate), + 0b001 => Instruction::IType(IInstruction::ShiftLeftLogicalImmediate), + 0b101 => + match (instruction >> 30) & 0x1 // bit 30 in simm component of IType + { + 0b0 => Instruction::IType(IInstruction::ShiftRightLogicalImmediate), + 0b1 => Instruction::IType(IInstruction::ShiftRightArithmeticImmediate), + _ => panic!("Unknown IType in shift right instructions with full inst {}", instruction), + }, + _ => panic!("Unknown IType instruction with full inst {}", instruction), + }, + 0b0110011 => + match (instruction >> 12) & 0x7 // bits 12-14 for func3 + { + 0b000 => + match (instruction >> 30) & 0x1 // bit 30 of funct5 component in RType + { + 0b0 => Instruction::RType(RInstruction::Add), + 0b1 => Instruction::RType(RInstruction::Sub), + _ => panic!("Unknown RType in add/sub instructions with full inst {}", instruction), + }, + 0b001 => Instruction::RType(RInstruction::ShiftLeftLogical), + 0b010 => Instruction::RType(RInstruction::SetLessThan), + 0b011 => Instruction::RType(RInstruction::SetLessThanUnsigned), + 0b100 => Instruction::RType(RInstruction::Xor), + 0b101 => + match (instruction >> 30) & 0x1 // bit 30 of funct5 component in RType + { + 0b0 => Instruction::RType(RInstruction::ShiftRightLogical), + 0b1 => Instruction::RType(RInstruction::ShiftRightArithmetic), + _ => panic!("Unknown RType in shift right instructions with full inst {}", instruction), + }, + 0b110 => Instruction::RType(RInstruction::Or), + 0b111 => Instruction::RType(RInstruction::And), + _ => panic!("Unknown RType 0110011 instruction with full inst {}", instruction), + }, + 0b0001111 => + match (instruction >> 12) & 0x7 // bits 12-14 for func3 + { + 0b000 => Instruction::RType(RInstruction::Fence), + 0b001 => Instruction::RType(RInstruction::FenceI), + _ => panic!("Unknown RType 0001111 (Fence) instruction with full inst {}", instruction), + }, + // FIXME: we should implement more syscalls here, and check the register state. + // Even better, only one constructor call ecall, and in the + // interpreter, we do the action depending on it + 0b1110011 => Instruction::SyscallType(SyscallInstruction::SyscallSuccess), + _ => panic!("Unknown instruction with full inst {:b}, and opcode {:b}", instruction, instruction & 0b1111111), + } + }; + (opcode, instruction) + } + + /// Execute a single step in the RISCV32i program + pub fn step(&mut self) -> Instruction { + self.reset_scratch_state(); + let (opcode, _instruction) = self.decode_instruction(); + + interpreter::interpret_instruction(self, opcode); + + self.instruction_counter = self.next_instruction_counter(); + + // Integer division by MAX_ACC to obtain the actual instruction count + if self.halt { + println!( + "Halted at step={} instruction={:?}", + self.normalized_instruction_counter(), + opcode + ); + } + opcode + } + + pub fn reset_scratch_state(&mut self) { + self.scratch_state_idx = 0; + self.scratch_state = fresh_scratch_state(); + self.selector = INSTRUCTION_SET_SIZE; + } + + pub fn write_column(&mut self, column: Column, value: u64) { + self.write_field_column(column, value.into()) + } + + pub fn write_field_column(&mut self, column: Column, value: Fp) { + match column { + Column::ScratchState(idx) => self.scratch_state[idx] = value, + Column::InstructionCounter => panic!("Cannot overwrite the column {:?}", column), + Column::Selector(s) => self.selector = s, + } + } + + pub fn update_last_memory_access(&mut self, i: usize) { + let [i_0, i_1, _] = self.last_memory_accesses; + self.last_memory_accesses = [i, i_0, i_1] + } + + pub fn get_memory_page_index(&mut self, page: u32) -> usize { + for &i in self.last_memory_accesses.iter() { + if self.memory_write_index[i].0 == page { + return i; + } + } + for (i, (page_index, _memory)) in self.memory.iter_mut().enumerate() { + if *page_index == page { + self.update_last_memory_access(i); + return i; + } + } + + // Memory not found; dynamically allocate + let memory = vec![0u8; PAGE_SIZE as usize]; + self.memory.push((page, memory)); + let i = self.memory.len() - 1; + self.update_last_memory_access(i); + i + } + + pub fn update_last_memory_write_index_access(&mut self, i: usize) { + let [i_0, i_1, _] = self.last_memory_write_index_accesses; + self.last_memory_write_index_accesses = [i, i_0, i_1] + } + + pub fn get_memory_access_page_index(&mut self, page: u32) -> usize { + for &i in self.last_memory_write_index_accesses.iter() { + if self.memory_write_index[i].0 == page { + return i; + } + } + for (i, (page_index, _memory_write_index)) in self.memory_write_index.iter_mut().enumerate() + { + if *page_index == page { + self.update_last_memory_write_index_access(i); + return i; + } + } + + // Memory not found; dynamically allocate + let memory_write_index = vec![0u64; PAGE_SIZE as usize]; + self.memory_write_index.push((page, memory_write_index)); + let i = self.memory_write_index.len() - 1; + self.update_last_memory_write_index_access(i); + i + } + + pub fn get_memory_direct(&mut self, addr: u32) -> u8 { + let page = addr >> PAGE_ADDRESS_SIZE; + let page_address = (addr & PAGE_ADDRESS_MASK) as usize; + let memory_idx = self.get_memory_page_index(page); + self.memory[memory_idx].1[page_address] + } + + /// The actual number of instructions executed results from dividing the + /// instruction counter by MAX_ACC (floor). + /// + /// NOTE: actually, in practice it will be less than that, as there is no + /// single instruction that performs all of them. + pub fn normalized_instruction_counter(&self) -> u64 { + self.instruction_counter / MAX_ACC + } +} diff --git a/o1vm/src/legacy/folding.rs b/o1vm/src/legacy/folding.rs new file mode 100644 index 0000000000..6062077e94 --- /dev/null +++ b/o1vm/src/legacy/folding.rs @@ -0,0 +1,657 @@ +use ark_ec::AffineRepr; +use ark_ff::FftField; +use ark_poly::{Evaluations, Radix2EvaluationDomain}; +use folding::{ + instance_witness::Foldable, Alphas, FoldingConfig, FoldingEnv, Instance, Side, Witness, +}; +use kimchi::circuits::{berkeley_columns::BerkeleyChallengeTerm, gate::CurrOrNext}; +use kimchi_msm::witness::Witness as GenericWitness; +use poly_commitment::commitment::CommitmentCurve; +use std::{array, ops::Index}; +use strum::EnumCount; +use strum_macros::{EnumCount as EnumCountMacro, EnumIter}; + +// Simple type alias as ScalarField/BaseField is often used. Reduce type +// complexity for clippy. +// Should be moved into FoldingConfig, but associated type defaults are unstable +// at the moment. +pub(crate) type ScalarField = <::Curve as AffineRepr>::ScalarField; +pub(crate) type BaseField = <::Curve as AffineRepr>::BaseField; + +// Does not contain alpha because this one should be provided by folding itself +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq, EnumIter, EnumCountMacro)] +pub enum Challenge { + Beta, + Gamma, + JointCombiner, +} + +// Needed to transform from expressions to folding expressions +impl From for Challenge { + fn from(chal: BerkeleyChallengeTerm) -> Self { + match chal { + BerkeleyChallengeTerm::Beta => Challenge::Beta, + BerkeleyChallengeTerm::Gamma => Challenge::Gamma, + BerkeleyChallengeTerm::JointCombiner => Challenge::JointCombiner, + BerkeleyChallengeTerm::Alpha => panic!("Alpha not allowed in folding expressions"), + } + } +} + +/// Folding instance containing the commitment to a witness of N columns, +/// challenges for the proof, and the alphas +#[derive(Debug, Clone)] +pub struct FoldingInstance { + /// Commitments to the witness columns, including the dynamic selectors + pub commitments: [G; N], + /// Challenges for the proof. + /// We do use 3 challenges: + /// - β as the evaluation point for the logup argument + /// - j: the joint combiner for vector lookups + /// - γ (set to 0 for now) + pub challenges: [::ScalarField; Challenge::COUNT], + /// Reuses the Alphas defined in the example of folding + pub alphas: Alphas<::ScalarField>, + + /// Blinder used in the polynomial commitment scheme + pub blinder: ::ScalarField, +} + +impl Foldable for FoldingInstance { + fn combine(a: Self, b: Self, challenge: G::ScalarField) -> Self { + FoldingInstance { + commitments: array::from_fn(|i| { + (a.commitments[i] + b.commitments[i].mul(challenge)).into() + }), + challenges: array::from_fn(|i| a.challenges[i] + challenge * b.challenges[i]), + alphas: Alphas::combine(a.alphas, b.alphas, challenge), + blinder: a.blinder + challenge * b.blinder, + } + } +} + +impl Instance for FoldingInstance { + fn to_absorb(&self) -> (Vec<::ScalarField>, Vec) { + // FIXME: check!!!! + let mut scalars = Vec::new(); + let mut points = Vec::new(); + points.extend(self.commitments); + scalars.extend(self.challenges); + scalars.extend(self.alphas.clone().powers()); + (scalars, points) + } + + fn get_alphas(&self) -> &Alphas { + &self.alphas + } + + fn get_blinder(&self) -> ::ScalarField { + self.blinder + } +} + +impl Index for FoldingInstance { + type Output = G::ScalarField; + + fn index(&self, index: Challenge) -> &Self::Output { + match index { + Challenge::Beta => &self.challenges[0], + Challenge::Gamma => &self.challenges[1], + Challenge::JointCombiner => &self.challenges[2], + } + } +} + +/// Includes the data witness columns and also the dynamic selector columns +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct FoldingWitness { + pub witness: GenericWitness>>, +} + +impl Foldable for FoldingWitness { + fn combine(a: Self, b: Self, challenge: F) -> Self { + Self { + witness: GenericWitness::combine(a.witness, b.witness, challenge), + } + } +} + +impl Witness for FoldingWitness {} + +/// Environment for the decomposable folding protocol, for a given number of +/// witness columns and selectors. +pub struct DecomposedFoldingEnvironment< + const N: usize, + const N_REL: usize, + const N_DSEL: usize, + C: FoldingConfig, + Structure, +> { + pub structure: Structure, + /// Commitments to the witness columns, for both sides + pub instances: [FoldingInstance; 2], + /// Corresponds to the omega evaluations, for both sides + pub curr_witnesses: [FoldingWitness>; 2], + /// Corresponds to the zeta*omega evaluations, for both sides + /// This is curr_witness but left shifted by 1 + pub next_witnesses: [FoldingWitness>; 2], +} + +impl< + const N: usize, + const N_REL: usize, + const N_SEL: usize, + C: FoldingConfig, + // FIXME: Clone should not be used. Only a reference should be stored + Structure: Clone, + > + FoldingEnv< + ScalarField, + FoldingInstance, + FoldingWitness>, + C::Column, + Challenge, + C::Selector, + > for DecomposedFoldingEnvironment +where + // Used by col and selector + FoldingWitness>: Index< + C::Column, + Output = Evaluations, Radix2EvaluationDomain>>, + >, + FoldingWitness>: Index< + C::Selector, + Output = Evaluations, Radix2EvaluationDomain>>, + >, +{ + type Structure = Structure; + + fn new( + structure: &Self::Structure, + instances: [&FoldingInstance; 2], + witnesses: [&FoldingWitness>; 2], + ) -> Self { + let curr_witnesses = [witnesses[0].clone(), witnesses[1].clone()]; + let mut next_witnesses = curr_witnesses.clone(); + for side in next_witnesses.iter_mut() { + for col in side.witness.cols.iter_mut() { + col.evals.rotate_left(1); + } + } + DecomposedFoldingEnvironment { + // FIXME: This is a clone, but it should be a reference + structure: structure.clone(), + instances: [instances[0].clone(), instances[1].clone()], + curr_witnesses, + next_witnesses, + } + } + + fn col(&self, col: C::Column, curr_or_next: CurrOrNext, side: Side) -> &[ScalarField] { + let wit = match curr_or_next { + CurrOrNext::Curr => &self.curr_witnesses[side as usize], + CurrOrNext::Next => &self.next_witnesses[side as usize], + }; + // The following is possible because Index is implemented for our circuit witnesses + &wit[col].evals + } + + fn challenge(&self, challenge: Challenge, side: Side) -> ScalarField { + match challenge { + Challenge::Beta => self.instances[side as usize].challenges[0], + Challenge::Gamma => self.instances[side as usize].challenges[1], + Challenge::JointCombiner => self.instances[side as usize].challenges[2], + } + } + + fn selector(&self, s: &C::Selector, side: Side) -> &[ScalarField] { + let witness = &self.curr_witnesses[side as usize]; + &witness[*s].evals + } +} + +pub struct FoldingEnvironment { + /// Structure of the folded circuit + pub structure: Structure, + /// Commitments to the witness columns, for both sides + pub instances: [FoldingInstance; 2], + /// Corresponds to the evaluations at ω, for both sides + pub curr_witnesses: [FoldingWitness>; 2], + /// Corresponds to the evaluations at ζω, for both sides + /// This is curr_witness but left shifted by 1 + pub next_witnesses: [FoldingWitness>; 2], +} + +impl< + const N: usize, + C: FoldingConfig, + // FIXME: Clone should not be used. Only a reference should be stored + Structure: Clone, + > + FoldingEnv< + ScalarField, + FoldingInstance, + FoldingWitness>, + C::Column, + Challenge, + (), + > for FoldingEnvironment +where + // Used by col and selector + FoldingWitness>: Index< + C::Column, + Output = Evaluations, Radix2EvaluationDomain>>, + >, +{ + type Structure = Structure; + + fn new( + structure: &Self::Structure, + instances: [&FoldingInstance; 2], + witnesses: [&FoldingWitness>; 2], + ) -> Self { + let curr_witnesses = [witnesses[0].clone(), witnesses[1].clone()]; + let mut next_witnesses = curr_witnesses.clone(); + for side in next_witnesses.iter_mut() { + for col in side.witness.cols.iter_mut() { + col.evals.rotate_left(1); + } + } + FoldingEnvironment { + // FIXME: This is a clone, but it should be a reference + structure: structure.clone(), + instances: [instances[0].clone(), instances[1].clone()], + curr_witnesses, + next_witnesses, + } + } + + fn col(&self, col: C::Column, curr_or_next: CurrOrNext, side: Side) -> &[ScalarField] { + let wit = match curr_or_next { + CurrOrNext::Curr => &self.curr_witnesses[side as usize], + CurrOrNext::Next => &self.next_witnesses[side as usize], + }; + // The following is possible because Index is implemented for our circuit witnesses + &wit[col].evals + } + + fn challenge(&self, challenge: Challenge, side: Side) -> ScalarField { + match challenge { + Challenge::Beta => self.instances[side as usize].challenges[0], + Challenge::Gamma => self.instances[side as usize].challenges[1], + Challenge::JointCombiner => self.instances[side as usize].challenges[2], + } + } + + fn selector(&self, _s: &(), _side: Side) -> &[ScalarField] { + unimplemented!("Selector not implemented for FoldingEnvironment. No selectors are supposed to be used when there is only one instruction.") + } +} + +pub mod keccak { + use std::ops::Index; + + use ark_poly::{Evaluations, Radix2EvaluationDomain}; + use folding::{ + checker::{Checker, ExtendedProvider, Provider}, + expressions::FoldingColumnTrait, + FoldingConfig, + }; + use kimchi_msm::columns::Column; + use poly_commitment::kzg::PairingSRS; + + use crate::{ + interpreters::keccak::{ + column::{ + ColumnAlias as KeccakColumn, N_ZKVM_KECCAK_COLS, N_ZKVM_KECCAK_REL_COLS, + N_ZKVM_KECCAK_SEL_COLS, + }, + Steps, + }, + legacy::{Curve, Fp, Pairing}, + }; + + use super::{Challenge, DecomposedFoldingEnvironment, FoldingInstance, FoldingWitness}; + + pub type KeccakFoldingEnvironment = DecomposedFoldingEnvironment< + N_ZKVM_KECCAK_COLS, + N_ZKVM_KECCAK_REL_COLS, + N_ZKVM_KECCAK_SEL_COLS, + KeccakConfig, + (), + >; + + pub type KeccakFoldingWitness = FoldingWitness; + pub type KeccakFoldingInstance = FoldingInstance; + + impl Index for KeccakFoldingWitness { + type Output = Evaluations>; + + fn index(&self, index: KeccakColumn) -> &Self::Output { + &self.witness.cols[usize::from(index)] + } + } + + // Implemented for decomposable folding compatibility + impl Index for KeccakFoldingWitness { + type Output = Evaluations>; + + /// Map a selector column to the corresponding witness column. + fn index(&self, index: Steps) -> &Self::Output { + &self.witness.cols[usize::from(index)] + } + } + + // Implementing this so that generic constraints can be used in folding + impl Index for KeccakFoldingWitness { + type Output = Evaluations>; + + /// Map a column alias to the corresponding witness column. + fn index(&self, index: Column) -> &Self::Output { + match index { + Column::Relation(ix) => &self.witness.cols[ix], + // Even if `Column::DynamicSelector(ix)` would correspond to + // `&self.witness.cols[N_ZKVM_KECCAK_REL_COLS + ix]`, the + // current design of constraints should not include the dynamic + // selectors. Instead, folding will add them in the + // `DecomposableFoldingScheme` as extended selector columns, and + // the `selector()` function inside the `FoldingEnv` will return + // the actual witness column values. + _ => panic!("Undesired column type inside expressions"), + } + } + } + + #[derive(Clone, Debug, PartialEq, Eq, Hash)] + pub struct KeccakConfig; + + impl FoldingColumnTrait for KeccakColumn { + fn is_witness(&self) -> bool { + // dynamic selectors KeccakColumn::Selector() count as witnesses + true + } + } + + impl FoldingConfig for KeccakConfig { + type Column = Column; + type Selector = Steps; + type Challenge = Challenge; + type Curve = Curve; + type Srs = PairingSRS; + type Instance = KeccakFoldingInstance; + type Witness = KeccakFoldingWitness; + type Structure = (); + type Env = KeccakFoldingEnvironment; + } + + // IMPLEMENT CHECKER TRAITS + + impl Checker for ExtendedProvider {} + impl Checker for Provider {} +} + +pub mod mips { + use std::ops::Index; + + use ark_poly::{Evaluations, Radix2EvaluationDomain}; + use folding::{expressions::FoldingColumnTrait, FoldingConfig}; + use kimchi_msm::columns::Column; + use poly_commitment::kzg::PairingSRS; + + use crate::{ + interpreters::mips::{ + column::{ColumnAlias as MIPSColumn, N_MIPS_COLS, N_MIPS_REL_COLS, N_MIPS_SEL_COLS}, + Instruction, + }, + legacy::{Curve, Fp, Pairing}, + }; + + use super::{Challenge, DecomposedFoldingEnvironment, FoldingInstance, FoldingWitness}; + + // Decomposable folding compatibility + pub type DecomposableMIPSFoldingEnvironment = DecomposedFoldingEnvironment< + N_MIPS_COLS, + N_MIPS_REL_COLS, + N_MIPS_SEL_COLS, + DecomposableMIPSFoldingConfig, + (), + >; + + pub type MIPSFoldingWitness = FoldingWitness; + pub type MIPSFoldingInstance = FoldingInstance; + + // -- Start indexer implementations + // Implement indexers over columns and selectors to implement an abstract + // folding environment over selectors, see [crate::folding::FoldingEnvironment] + // for more details + + impl Index for FoldingWitness { + type Output = Evaluations>; + + fn index(&self, index: Column) -> &Self::Output { + match index { + Column::Relation(ix) => &self.witness.cols[ix], + _ => panic!("Invalid column type"), + } + } + } + + impl Index for MIPSFoldingWitness { + type Output = Evaluations>; + + fn index(&self, index: MIPSColumn) -> &Self::Output { + &self.witness.cols[usize::from(index)] + } + } + + // Implemented for decomposable folding compatibility + impl Index for MIPSFoldingWitness { + type Output = Evaluations>; + + /// Map a selector column to the corresponding witness column. + fn index(&self, index: Instruction) -> &Self::Output { + &self.witness.cols[usize::from(index)] + } + } + + // Implementing this so that generic constraints can be used in folding + impl Index for MIPSFoldingWitness { + type Output = Evaluations>; + + /// Map a column alias to the corresponding witness column. + fn index(&self, index: Column) -> &Self::Output { + match index { + Column::Relation(ix) => &self.witness.cols[ix], + Column::DynamicSelector(ix) => &self.witness.cols[N_MIPS_REL_COLS + ix], + _ => panic!("Invalid column type"), + } + } + } + // -- End of indexer implementations + + #[derive(Clone, Debug, PartialEq, Eq, Hash)] + pub struct DecomposableMIPSFoldingConfig; + + impl FoldingColumnTrait for MIPSColumn { + fn is_witness(&self) -> bool { + // All MIPS columns are witness columns + true + } + } + + impl FoldingConfig for DecomposableMIPSFoldingConfig { + type Column = Column; + type Selector = Instruction; + type Challenge = Challenge; + type Curve = Curve; + type Srs = PairingSRS; + type Instance = MIPSFoldingInstance; + type Witness = MIPSFoldingWitness; + type Structure = (); + type Env = DecomposableMIPSFoldingEnvironment; + } +} + +#[cfg(test)] +mod tests { + use crate::legacy::{ + folding::{FoldingInstance, FoldingWitness, *}, + Curve, Fp, Pairing, + }; + use ark_poly::{Evaluations, Radix2EvaluationDomain}; + use folding::{ + expressions::{FoldingColumnTrait, FoldingCompatibleExpr, FoldingCompatibleExprInner}, + FoldingConfig, + }; + use kimchi::{ + circuits::expr::{ + ConstantExprInner, ConstantTerm, Constants, Expr, ExprInner, Literal, Variable, + }, + curve::KimchiCurve, + }; + use poly_commitment::kzg::PairingSRS; + use std::ops::Index; + + #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialOrd, PartialEq)] + enum TestColumn { + X, + Y, + Z, + } + + #[derive(Clone, Debug, PartialEq, Eq, Hash)] + struct TestConfig; + + type TestWitness = kimchi_msm::witness::Witness<3, T>; + type TestFoldingWitness = FoldingWitness<3, Fp>; + type TestFoldingInstance = FoldingInstance<3, Curve>; + type TestFoldingEnvironment = FoldingEnvironment<3, TestConfig, ()>; + + impl Index for TestFoldingWitness { + type Output = Evaluations>; + + fn index(&self, index: TestColumn) -> &Self::Output { + &self.witness[index] + } + } + + impl FoldingColumnTrait for TestColumn { + fn is_witness(&self) -> bool { + true + } + } + + impl Index for TestWitness { + type Output = T; + fn index(&self, index: TestColumn) -> &Self::Output { + match index { + TestColumn::X => &self.cols[0], + TestColumn::Y => &self.cols[1], + TestColumn::Z => &self.cols[2], + } + } + } + + impl FoldingConfig for TestConfig { + type Column = TestColumn; + type Challenge = Challenge; + type Selector = (); + type Curve = Curve; + type Srs = PairingSRS; + type Instance = TestFoldingInstance; + type Witness = TestFoldingWitness; + type Structure = (); + type Env = TestFoldingEnvironment; + } + + #[test] + fn test_conversion() { + use super::*; + use kimchi::circuits::berkeley_columns::BerkeleyChallengeTerm; + + // Check that the conversion from ChallengeTerm to Challenge works as expected + assert_eq!(Challenge::Beta, BerkeleyChallengeTerm::Beta.into()); + assert_eq!(Challenge::Gamma, BerkeleyChallengeTerm::Gamma.into()); + assert_eq!( + Challenge::JointCombiner, + BerkeleyChallengeTerm::JointCombiner.into() + ); + + // Create my special constants + let constants = Constants { + endo_coefficient: Fp::from(3), + mds: &Curve::sponge_params().mds, + zk_rows: 0, + }; + + // Define variables to be used in larger expressions + let x = Expr::Atom(ExprInner::Cell::< + ConstantExprInner, + TestColumn, + >(Variable { + col: TestColumn::X, + row: CurrOrNext::Curr, + })); + let y = Expr::Atom(ExprInner::Cell::< + ConstantExprInner, + TestColumn, + >(Variable { + col: TestColumn::Y, + row: CurrOrNext::Curr, + })); + let z = Expr::Atom(ExprInner::Cell::< + ConstantExprInner, + TestColumn, + >(Variable { + col: TestColumn::Z, + row: CurrOrNext::Curr, + })); + let endo = Expr::Atom(ExprInner::< + ConstantExprInner, + TestColumn, + >::Constant(ConstantExprInner::Constant( + ConstantTerm::EndoCoefficient, + ))); + + // Define variables with folding expressions + let x_f = + FoldingCompatibleExpr::::Atom(FoldingCompatibleExprInner::Cell(Variable { + col: TestColumn::X, + row: CurrOrNext::Curr, + })); + let y_f = + FoldingCompatibleExpr::::Atom(FoldingCompatibleExprInner::Cell(Variable { + col: TestColumn::Y, + row: CurrOrNext::Curr, + })); + let z_f = + FoldingCompatibleExpr::::Atom(FoldingCompatibleExprInner::Cell(Variable { + col: TestColumn::Z, + row: CurrOrNext::Curr, + })); + + // Check conversion of general expressions + let xyz = x.clone() * y * z; + let xyz_f = FoldingCompatibleExpr::::Mul( + Box::new(FoldingCompatibleExpr::::Mul( + Box::new(x_f.clone()), + Box::new(y_f), + )), + Box::new(z_f), + ); + assert_eq!(FoldingCompatibleExpr::::from(xyz), xyz_f); + + let x_endo = x + endo; + let x_endo_f = FoldingCompatibleExpr::::Add( + Box::new(x_f), + Box::new(FoldingCompatibleExpr::::Atom( + FoldingCompatibleExprInner::Constant(constants.endo_coefficient), + )), + ); + let x_endo_lit = x_endo.as_literal(&constants); + assert_eq!( + FoldingCompatibleExpr::::from(x_endo_lit), + x_endo_f + ); + } +} diff --git a/o1vm/src/legacy/main.rs b/o1vm/src/legacy/main.rs new file mode 100644 index 0000000000..749f5722b6 --- /dev/null +++ b/o1vm/src/legacy/main.rs @@ -0,0 +1,331 @@ +use ark_ff::UniformRand; +use folding::decomposable_folding::DecomposableFoldingScheme; +use kimchi::o1_utils; +use kimchi_msm::{proof::ProofInputs, prover::prove, verifier::verify, witness::Witness}; +use log::debug; +use o1vm::{ + cannon::{self, Meta, Start, State}, + cannon_cli, + interpreters::{ + keccak::{ + column::{Steps, N_ZKVM_KECCAK_COLS, N_ZKVM_KECCAK_REL_COLS, N_ZKVM_KECCAK_SEL_COLS}, + environment::KeccakEnv, + }, + mips::{ + column::{N_MIPS_COLS, N_MIPS_REL_COLS, N_MIPS_SEL_COLS, SCRATCH_SIZE}, + constraints as mips_constraints, + interpreter::Instruction, + witness::{self as mips_witness}, + }, + }, + legacy::{ + folding::mips::DecomposableMIPSFoldingConfig, + proof, + trace::{ + keccak::DecomposedKeccakTrace, mips::DecomposedMIPSTrace, DecomposableTracer, Foldable, + Tracer, + }, + BaseSponge, Fp, OpeningProof, ScalarSponge, + }, + lookups::LookupTableIDs, + preimage_oracle::PreImageOracle, +}; +use poly_commitment::SRS as _; +use std::{cmp::Ordering, collections::HashMap, fs::File, io::BufReader, process::ExitCode}; +use strum::IntoEnumIterator; + +/// Domain size shared by the Keccak evaluations, MIPS evaluation and main +/// program. +pub const DOMAIN_SIZE: usize = 1 << 15; + +pub fn main() -> ExitCode { + let cli = cannon_cli::main_cli(); + + let configuration = cannon_cli::read_configuration(&cli.get_matches()); + + let file = + File::open(&configuration.input_state_file).expect("Error opening input state file "); + + let reader = BufReader::new(file); + // Read the JSON contents of the file as an instance of `State`. + let state: State = serde_json::from_reader(reader).expect("Error reading input state file"); + + let meta_file = File::open(&configuration.metadata_file).unwrap_or_else(|_| { + panic!( + "Could not open metadata file {}", + &configuration.metadata_file + ) + }); + + let meta: Meta = serde_json::from_reader(BufReader::new(meta_file)).unwrap_or_else(|_| { + panic!( + "Error deserializing metadata file {}", + &configuration.metadata_file + ) + }); + + let mut po = PreImageOracle::create(&configuration.host); + let _child = po.start(); + + // Initialize some data used for statistical computations + let start = Start::create(state.step as usize); + + env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); + + let domain = kimchi::circuits::domains::EvaluationDomains::::create(DOMAIN_SIZE).unwrap(); + + let mut rng = o1_utils::tests::make_test_rng(None); + + let srs = { + // FIXME: toxic waste is generated in `create`. This is unsafe for prod. + let srs = poly_commitment::kzg::PairingSRS::create(DOMAIN_SIZE); + srs.get_lagrange_basis(domain.d1); + srs + }; + + // Initialize the environments + // The Keccak environment is extracted inside the loop + let mut mips_wit_env = + mips_witness::Env::::create(cannon::PAGE_SIZE as usize, state, po); + let mut mips_con_env = mips_constraints::Env::::default(); + // The keccak environment is extracted inside the loop + + // Initialize the circuits. Includes pre-folding witnesses. + let mut mips_trace = DecomposedMIPSTrace::new(DOMAIN_SIZE, &mut mips_con_env); + let mut keccak_trace = DecomposedKeccakTrace::new(DOMAIN_SIZE, &mut KeccakEnv::::default()); + + let _mips_folding = { + DecomposableFoldingScheme::::new( + >::folding_constraints(&mips_trace), + vec![], + &srs, + domain.d1, + &(), + ) + }; + + // Initialize folded instances of the sub circuits + let mut mips_folded_instance = HashMap::new(); + for instr in Instruction::iter().flat_map(|x| x.into_iter()) { + mips_folded_instance.insert( + instr, + ProofInputs::::default(), + ); + } + + let mut keccak_folded_instance = HashMap::new(); + for step in Steps::iter().flat_map(|x| x.into_iter()) { + keccak_folded_instance.insert( + step, + ProofInputs::::default(), + ); + } + + while !mips_wit_env.halt { + let instr = mips_wit_env.step(&configuration, &meta, &start); + + if let Some(ref mut keccak_env) = mips_wit_env.keccak_env { + // Run all steps of hash + while keccak_env.step.is_some() { + // Get the current standardize step that is being executed + let step = keccak_env.selector(); + // Run the interpreter, which sets the witness columns + keccak_env.step(); + // Add the witness row to the Keccak circuit for this step + keccak_trace.push_row(step, &keccak_env.witness_env.witness.cols); + + // If the witness is full, fold it and reset the pre-folding witness + if keccak_trace.number_of_rows(step) == DOMAIN_SIZE { + // Set to zero all selectors except for the one corresponding to the current instruction + keccak_trace.set_selector_column::(step, DOMAIN_SIZE); + proof::fold::( + domain, + &srs, + keccak_folded_instance.get_mut(&step).unwrap(), + &keccak_trace[step].witness, + ); + keccak_trace.reset(step); + } + } + // When the Keccak interpreter is finished, we can reset the environment + mips_wit_env.keccak_env = None; + } + + // TODO: unify witness of MIPS to include scratch state, instruction counter, and error + for i in 0..N_MIPS_REL_COLS { + match i.cmp(&SCRATCH_SIZE) { + Ordering::Less => mips_trace.trace.get_mut(&instr).unwrap().witness.cols[i] + .push(mips_wit_env.scratch_state[i]), + Ordering::Equal => mips_trace.trace.get_mut(&instr).unwrap().witness.cols[i] + .push(Fp::from(mips_wit_env.instruction_counter)), + Ordering::Greater => { + // TODO: error + mips_trace.trace.get_mut(&instr).unwrap().witness.cols[i] + .push(Fp::rand(&mut rand::rngs::OsRng)) + } + } + } + + if mips_trace.number_of_rows(instr) == DOMAIN_SIZE { + // Set to zero all selectors except for the one corresponding to the current instruction + mips_trace.set_selector_column::(instr, DOMAIN_SIZE); + proof::fold::( + domain, + &srs, + mips_folded_instance.get_mut(&instr).unwrap(), + &mips_trace[instr].witness, + ); + mips_trace.reset(instr); + } + } + + // Pad any possible remaining rows if the execution was not a multiple of the domain size + for instr in Instruction::iter().flat_map(|x| x.into_iter()) { + // Start by padding with the first row + let needs_folding = mips_trace.pad_dummy(instr) != 0; + if needs_folding { + // Then set the selector columns (all of them, none has selectors set) + mips_trace.set_selector_column::(instr, DOMAIN_SIZE); + + // Finally fold instance + proof::fold::( + domain, + &srs, + mips_folded_instance.get_mut(&instr).unwrap(), + &mips_trace[instr].witness, + ); + } + } + for step in Steps::iter().flat_map(|x| x.into_iter()) { + let needs_folding = keccak_trace.pad_dummy(step) != 0; + if needs_folding { + keccak_trace.set_selector_column::(step, DOMAIN_SIZE); + + proof::fold::( + domain, + &srs, + keccak_folded_instance.get_mut(&step).unwrap(), + &keccak_trace[step].witness, + ); + } + } + + { + // MIPS + for instr in Instruction::iter().flat_map(|x| x.into_iter()) { + // Prove only if the instruction was executed + // and if the number of constraints is nonzero (otherwise quotient polynomial cannot be created) + if mips_trace.in_circuit(instr) && !mips_trace[instr].constraints.is_empty() { + debug!("Checking MIPS circuit {:?}", instr); + let mips_result = prove::< + _, + OpeningProof, + BaseSponge, + ScalarSponge, + _, + N_MIPS_COLS, + N_MIPS_REL_COLS, + N_MIPS_SEL_COLS, + 0, + LookupTableIDs, + >( + domain, + &srs, + &mips_trace[instr].constraints, + Box::new([]), + mips_folded_instance[&instr].clone(), + &mut rng, + ); + let mips_proof = mips_result.unwrap(); + debug!("Generated a MIPS {:?} proof:", instr); + let mips_verifies = verify::< + _, + OpeningProof, + BaseSponge, + ScalarSponge, + N_MIPS_COLS, + N_MIPS_REL_COLS, + N_MIPS_SEL_COLS, + 0, + 0, + LookupTableIDs, + >( + domain, + &srs, + &mips_trace[instr].constraints, + Box::new([]), + &mips_proof, + Witness::zero_vec(DOMAIN_SIZE), + ); + if mips_verifies { + debug!("The MIPS {:?} proof verifies\n", instr) + } else { + debug!("The MIPS {:?} proof doesn't verify\n", instr) + } + } + } + } + + { + // KECCAK + // FIXME: when folding is applied, the error term will be created to satisfy the folded witness + for step in Steps::iter().flat_map(|x| x.into_iter()) { + // Prove only if the instruction was executed + if keccak_trace.in_circuit(step) { + debug!("Checking Keccak circuit {:?}", step); + let keccak_result = prove::< + _, + OpeningProof, + BaseSponge, + ScalarSponge, + _, + N_ZKVM_KECCAK_COLS, + N_ZKVM_KECCAK_REL_COLS, + N_ZKVM_KECCAK_SEL_COLS, + 0, + LookupTableIDs, + >( + domain, + &srs, + &keccak_trace[step].constraints, + Box::new([]), + keccak_folded_instance[&step].clone(), + &mut rng, + ); + let keccak_proof = keccak_result.unwrap(); + debug!("Generated a Keccak {:?} proof:", step); + let keccak_verifies = verify::< + _, + OpeningProof, + BaseSponge, + ScalarSponge, + N_ZKVM_KECCAK_COLS, + N_ZKVM_KECCAK_REL_COLS, + N_ZKVM_KECCAK_SEL_COLS, + 0, + 0, + LookupTableIDs, + >( + domain, + &srs, + &keccak_trace[step].constraints, + Box::new([]), + &keccak_proof, + Witness::zero_vec(DOMAIN_SIZE), + ); + if keccak_verifies { + debug!("The Keccak {:?} proof verifies\n", step) + } else { + debug!("The Keccak {:?} proof doesn't verify\n", step) + } + } + } + } + + // TODO: Logic + ExitCode::SUCCESS +} diff --git a/o1vm/src/legacy/mod.rs b/o1vm/src/legacy/mod.rs new file mode 100644 index 0000000000..d4b9e8027d --- /dev/null +++ b/o1vm/src/legacy/mod.rs @@ -0,0 +1,37 @@ +//! This submodule provides the legacy flavor/interface of the o1vm, and is not +//! supposed to be used anymore. +//! +//! This o1vm flavor was supposed to use the [folding](folding) library defined +//! in [folding](folding), which consists of reducing all constraints to degree +//! 2, in addition to the `ivc` library defined in this monorepo to support long +//! traces. +//! The goal of this flavor was to support the curve `bn254`. For the time +//! being, the project has been stopped in favor of the pickles version defined +//! in [crate::pickles] and we do not aim to provide any support for now. +//! +//! You can still run the legacy flavor by using: +//! +//! ```bash +//! O1VM_FLAVOR=legacy bash run-code.sh +//! ``` + +use ark_ec::bn::Bn; +use mina_poseidon::{ + constants::PlonkSpongeConstantsKimchi, + sponge::{DefaultFqSponge, DefaultFrSponge}, +}; +use poly_commitment::kzg::KZGProof; + +/// Scalar field of BN254 +pub type Fp = ark_bn254::Fr; +/// Elliptic curve group of BN254 +pub type Curve = ark_bn254::G1Affine; +pub type Pairing = ark_bn254::Bn254; +pub type BaseSponge = DefaultFqSponge; +pub type ScalarSponge = DefaultFrSponge; +pub type SpongeParams = PlonkSpongeConstantsKimchi; +pub type OpeningProof = KZGProof>; + +pub mod folding; +pub mod proof; +pub mod trace; diff --git a/o1vm/src/legacy/proof.rs b/o1vm/src/legacy/proof.rs new file mode 100644 index 0000000000..5e352f4332 --- /dev/null +++ b/o1vm/src/legacy/proof.rs @@ -0,0 +1,108 @@ +use crate::lookups::LookupTableIDs; +use ark_poly::{Evaluations, Radix2EvaluationDomain as D}; +use kimchi::{circuits::domains::EvaluationDomains, curve::KimchiCurve, plonk_sponge::FrSponge}; +use kimchi_msm::{proof::ProofInputs, witness::Witness}; +use mina_poseidon::{sponge::ScalarChallenge, FqSponge}; +use poly_commitment::{commitment::absorb_commitment, OpenProof, SRS as _}; +use rayon::iter::{ + IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator, +}; + +/// FIXME: DUMMY FOLD FUNCTION THAT ONLY KEEPS THE LAST INSTANCE +pub fn fold< + const N: usize, + G: KimchiCurve, + OpeningProof: OpenProof, + EFqSponge: Clone + FqSponge, + EFrSponge: FrSponge, +>( + domain: EvaluationDomains, + srs: &OpeningProof::SRS, + accumulator: &mut ProofInputs, + inputs: &Witness>, +) where + >::SRS: std::marker::Sync, +{ + let commitments = { + inputs + .par_iter() + .map(|evals: &Vec| { + let evals = Evaluations::>::from_vec_and_domain( + evals.clone(), + domain.d1, + ); + srs.commit_evaluations_non_hiding(domain.d1, &evals) + }) + .collect::>() + }; + let mut fq_sponge = EFqSponge::new(G::other_curve_sponge_params()); + + commitments.into_iter().for_each(|comm| { + absorb_commitment(&mut fq_sponge, &comm); + }); + + // TODO: fold mvlookups as well + accumulator + .evaluations + .par_iter_mut() + .zip(inputs.par_iter()) + .for_each(|(accumulator, inputs)| { + accumulator + .par_iter_mut() + .zip(inputs.par_iter()) + .for_each(|(accumulator, input)| *accumulator = *input); + }); +} + +#[allow(dead_code)] +/// This function folds the witness of the current circuit with the accumulated Keccak instance +/// with a random combination using a scaling challenge +// FIXME: This will have to be adapted when the folding library is available +fn old_fold< + const N: usize, + G: KimchiCurve, + OpeningProof: OpenProof, + EFqSponge: Clone + FqSponge, +>( + domain: EvaluationDomains, + srs: &OpeningProof::SRS, + accumulator: &mut ProofInputs, + inputs: &Witness>, +) where + >::SRS: std::marker::Sync, +{ + let commitments = { + inputs + .par_iter() + .map(|evals: &Vec| { + let evals = Evaluations::>::from_vec_and_domain( + evals.clone(), + domain.d1, + ); + srs.commit_evaluations_non_hiding(domain.d1, &evals) + }) + .collect::>() + }; + let mut fq_sponge = EFqSponge::new(G::other_curve_sponge_params()); + + commitments.into_iter().for_each(|comm| { + absorb_commitment(&mut fq_sponge, &comm); + }); + + let scaling_challenge = ScalarChallenge(fq_sponge.challenge()); + let (_, endo_r) = G::endos(); + let scaling_challenge = scaling_challenge.to_field(endo_r); + // TODO: fold mvlookups as well + accumulator + .evaluations + .par_iter_mut() + .zip(inputs.par_iter()) + .for_each(|(accumulator, inputs)| { + accumulator + .par_iter_mut() + .zip(inputs.par_iter()) + .for_each(|(accumulator, input)| { + *accumulator = *input + scaling_challenge * *accumulator + }); + }); +} diff --git a/o1vm/src/legacy/tests.rs b/o1vm/src/legacy/tests.rs new file mode 100644 index 0000000000..bce2990d86 --- /dev/null +++ b/o1vm/src/legacy/tests.rs @@ -0,0 +1,742 @@ +use crate::{ + interpreters::mips::{ + column::{N_MIPS_REL_COLS, SCRATCH_SIZE}, + constraints::Env as CEnv, + interpreter::{debugging::InstructionParts, interpret_itype}, + ITypeInstruction, + }, + legacy::{ + folding::{ + Challenge, DecomposedMIPSTrace, FoldingEnvironment, FoldingInstance, FoldingWitness, + }, + trace::Trace, + }, + BaseSponge, Curve, +}; + +use ark_ff::One; +use ark_poly::{EvaluationDomain as _, Evaluations, Radix2EvaluationDomain as D}; +use folding::{expressions::FoldingCompatibleExpr, Alphas, FoldingConfig, FoldingScheme}; +use itertools::Itertools; +use kimchi::{curve::KimchiCurve, o1_utils}; +use kimchi_msm::{columns::Column, witness::Witness}; +use mina_poseidon::FqSponge; +use poly_commitment::{commitment::absorb_commitment, kzg::PairingSRS, PolyComm, SRS as _}; +use rand::{CryptoRng, Rng, RngCore}; +use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _}; + +pub mod mips { + fn make_random_witness_for_addiu( + domain_size: usize, + rng: &mut RNG, + ) -> Witness> + where + RNG: RngCore + CryptoRng, + { + let mut dummy_env = dummy_env(rng); + let instr = ITypeInstruction::AddImmediateUnsigned; + let a = std::array::from_fn(|_| Vec::with_capacity(domain_size)); + let mut witness = Witness { cols: Box::new(a) }; + // Building trace for AddImmediateUnsigned + (0..domain_size).for_each(|_i| { + // Registers should not conflict because RAMlookup does not support + // reading from/writing into the same register in the same + // instruction + let reg_src = rng.gen_range(0..10); + let reg_dest = rng.gen_range(10..20); + // we simulate always the same instruction, with some random values + // and registers + write_instruction( + &mut dummy_env, + InstructionParts { + op_code: 0b001001, + rs: reg_src, // source register + rt: reg_dest, // destination register + // The rest is the immediate value + rd: rng.gen_range(0..32), + shamt: rng.gen_range(0..32), + funct: rng.gen_range(0..64), + }, + ); + interpret_itype(&mut dummy_env, instr); + + for j in 0..SCRATCH_SIZE { + witness.cols[j].push(dummy_env.scratch_state[j]); + } + witness.cols[SCRATCH_SIZE].push(Fp::from(dummy_env.instruction_counter)); + witness.cols[SCRATCH_SIZE + 1].push(Fp::from(0)); + dummy_env.instruction_counter += 1; + + dummy_env.reset_scratch_state() + }); + // sanity check + witness + .cols + .iter() + .for_each(|x| assert_eq!(x.len(), domain_size)); + witness + } + + fn build_folding_instance( + witness: &FoldingWitness, + fq_sponge: &mut BaseSponge, + domain: D, + srs: &PairingSRS, + ) -> FoldingInstance { + let commitments: Witness> = (&witness.witness) + .into_par_iter() + .map(|w| srs.commit_evaluations_non_hiding(domain, w)) + .collect(); + + // Absorbing commitments + (&commitments) + .into_iter() + .for_each(|c| absorb_commitment(fq_sponge, c)); + + let commitments: [Curve; N_MIPS_REL_COLS] = commitments + .into_iter() + .map(|c| c.elems[0]) + .collect_vec() + .try_into() + .unwrap(); + + let beta = fq_sponge.challenge(); + let gamma = fq_sponge.challenge(); + let joint_combiner = fq_sponge.challenge(); + let alpha = fq_sponge.challenge(); + let challenges = [beta, gamma, joint_combiner]; + let alphas = Alphas::new(alpha); + let blinder = Fp::one(); + + FoldingInstance { + commitments, + challenges, + alphas, + blinder, + } + } + + // Manually change the number of constraints if they are modififed in the + // interpreter + // FIXME: can be moved up, in interpreters::mips, wihthout using the + // DecomposedMIPSTrace + #[test] + fn test_mips_number_constraints() { + let domain_size = 1 << 8; + + // Initialize the environment and run the interpreter + let mut constraints_env = Env:: { + scratch_state_idx: 0, + constraints: Vec::new(), + lookups: Vec::new(), + }; + + // Keep track of the constraints and lookups of the sub-circuits + let mips_circuit = DecomposedMIPSTrace::new(domain_size, &mut constraints_env); + + let assert_num_constraints = |instr: &Instruction, num: usize| { + assert_eq!( + mips_circuit.trace.get(instr).unwrap().constraints.len(), + num + ) + }; + + let mut i = 0; + for instr in Instruction::iter().flat_map(|x| x.into_iter()) { + match instr { + RType(rtype) => match rtype { + JumpRegister | SyscallExitGroup | Sync => assert_num_constraints(&instr, 1), + ShiftLeftLogical + | ShiftRightLogical + | ShiftRightArithmetic + | ShiftLeftLogicalVariable + | ShiftRightLogicalVariable + | ShiftRightArithmeticVariable + | JumpAndLinkRegister + | SyscallReadHint + | MoveFromHi + | MoveFromLo + | MoveToLo + | MoveToHi + | Add + | AddUnsigned + | Sub + | SubUnsigned + | And + | Or + | Xor + | Nor + | SetLessThan + | SetLessThanUnsigned + | MultiplyToRegister + | CountLeadingOnes + | CountLeadingZeros => assert_num_constraints(&instr, 4), + MoveZero | MoveNonZero => assert_num_constraints(&instr, 6), + SyscallReadOther | SyscallWriteHint | SyscallWriteOther | Multiply + | MultiplyUnsigned | Div | DivUnsigned => assert_num_constraints(&instr, 7), + SyscallOther => assert_num_constraints(&instr, 11), + SyscallMmap => assert_num_constraints(&instr, 12), + SyscallFcntl | SyscallReadPreimage => assert_num_constraints(&instr, 23), + // TODO: update SyscallReadPreimage to 31 when using self.equal() + SyscallWritePreimage => assert_num_constraints(&instr, 31), + }, + JType(jtype) => match jtype { + Jump => assert_num_constraints(&instr, 1), + JumpAndLink => assert_num_constraints(&instr, 4), + }, + IType(itype) => match itype { + BranchLeqZero | BranchGtZero | BranchLtZero | BranchGeqZero | Store8 + | Store16 => assert_num_constraints(&instr, 1), + BranchEq | BranchNeq | Store32 => assert_num_constraints(&instr, 3), + AddImmediate + | AddImmediateUnsigned + | SetLessThanImmediate + | SetLessThanImmediateUnsigned + | AndImmediate + | OrImmediate + | XorImmediate + | LoadUpperImmediate + | Load8 + | Load16 + | Load32 + | Load8Unsigned + | Load16Unsigned + | Store32Conditional => assert_num_constraints(&instr, 4), + LoadWordLeft | LoadWordRight | StoreWordLeft | StoreWordRight => { + assert_num_constraints(&instr, 13) + } + }, + } + i += 1; + } + assert_eq!( + i, + RTypeInstruction::COUNT + JTypeInstruction::COUNT + ITypeInstruction::COUNT + ); + } + + #[test] + fn test_folding_mips_addiu_constraint() { + let mut fq_sponge: BaseSponge = FqSponge::new(Curve::other_curve_sponge_params()); + let mut rng = o1_utils::tests::make_test_rng(None); + + let domain_size = 1 << 3; + let domain: D = D::::new(domain_size).unwrap(); + + let srs = PairingSRS::::create(domain_size); + srs.get_lagrange_basis(domain); + + // Generating constraints + let constraints = { + // Initialize the environment and run the interpreter + let mut constraints_env = CEnv::::default(); + interpret_itype(&mut constraints_env, ITypeInstruction::AddImmediateUnsigned); + constraints_env.constraints + }; + // We have 3 constraints here. We can select only one. + // println!("Nb of constraints: {:?}", constraints.len()); + // + // You can select one of the constraints if you want to fold only one + // + // constraints + // .iter() + // .for_each(|constraint| + // println!("Degree: {:?}", constraint.degree(1, 0))); + // + // Selecting the first constraint for testing + // let constraints + // = vec![constraints.first().unwrap().clone()]; + + let witness_one = make_random_witness_for_addiu(domain_size, &mut rng); + let witness_two = make_random_witness_for_addiu(domain_size, &mut rng); + // FIXME: run PlonK here to check the it is satisfied. + + // Now, we will fold and use the folding scheme to only prove the + // aggregation instead of the individual instances + + #[derive(Clone, Debug, PartialEq, Eq, Hash)] + struct MIPSFoldingConfig; + + let trace_one: Trace = Trace { + domain_size, + // FIXME: do not use clone + witness: witness_one.clone(), + constraints: constraints.clone(), + lookups: vec![], + }; + + impl FoldingConfig for MIPSFoldingConfig { + type Column = Column; + type Selector = (); + type Challenge = Challenge; + type Curve = Curve; + type Srs = PairingSRS; + type Instance = FoldingInstance; + type Witness = FoldingWitness; + // The structure must a provable Trace. Here we use a single + // instruction trace + type Structure = Trace; + type Env = FoldingEnvironment< + N_MIPS_REL_COLS, + MIPSFoldingConfig, + Trace, + >; + } + + let folding_witness_one: FoldingWitness = { + let witness_one = (&witness_one) + .into_par_iter() + .map(|w| Evaluations::from_vec_and_domain(w.to_vec(), domain)) + .collect(); + FoldingWitness { + witness: witness_one, + } + }; + + let folding_witness_two: FoldingWitness = { + let witness_two = (&witness_two) + .into_par_iter() + .map(|w| Evaluations::from_vec_and_domain(w.to_vec(), domain)) + .collect(); + FoldingWitness { + witness: witness_two, + } + }; + + let folding_instance_one = + build_folding_instance(&folding_witness_one, &mut fq_sponge, domain, &srs); + let folding_instance_two = + build_folding_instance(&folding_witness_two, &mut fq_sponge, domain, &srs); + + let folding_compat_constraints: Vec> = constraints + .iter() + .map(|x| FoldingCompatibleExpr::from(x.clone())) + .collect::>(); + + let (folding_scheme, _) = FoldingScheme::::new( + folding_compat_constraints, + &srs, + domain, + &trace_one, + ); + + let one = (folding_instance_one, folding_witness_one); + let two = (folding_instance_two, folding_witness_two); + let (_relaxed_instance, _relatex_witness) = folding_scheme + .fold_instance_witness_pair(one, two, &mut fq_sponge) + .pair(); + + // FIXME: add IVC + } +} + +pub mod keccak { + fn create_trace_all_steps(domain_size: usize, rng: &mut StdRng) -> DecomposedKeccakTrace { + let mut trace = >>::new( + domain_size, + &mut KeccakEnv::::default(), + ); + { + // 1 block preimages for Sponge(Absorb(Only)), Round(0), and Sponge(Squeeze) + for _ in 0..domain_size { + // random 1-block preimages + let bytelength = rng.gen_range(0..RATE_IN_BYTES); + let preimage: Vec = (0..bytelength).map(|_| rng.gen()).collect(); + // Initialize the environment and run the interpreter + let mut keccak_env = KeccakEnv::::new(0, &preimage); + while keccak_env.step.is_some() { + let step = keccak_env.step.unwrap(); + // Create the relation witness columns + keccak_env.step(); + match step { + Sponge(Absorb(Only)) | Round(0) | Sponge(Squeeze) => { + // Add the witness row to the circuit + trace.push_row(step, &keccak_env.witness_env.witness.cols); + } + _ => {} + } + } + } + // Check there is no need for padding because we reached domain_size rows for these selectors + assert!(trace.is_full(Sponge(Absorb(Only)))); + assert!(trace.is_full(Round(0))); + assert!(trace.is_full(Sponge(Squeeze))); + + // Add the columns of the selectors to the circuit + trace.set_selector_column::(Sponge(Absorb(Only)), domain_size); + trace.set_selector_column::(Round(0), domain_size); + trace.set_selector_column::(Sponge(Squeeze), domain_size); + } + { + // 3 block preimages for Sponge(Absorb(First)), Sponge(Absorb(Middle)), and Sponge(Absorb(Last)) + for _ in 0..domain_size { + // random 3-block preimages + let bytelength = rng.gen_range(2 * RATE_IN_BYTES..3 * RATE_IN_BYTES); + let preimage: Vec = (0..bytelength).map(|_| rng.gen()).collect(); + // Initialize the environment and run the interpreter + let mut keccak_env = KeccakEnv::::new(0, &preimage); + while keccak_env.step.is_some() { + let step = keccak_env.step.unwrap(); + // Create the relation witness columns + keccak_env.step(); + match step { + Sponge(Absorb(First)) | Sponge(Absorb(Middle)) | Sponge(Absorb(Last)) => { + // Add the witness row to the circuit + trace.push_row(step, &keccak_env.witness_env.witness.cols); + } + _ => {} + } + } + } + // Check there is no need for padding because we reached domain_size rows for these selectors + assert!(trace.is_full(Sponge(Absorb(First)))); + assert!(trace.is_full(Sponge(Absorb(Middle)))); + assert!(trace.is_full(Sponge(Absorb(Last)))); + + // Add the columns of the selectors to the circuit + trace.set_selector_column::(Sponge(Absorb(First)), domain_size); + trace + .set_selector_column::(Sponge(Absorb(Middle)), domain_size); + trace.set_selector_column::(Sponge(Absorb(Last)), domain_size); + trace + } + } + + // Prover/Verifier test includidng the Keccak constraints + #[test] + fn test_keccak_prover_constraints() { + // guaranteed to have at least 30MB of stack + stacker::grow(30 * 1024 * 1024, || { + let mut rng = o1_utils::tests::make_test_rng(None); + let domain_size = 1 << 8; + + // Generate 3 blocks of preimage data + let bytelength = rng.gen_range(2 * RATE_IN_BYTES..RATE_IN_BYTES * 3); + let preimage: Vec = (0..bytelength).map(|_| rng.gen()).collect(); + + // Initialize the environment and run the interpreter + let mut keccak_env = KeccakEnv::::new(0, &preimage); + + // Keep track of the constraints and lookups of the sub-circuits + let mut keccak_circuit = + >>::new( + domain_size, + &mut keccak_env, + ); + + while keccak_env.step.is_some() { + let step = keccak_env.selector(); + + // Run the interpreter, which sets the witness columns + keccak_env.step(); + + // Add the witness row to the circuit + keccak_circuit.push_row(step, &keccak_env.witness_env.witness.cols); + } + keccak_circuit.pad_witnesses(); + + for step in Steps::iter().flat_map(|x| x.into_iter()) { + if keccak_circuit.in_circuit(step) { + test_completeness_generic_no_lookups::< + N_ZKVM_KECCAK_COLS, + N_ZKVM_KECCAK_REL_COLS, + N_ZKVM_KECCAK_SEL_COLS, + 0, + _, + >( + keccak_circuit[step].constraints.clone(), + Box::new([]), + keccak_circuit[step].witness.clone(), + domain_size, + &mut rng, + ); + } + } + }); + } + + fn dummy_constraints() -> BTreeMap>> { + Steps::iter() + .flat_map(|x| x.into_iter()) + .map(|step| { + ( + step, + vec![FoldingCompatibleExpr::::Atom( + FoldingCompatibleExprInner::Constant(Fp::zero()), + )], + ) + }) + .collect() + } + + // (Instance, Witness) + type KeccakFoldingSide = ( + ::Instance, + ::Witness, + ); + + // (Step, Left, Right) + type KeccakFoldingPair = (Steps, KeccakFoldingSide, KeccakFoldingSide); + + type KeccakDefaultFqSponge = DefaultFqSponge; + + #[test] + fn heavy_test_keccak_folding() { + use crate::{keccak::folding::KeccakConfig, trace::Foldable, Curve}; + use ark_poly::{EvaluationDomain, Radix2EvaluationDomain as D}; + use folding::{checker::Checker, expressions::FoldingCompatibleExpr}; + use kimchi::curve::KimchiCurve; + use mina_poseidon::FqSponge; + use poly_commitment::kzg::PairingSRS; + + // guaranteed to have at least 30MB of stack + stacker::grow(30 * 1024 * 1024, || { + let mut rng = o1_utils::tests::make_test_rng(None); + let domain_size = 1 << 6; + + let domain = D::::new(domain_size).unwrap(); + let srs = PairingSRS::::create(domain_size); + srs.get_lagrange_basis(domain); + + // Create sponge + let mut fq_sponge = BaseSponge::new(Curve::other_curve_sponge_params()); + + // Create two instances for each selector to be folded + let keccak_trace: [DecomposedKeccakTrace; 2] = + array::from_fn(|_| create_trace_all_steps(domain_size, &mut rng)); + let trace = keccak_trace[0].clone(); + + // Store all constraints indexed by Step + let constraints = >::folding_constraints(&trace); + + // DEFINITIONS OF FUNCTIONS FOR TESTING PURPOSES + + let check_instance_satisfy_constraints = + |constraints: &[FoldingCompatibleExpr], side: &KeccakFoldingSide| { + let (instance, witness) = side; + let checker = Provider::new(instance.clone(), witness.clone()); + constraints.iter().for_each(|c| { + checker.check(c, domain); + }); + }; + + let check_folding = + |left: &KeccakFoldingSide, + right: &KeccakFoldingSide, + constraints: &[FoldingCompatibleExpr], + fq_sponge: &mut KeccakDefaultFqSponge| { + // Create the folding scheme ignoring selectors + let (scheme, final_constraint) = + FoldingScheme::::new(constraints.to_vec(), &srs, domain, &()); + + // Fold both sides and check the constraints ignoring the selector columns + let fout = + scheme.fold_instance_witness_pair(left.clone(), right.clone(), fq_sponge); + + // We should always have 0 as the degree of the constraints, + // without selectors, they are never higher than 2 in Keccak. + assert_eq!(scheme.get_number_of_additional_columns(), 0); + + let checker = ExtendedProvider::new(fout.folded_instance, fout.folded_witness); + checker.check(&final_constraint, domain); + }; + + let check_decomposable_folding_pair = + |step: Option, + left: &KeccakFoldingSide, + right: &KeccakFoldingSide, + scheme: &DecomposableFoldingScheme, + final_constraint: &FoldingCompatibleExpr, + quadri_cols: Option, + fq_sponge: &mut KeccakDefaultFqSponge| { + let fout = scheme.fold_instance_witness_pair( + left.clone(), + right.clone(), + step, + fq_sponge, + ); + + let extra_cols = scheme.get_number_of_additional_columns(); + if let Some(quadri_cols) = quadri_cols { + assert!(extra_cols == quadri_cols); + } + + // Check the constraints on the folded circuit applying selectors + let checker = ExtendedProvider::::new( + fout.folded_instance, + fout.folded_witness, + ); + checker.check(final_constraint, domain); + }; + + let check_decomposable_folding = + |pair: &KeccakFoldingPair, + constraints: BTreeMap>>, + common_constraints: Vec>, + quadri_cols: Option, + fq_sponge: &mut KeccakDefaultFqSponge| { + let (step, left, right) = pair; + let (dec_scheme, dec_final_constraint) = + DecomposableFoldingScheme::::new( + constraints, + common_constraints, + &srs, + domain, + &(), + ); + // Subcase A: Check the folded circuit of decomposable folding ignoring selectors (None) + check_decomposable_folding_pair( + None, + left, + right, + &dec_scheme, + &dec_final_constraint, + quadri_cols, + fq_sponge, + ); + // Subcase B: Check the folded circuit of decomposable folding applying selectors (Some) + check_decomposable_folding_pair( + Some(*step), + left, + right, + &dec_scheme, + &dec_final_constraint, + quadri_cols, + fq_sponge, + ); + }; + + let check_decomposable_folding_mix = + |steps: (Steps, Steps), fq_sponge: &mut KeccakDefaultFqSponge| { + let (dec_scheme, dec_final_constraint) = + DecomposableFoldingScheme::::new( + constraints.clone(), + vec![], + &srs, + domain, + &(), + ); + let left = { + let fout = dec_scheme.fold_instance_witness_pair( + keccak_trace[0].to_folding_pair(steps.0, fq_sponge, domain, &srs), + keccak_trace[1].to_folding_pair(steps.0, fq_sponge, domain, &srs), + Some(steps.0), + fq_sponge, + ); + let checker = ExtendedProvider::::new( + fout.folded_instance, + fout.folded_witness, + ); + (checker.instance, checker.witness) + }; + let right = { + let fout = dec_scheme.fold_instance_witness_pair( + keccak_trace[0].to_folding_pair(steps.1, fq_sponge, domain, &srs), + keccak_trace[1].to_folding_pair(steps.1, fq_sponge, domain, &srs), + Some(steps.1), + fq_sponge, + ); + let checker = ExtendedProvider::::new( + fout.folded_instance, + fout.folded_witness, + ); + (checker.instance, checker.witness) + }; + let fout = dec_scheme.fold_instance_witness_pair(left, right, None, fq_sponge); + let checker = ExtendedProvider::new(fout.folded_instance, fout.folded_witness); + checker.check(&dec_final_constraint, domain); + }; + + // HERE STARTS THE TESTING + + // Sanity checks that the number of constraints are as expected for each step + assert_eq!(constraints[&Sponge(Absorb(First))].len(), 332); + assert_eq!(constraints[&Sponge(Absorb(Middle))].len(), 232); + assert_eq!(constraints[&Sponge(Absorb(Last))].len(), 374); + assert_eq!(constraints[&Sponge(Absorb(Only))].len(), 474); + assert_eq!(constraints[&Sponge(Squeeze)].len(), 16); + assert_eq!(constraints[&Round(0)].len(), 389); + + // Total number of Keccak constraints of degree higher than 2 (should be 0) + let total_deg_higher_2 = + Steps::iter() + .flat_map(|x| x.into_iter()) + .fold(0, |acc, step| { + acc + trace[step] + .constraints + .iter() + .filter(|c| c.degree(1, 0) > 2) + .count() + }); + assert_eq!(total_deg_higher_2, 0); + + // Check folding constraints of individual steps ignoring selectors + for step in Steps::iter().flat_map(|x| x.into_iter()) { + // BTreeMap with constraints of this step + let mut step_constraints = BTreeMap::new(); + step_constraints.insert(step, constraints[&step].clone()); + + // Create sides for folding + let left = keccak_trace[0].to_folding_pair(step, &mut fq_sponge, domain, &srs); + let right = keccak_trace[1].to_folding_pair(step, &mut fq_sponge, domain, &srs); + + // CASE 0: Check instances satisfy the constraints, without folding them + check_instance_satisfy_constraints(&constraints[&step], &left); + check_instance_satisfy_constraints(&constraints[&step], &right); + + // CASE 1: Check constraints on folded circuit ignoring selectors with `FoldingScheme` + check_folding(&left, &right, &constraints[&step], &mut fq_sponge); + + // CASE 2: Check that `DecomposableFoldingScheme` works when passing the dummy zero constraint + // to each step, and an empty list of common constraints. + let pair = (step, left, right); + check_decomposable_folding( + &pair, + dummy_constraints(), + vec![], + Some(0), + &mut fq_sponge, + ); + + // CASE 3: Using a separate `DecomposableFoldingScheme` for each step, check each step + // constraints using a dummy BTreeMap of `vec[0]` per-step constraints and + // common constraints set to each selector's constraints. + check_decomposable_folding( + &pair, + dummy_constraints(), + step_constraints[&step].clone(), + Some(0), + &mut fq_sponge, + ); + + // CASE 4: Using the same `DecomposableFoldingScheme` for all steps, initialized with a real + // BTreeMap of only the current step, and common constraints set to `vec[]`, check + // the folded circuit + check_decomposable_folding( + &pair, + step_constraints.clone(), + vec![], + None, + &mut fq_sponge, + ); + + // CASE 5: Using the same `DecomposableFoldingScheme` for all steps, initialized with a real + // BTreeMap of constraints per-step, and common constraints set to `vec[]`, check + // the folded circuit + check_decomposable_folding( + &pair, + constraints.clone(), + vec![], + Some(151), + &mut fq_sponge, + ); + } + // CASE 6: Fold mixed steps together and check the final constraints + check_decomposable_folding_mix((Sponge(Absorb(First)), Round(0)), &mut fq_sponge); + }); + } +} diff --git a/o1vm/src/legacy/trace.rs b/o1vm/src/legacy/trace.rs new file mode 100644 index 0000000000..2f5d1be497 --- /dev/null +++ b/o1vm/src/legacy/trace.rs @@ -0,0 +1,565 @@ +//! This module defines structures and traits to build and manipulate traces. +//! A trace is a collection of data points that represent the execution of a +//! program. +//! Some trace can be seen as "decomposable" in the sense that they can be +//! divided into sub-traces that share the same columns, and sub-traces can be +//! selected using "selectors". + +use crate::{ + legacy::{ + folding::{BaseField, FoldingInstance, FoldingWitness, ScalarField}, + Curve, Pairing, + }, + lookups::Lookup, + E, +}; +use ark_ff::{One, Zero}; +use ark_poly::{Evaluations, Radix2EvaluationDomain as D}; +use folding::{expressions::FoldingCompatibleExpr, Alphas, FoldingConfig}; +use itertools::Itertools; +use kimchi::circuits::berkeley_columns::BerkeleyChallengeTerm; +use kimchi_msm::{columns::Column, witness::Witness}; +use mina_poseidon::sponge::FqSponge; +use poly_commitment::{commitment::absorb_commitment, PolyComm, SRS as _}; +use rayon::{iter::ParallelIterator, prelude::IntoParallelIterator}; +use std::{collections::BTreeMap, ops::Index}; + +/// Implement a trace for a single instruction. +// TODO: we should use the generic traits defined in [kimchi_msm]. +// For now, we want to have this to be able to test the folding library for a +// single instruction. +// It is not recommended to use this in production and it should not be +// maintained in the long term. +#[derive(Clone)] +pub struct Trace { + pub domain_size: usize, + pub witness: Witness>>, + pub constraints: Vec>>, + pub lookups: Vec>>>, +} + +/// Struct representing a circuit execution trace which is decomposable in +/// individual sub-circuits sharing the same columns. +/// It is parameterized by +/// - `N`: the total number of columns (constant), it must equal `N_REL + N_DSEL` +/// - `N_REL`: the number of relation columns (constant), +/// - `N_DSEL`: the number of dynamic selector columns (constant), +/// - `Selector`: an enum representing the different gate behaviours, +/// - `F`: the type of the witness data. +#[derive(Clone)] +pub struct DecomposedTrace { + /// The domain size of the circuit (should coincide with that of the traces) + pub domain_size: usize, + /// The traces are indexed by the selector + /// Inside the witness of the trace for a given selector, + /// - the last N_SEL columns represent the selector columns + /// and only the one for `Selector` should be all ones (the rest of selector columns should be all zeros) + pub trace: BTreeMap>, +} + +// Implementation of [Index] using `C::Selector`` as the index for [DecomposedTrace] to access the trace directly. +impl Index for DecomposedTrace { + type Output = Trace; + + fn index(&self, index: C::Selector) -> &Self::Output { + &self.trace[&index] + } +} + +impl DecomposedTrace +where + usize: From<::Selector>, +{ + /// Returns the number of rows that have been instantiated for the given + /// selector. + /// It is important that the column used is a relation column because + /// selector columns are only instantiated at the very end, so their length + /// could be zero most times. + /// That is the reason that relation columns are located first. + pub fn number_of_rows(&self, opcode: C::Selector) -> usize { + self[opcode].witness.cols[0].len() + } + + /// Returns a boolean indicating whether the witness for the given selector + /// was ever found in the cirucit or not. + pub fn in_circuit(&self, opcode: C::Selector) -> bool { + self.number_of_rows(opcode) != 0 + } + + /// Returns whether the witness for the given selector has achieved a number + /// of rows that is equal to the domain size. + pub fn is_full(&self, opcode: C::Selector) -> bool { + self.domain_size == self.number_of_rows(opcode) + } + + /// Resets the witness after folding + pub fn reset(&mut self, opcode: C::Selector) { + (self.trace.get_mut(&opcode).unwrap().witness.cols.as_mut()) + .iter_mut() + .for_each(Vec::clear); + } + + /// Sets the selector column to all ones, and the rest to all zeros + pub fn set_selector_column( + &mut self, + selector: C::Selector, + number_of_rows: usize, + ) { + (N_REL..N).for_each(|i| { + if i == usize::from(selector) { + self.trace.get_mut(&selector).unwrap().witness.cols[i] + .extend((0..number_of_rows).map(|_| ScalarField::::one())) + } else { + self.trace.get_mut(&selector).unwrap().witness.cols[i] + .extend((0..number_of_rows).map(|_| ScalarField::::zero())) + } + }); + } +} + +/// The trait [Foldable] describes structures that can be folded. +/// For that, it requires to be able to implement a way to return a folding +/// instance and a folding witness. +/// It is specialized for the [DecomposedTrace] struct for now and is expected +/// to fold individual instructions, selected with a specific `C::Selector`. +pub trait Foldable { + /// Returns the witness for the given selector as a folding witness and + /// folding instance pair. + /// Note that this function will also absorb all commitments to the columns + /// to coin challenges appropriately. + fn to_folding_pair( + &self, + selector: C::Selector, + fq_sponge: &mut Sponge, + domain: D>, + srs: &poly_commitment::kzg::PairingSRS, + ) -> ( + FoldingInstance, + FoldingWitness>, + ); + + /// Returns a map of constraints that are compatible with folding for each selector + fn folding_constraints(&self) -> BTreeMap>>; +} + +/// Implement the trait Foldable for the structure [DecomposedTrace] +impl, Sponge> + Foldable for DecomposedTrace +where + C::Selector: Into, + Sponge: FqSponge, C::Curve, ScalarField>, + ::Challenge: From, +{ + fn to_folding_pair( + &self, + selector: C::Selector, + fq_sponge: &mut Sponge, + domain: D>, + srs: &poly_commitment::kzg::PairingSRS, + ) -> ( + FoldingInstance, + FoldingWitness>, + ) { + let folding_witness = FoldingWitness { + witness: (&self[selector].witness) + .into_par_iter() + .map(|w| Evaluations::from_vec_and_domain(w.to_vec(), domain)) + .collect(), + }; + + let commitments: Witness> = (&folding_witness.witness) + .into_par_iter() + .map(|w| srs.commit_evaluations_non_hiding(domain, w)) + .collect(); + + // Absorbing commitments + (&commitments) + .into_iter() + .for_each(|c| absorb_commitment(fq_sponge, c)); + + let commitments: [C::Curve; N] = commitments + .into_iter() + .map(|c| c.get_first_chunk()) + .collect_vec() + .try_into() + .unwrap(); + + let beta = fq_sponge.challenge(); + let gamma = fq_sponge.challenge(); + let joint_combiner = fq_sponge.challenge(); + let alpha = fq_sponge.challenge(); + let challenges = [beta, gamma, joint_combiner]; + let alphas = Alphas::new(alpha); + let blinder = ScalarField::::one(); + let instance = FoldingInstance { + commitments, + challenges, + alphas, + blinder, + }; + + (instance, folding_witness) + } + + fn folding_constraints(&self) -> BTreeMap>> { + self.trace + .iter() + .map(|(k, instr)| { + ( + *k, + instr + .constraints + .iter() + .map(|x| FoldingCompatibleExpr::from(x.clone())) + .collect(), + ) + }) + .collect() + } +} + +/// Tracer builds traces for some program executions. +/// The constant type `N_REL` is defined as the maximum number of relation +/// columns the trace can use per row. +/// The type `C` encodes the folding configuration, from which the selector, +/// which encodes the information of the kind of information the trace encodes, +/// and scalar field are derived. Examples of selectors are: +/// - For Keccak, `Step` encodes the row being performed at a time: round, +/// squeeze, padding, etc... +/// - For MIPS, `Instruction` encodes the CPU instruction being executed: add, +/// sub, load, store, etc... +pub trait Tracer { + type Selector; + + /// Initialize a new trace with the given domain size, selector, and environment. + fn init(domain_size: usize, selector: C::Selector, env: &mut Env) -> Self; + + /// Add a witness row to the circuit (only for relation columns) + fn push_row(&mut self, selector: Self::Selector, row: &[ScalarField; N_REL]); + + /// Pad the rows of one opcode with the given row until + /// reaching the domain size if needed. + /// Returns the number of rows that were added. + /// It does not add selector columns. + fn pad_with_row(&mut self, selector: Self::Selector, row: &[ScalarField; N_REL]) -> usize; + + /// Pads the rows of one opcode with zero rows until + /// reaching the domain size if needed. + /// Returns the number of rows that were added. + /// It does not add selector columns. + fn pad_with_zeros(&mut self, selector: Self::Selector) -> usize; + + /// Pad the rows of one opcode with the first row until + /// reaching the domain size if needed. + /// It only tries to pad witnesses which are non empty. + /// Returns the number of rows that were added. + /// It does not add selector columns. + /// - Use `None` for single traces + /// - Use `Some(selector)` for multi traces + fn pad_dummy(&mut self, selector: Self::Selector) -> usize; +} + +/// DecomposableTracer builds traces for some program executions. +/// The constant type `N_REL` is defined as the maximum number of relation +/// columns the trace can use per row. +/// The type `C` encodes the folding configuration, from which the selector, +/// and scalar field are derived. Examples of selectors are: +/// - For Keccak, `Step` encodes the row being performed at a time: round, +/// squeeze, padding, etc... +/// - For MIPS, `Instruction` encodes the CPU instruction being executed: add, +/// sub, load, store, etc... +pub trait DecomposableTracer { + /// Create a new decomposable trace with the given domain size, and environment. + fn new(domain_size: usize, env: &mut Env) -> Self; + + /// Pads the rows of the witnesses until reaching the domain size using the first + /// row repeatedly. It does not add selector columns. + fn pad_witnesses(&mut self); +} + +/// Generic implementation of the [Tracer] trait for the [DecomposedTrace] struct. +/// It requires the [DecomposedTrace] to implement the [DecomposableTracer] trait, +/// and the [Trace] struct to implement the [Tracer] trait with Selector set to (), +/// and `usize` to implement the [From] trait with `C::Selector`. +impl Tracer + for DecomposedTrace +where + DecomposedTrace: DecomposableTracer, + Trace: Tracer, + usize: From<::Selector>, +{ + type Selector = C::Selector; + + fn init(domain_size: usize, _selector: C::Selector, env: &mut Env) -> Self { + >::new(domain_size, env) + } + + fn push_row(&mut self, selector: Self::Selector, row: &[ScalarField; N_REL]) { + self.trace.get_mut(&selector).unwrap().push_row((), row); + } + + fn pad_with_row(&mut self, selector: Self::Selector, row: &[ScalarField; N_REL]) -> usize { + // We only want to pad non-empty witnesses. + if !self.in_circuit(selector) { + 0 + } else { + self.trace.get_mut(&selector).unwrap().pad_with_row((), row) + } + } + + fn pad_with_zeros(&mut self, selector: Self::Selector) -> usize { + // We only want to pad non-empty witnesses. + if !self.in_circuit(selector) { + 0 + } else { + self.trace.get_mut(&selector).unwrap().pad_with_zeros(()) + } + } + + fn pad_dummy(&mut self, selector: Self::Selector) -> usize { + // We only want to pad non-empty witnesses. + if !self.in_circuit(selector) { + 0 + } else { + self.trace.get_mut(&selector).unwrap().pad_dummy(()) + } + } +} + +pub mod keccak { + use std::{array, collections::BTreeMap}; + + use ark_ff::Zero; + use kimchi_msm::witness::Witness; + use strum::IntoEnumIterator; + + use crate::{ + interpreters::keccak::{ + column::{Steps, N_ZKVM_KECCAK_COLS, N_ZKVM_KECCAK_REL_COLS}, + environment::KeccakEnv, + standardize, + }, + legacy::{ + folding::{keccak::KeccakConfig, ScalarField}, + trace::{DecomposableTracer, DecomposedTrace, Trace, Tracer}, + }, + }; + + /// A Keccak instruction trace + pub type KeccakTrace = Trace; + /// The Keccak circuit trace + pub type DecomposedKeccakTrace = DecomposedTrace; + + impl DecomposableTracer>> for DecomposedKeccakTrace { + fn new(domain_size: usize, env: &mut KeccakEnv>) -> Self { + let mut circuit = Self { + domain_size, + trace: BTreeMap::new(), + }; + for step in Steps::iter().flat_map(|step| step.into_iter()) { + circuit + .trace + .insert(step, KeccakTrace::init(domain_size, step, env)); + } + circuit + } + + fn pad_witnesses(&mut self) { + for opcode in Steps::iter().flat_map(|opcode| opcode.into_iter()) { + if self.in_circuit(opcode) { + self.trace.get_mut(&opcode).unwrap().pad_dummy(()); + } + } + } + } + + impl Tracer>> + for KeccakTrace + { + type Selector = (); + + fn init( + domain_size: usize, + selector: Steps, + _env: &mut KeccakEnv>, + ) -> Self { + // Make sure we are using the same round number to refer to round steps + let step = standardize(selector); + Self { + domain_size, + witness: Witness { + cols: Box::new(std::array::from_fn(|_| Vec::with_capacity(domain_size))), + }, + constraints: KeccakEnv::constraints_of(step), + lookups: KeccakEnv::lookups_of(step), + } + } + + fn push_row( + &mut self, + _selector: Self::Selector, + row: &[ScalarField; N_ZKVM_KECCAK_REL_COLS], + ) { + for (i, value) in row.iter().enumerate() { + if self.witness.cols[i].len() < self.witness.cols[i].capacity() { + self.witness.cols[i].push(*value); + } + } + } + + fn pad_with_row( + &mut self, + _selector: Self::Selector, + row: &[ScalarField; N_ZKVM_KECCAK_REL_COLS], + ) -> usize { + let len = self.witness.cols[0].len(); + assert!(len <= self.domain_size); + let rows_to_add = self.domain_size - len; + // When we reach the domain size, we don't need to pad anymore. + for _ in 0..rows_to_add { + self.push_row((), row); + } + rows_to_add + } + + fn pad_with_zeros(&mut self, _selector: Self::Selector) -> usize { + let len = self.witness.cols[0].len(); + assert!(len <= self.domain_size); + let rows_to_add = self.domain_size - len; + // When we reach the domain size, we don't need to pad anymore. + for col in self.witness.cols.iter_mut() { + col.extend((0..rows_to_add).map(|_| ScalarField::::zero())); + } + rows_to_add + } + + fn pad_dummy(&mut self, _selector: Self::Selector) -> usize { + // We keep track of the first row of the non-empty witness, which is a real step witness. + let row = array::from_fn(|i| self.witness.cols[i][0]); + self.pad_with_row(_selector, &row) + } + } +} + +pub mod mips { + use crate::{ + interpreters::mips::{ + column::{N_MIPS_COLS, N_MIPS_REL_COLS}, + constraints::Env, + interpreter::{interpret_instruction, Instruction, InterpreterEnv}, + }, + legacy::{ + folding::{mips::DecomposableMIPSFoldingConfig, ScalarField}, + trace::{DecomposableTracer, DecomposedTrace, Trace, Tracer}, + }, + }; + use ark_ff::Zero; + use kimchi_msm::witness::Witness; + use std::{array, collections::BTreeMap}; + use strum::IntoEnumIterator; + + /// The MIPS instruction trace + pub type MIPSTrace = Trace; + /// The MIPS circuit trace + pub type DecomposedMIPSTrace = DecomposedTrace; + + impl DecomposableTracer>> for DecomposedMIPSTrace { + fn new( + domain_size: usize, + env: &mut Env>, + ) -> Self { + let mut circuit = Self { + domain_size, + trace: BTreeMap::new(), + }; + for instr in Instruction::iter().flat_map(|step| step.into_iter()) { + circuit + .trace + .insert(instr, ::init(domain_size, instr, env)); + } + circuit + } + + fn pad_witnesses(&mut self) { + for opcode in Instruction::iter().flat_map(|opcode| opcode.into_iter()) { + self.trace.get_mut(&opcode).unwrap().pad_dummy(()); + } + } + } + + impl + Tracer< + N_MIPS_REL_COLS, + DecomposableMIPSFoldingConfig, + Env>, + > for MIPSTrace + { + type Selector = (); + + fn init( + domain_size: usize, + instr: Instruction, + env: &mut Env>, + ) -> Self { + interpret_instruction(env, instr); + + let trace = Self { + domain_size, + witness: Witness { + cols: Box::new(std::array::from_fn(|_| Vec::with_capacity(domain_size))), + }, + constraints: env.get_constraints(), + lookups: env.get_lookups(), + }; + // Clear for the next instruction + env.reset(); + trace + } + + fn push_row( + &mut self, + _selector: Self::Selector, + row: &[ScalarField; N_MIPS_REL_COLS], + ) { + for (i, value) in row.iter().enumerate() { + if self.witness.cols[i].len() < self.witness.cols[i].capacity() { + self.witness.cols[i].push(*value); + } + } + } + + fn pad_with_row( + &mut self, + _selector: Self::Selector, + row: &[ScalarField; N_MIPS_REL_COLS], + ) -> usize { + let len = self.witness.cols[0].len(); + assert!(len <= self.domain_size); + let rows_to_add = self.domain_size - len; + // When we reach the domain size, we don't need to pad anymore. + for _ in 0..rows_to_add { + self.push_row(_selector, row); + } + rows_to_add + } + + fn pad_with_zeros(&mut self, _selector: Self::Selector) -> usize { + let len = self.witness.cols[0].len(); + assert!(len <= self.domain_size); + let rows_to_add = self.domain_size - len; + // When we reach the domain size, we don't need to pad anymore. + for col in self.witness.cols.iter_mut() { + col.extend( + (0..rows_to_add).map(|_| ScalarField::::zero()), + ); + } + rows_to_add + } + + fn pad_dummy(&mut self, _selector: Self::Selector) -> usize { + // We keep track of the first row of the non-empty witness, which is a real step witness. + let row = array::from_fn(|i| self.witness.cols[i][0]); + self.pad_with_row(_selector, &row) + } + } +} diff --git a/o1vm/src/lib.rs b/o1vm/src/lib.rs new file mode 100644 index 0000000000..0ee9c5edd4 --- /dev/null +++ b/o1vm/src/lib.rs @@ -0,0 +1,50 @@ +/// Modules mimicking the defined structures used by Cannon CLI. +pub mod cannon; + +/// A CLI mimicking the Cannon CLI. +pub mod cannon_cli; + +/// A module to load ELF files. +pub mod elf_loader; + +pub mod interpreters; + +/// Legacy implementation of the recursive proof composition. +/// It does use the folding and ivc libraries defined in this monorepo, and aims +/// to be compatible with Ethereum natively, using the curve bn254. +pub mod legacy; + +/// Pickles flavor of the o1vm. +pub mod pickles; + +/// Instantiation of the lookups for the VM project. +pub mod lookups; + +/// Preimage oracle interface used by the zkVM. +pub mod preimage_oracle; + +/// The RAM lookup argument. +pub mod ramlookup; + +pub mod utils; + +use kimchi::circuits::{ + berkeley_columns::BerkeleyChallengeTerm, + expr::{ConstantExpr, Expr}, +}; +use kimchi_msm::columns::Column; +pub use ramlookup::{LookupMode as RAMLookupMode, RAMLookup}; + +/// Type to represent a constraint on the individual columns of the execution +/// trace. +/// As a reminder, a constraint can be formally defined as a multi-variate +/// polynomial over a finite field. The variables of the polynomial are defined +/// as `kimchi_msm::columns::Column`. +/// The `expression` framework defined in `kimchi::circuits::expr` is used to +/// describe the multi-variate polynomials. +/// For instance, a vanilla 3-wires PlonK constraint can be defined using the +/// multi-variate polynomial of degree 2 +/// `P(X, Y, Z) = q_x X + q_y Y + q_m X Y + q_o Z + q_c` +/// To represent this multi-variate polynomial using the expression framework, +/// we would use 3 different columns. +pub(crate) type E = Expr, Column>; diff --git a/o1vm/src/lookups.rs b/o1vm/src/lookups.rs new file mode 100644 index 0000000000..eff69771f3 --- /dev/null +++ b/o1vm/src/lookups.rs @@ -0,0 +1,278 @@ +//! Instantiation of the lookups for the VM project. + +use self::LookupTableIDs::*; +use crate::{interpreters::keccak::pad_blocks, ramlookup::RAMLookup}; +use ark_ff::{Field, PrimeField}; +use kimchi::{ + circuits::polynomials::keccak::{ + constants::{RATE_IN_BYTES, ROUNDS}, + Keccak, RC, + }, + o1_utils::{FieldHelpers, Two}, +}; +use kimchi_msm::{LogupTable, LogupWitness, LookupTableID}; + +/// The lookups struct based on RAMLookups for the VM table IDs +pub(crate) type Lookup = RAMLookup; + +#[allow(dead_code)] +/// Represents a witness of one instance of the lookup argument of the zkVM project +pub(crate) type LookupWitness = LogupWitness; + +/// The lookup table struct based on LogupTable for the VM table IDs +pub(crate) type LookupTable = LogupTable; + +/// All of the possible lookup table IDs used in the zkVM +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, Ord, PartialOrd)] +pub enum LookupTableIDs { + // PadLookup ID is 0 because this is the only fixed table whose first entry is not 0. + // This way, it is guaranteed that the 0 value is not always in the tables after the + // randomization with the joint combiner is applied. + /// All [1..136] values of possible padding lengths, the value 2^len, and the 5 corresponding pad suffixes with the 10*1 rule + PadLookup = 0, + /// 24-row table with all possible values for round and their round constant in expanded form (in big endian) [0..=23] + RoundConstantsLookup = 1, + /// Values from 0 to 4 to check the number of bytes read from syscalls + AtMost4Lookup = 2, + /// All values that can be stored in a byte (amortized table, better than model as RangeCheck16 (x and scaled x) + ByteLookup = 3, + // Read tables come first to allow indexing with the table ID for the multiplicities + /// Single-column table of all values in the range [0, 2^16) + RangeCheck16Lookup = 4, + /// Single-column table of 2^16 entries with the sparse representation of all values + SparseLookup = 5, + /// Dual-column table of all values in the range [0, 2^16) and their sparse representation + ResetLookup = 6, + + // RAM Tables + MemoryLookup = 7, + RegisterLookup = 8, + /// Syscalls communication channel + SyscallLookup = 9, + /// Input/Output of Keccak steps + KeccakStepLookup = 10, +} + +impl LookupTableID for LookupTableIDs { + fn to_u32(&self) -> u32 { + *self as u32 + } + + fn from_u32(value: u32) -> Self { + match value { + 0 => PadLookup, + 1 => RoundConstantsLookup, + 2 => AtMost4Lookup, + 3 => ByteLookup, + 4 => RangeCheck16Lookup, + 5 => SparseLookup, + 6 => ResetLookup, + 7 => MemoryLookup, + 8 => RegisterLookup, + 9 => SyscallLookup, + 10 => KeccakStepLookup, + _ => panic!("Invalid table ID"), + } + } + + fn length(&self) -> usize { + match self { + PadLookup => RATE_IN_BYTES, + RoundConstantsLookup => ROUNDS, + AtMost4Lookup => 5, + ByteLookup => 1 << 8, + RangeCheck16Lookup | SparseLookup | ResetLookup => 1 << 16, + MemoryLookup | RegisterLookup | SyscallLookup | KeccakStepLookup => { + panic!("RAM Tables do not have a fixed length") + } + } + } + + fn is_fixed(&self) -> bool { + match self { + PadLookup | RoundConstantsLookup | AtMost4Lookup | ByteLookup | RangeCheck16Lookup + | SparseLookup | ResetLookup => true, + MemoryLookup | RegisterLookup | SyscallLookup | KeccakStepLookup => false, + } + } + + fn runtime_create_column(&self) -> bool { + panic!("No runtime tables specified"); + } + + fn ix_by_value(&self, _value: &[F]) -> Option { + todo!() + } + + fn all_variants() -> Vec { + vec![ + Self::PadLookup, + Self::RoundConstantsLookup, + Self::AtMost4Lookup, + Self::ByteLookup, + Self::RangeCheck16Lookup, + Self::SparseLookup, + Self::ResetLookup, + Self::MemoryLookup, + Self::RegisterLookup, + Self::SyscallLookup, + Self::KeccakStepLookup, + ] + } +} + +/// Trait that creates all the fixed lookup tables used in the VM +pub(crate) trait FixedLookupTables { + /// Checks whether a value is in a table and returns the position if it is or None otherwise. + fn is_in_table(table: &LookupTable, value: Vec) -> Option; + /// Returns the pad table + fn table_pad() -> LookupTable; + /// Returns the round constants table + fn table_round_constants() -> LookupTable; + /// Returns the at most 4 table + fn table_at_most_4() -> LookupTable; + /// Returns the byte table + fn table_byte() -> LookupTable; + /// Returns the range check 16 table + fn table_range_check_16() -> LookupTable; + /// Returns the sparse table + fn table_sparse() -> LookupTable; + /// Returns the reset table + fn table_reset() -> LookupTable; +} + +impl FixedLookupTables for LookupTable { + fn is_in_table(table: &LookupTable, value: Vec) -> Option { + let id = table.table_id; + // In these tables, the first value of the vector is related to the index within the table. + let idx = value[0] + .to_bytes() + .iter() + .rev() + .fold(0u64, |acc, &x| acc * 256 + x as u64) as usize; + + match id { + RoundConstantsLookup | AtMost4Lookup | ByteLookup | RangeCheck16Lookup + | ResetLookup => { + if idx < id.length() && table.entries[idx] == value { + Some(idx) + } else { + None + } + } + PadLookup => { + // Because this table starts with entry 1 + if idx - 1 < id.length() && table.entries[idx - 1] == value { + Some(idx - 1) + } else { + None + } + } + SparseLookup => { + let res = u64::from_str_radix(&format!("{:x}", idx), 2); + let dense = if let Ok(ok) = res { + ok as usize + } else { + id.length() // So that it returns None + }; + if dense < id.length() && table.entries[dense] == value { + Some(dense) + } else { + None + } + } + MemoryLookup | RegisterLookup | SyscallLookup | KeccakStepLookup => None, + } + } + + fn table_pad() -> Self { + Self { + table_id: PadLookup, + entries: (1..=PadLookup.length()) + .map(|i| { + let suffix = pad_blocks(i); + vec![ + F::from(i as u64), + F::two_pow(i as u64), + suffix[0], + suffix[1], + suffix[2], + suffix[3], + suffix[4], + ] + }) + .collect(), + } + } + + fn table_round_constants() -> Self { + Self { + table_id: RoundConstantsLookup, + entries: (0..RoundConstantsLookup.length()) + .map(|i| { + vec![ + F::from(i as u32), + F::from(Keccak::sparse(RC[i])[3]), + F::from(Keccak::sparse(RC[i])[2]), + F::from(Keccak::sparse(RC[i])[1]), + F::from(Keccak::sparse(RC[i])[0]), + ] + }) + .collect(), + } + } + + fn table_at_most_4() -> LookupTable { + Self { + table_id: AtMost4Lookup, + entries: (0..AtMost4Lookup.length()) + .map(|i| vec![F::from(i as u32)]) + .collect(), + } + } + + fn table_byte() -> Self { + Self { + table_id: ByteLookup, + entries: (0..ByteLookup.length()) + .map(|i| vec![F::from(i as u32)]) + .collect(), + } + } + + fn table_range_check_16() -> Self { + Self { + table_id: RangeCheck16Lookup, + entries: (0..RangeCheck16Lookup.length()) + .map(|i| vec![F::from(i as u32)]) + .collect(), + } + } + + fn table_sparse() -> Self { + Self { + table_id: SparseLookup, + entries: (0..SparseLookup.length()) + .map(|i| { + vec![F::from( + u64::from_str_radix(&format!("{:b}", i), 16).unwrap(), + )] + }) + .collect(), + } + } + + fn table_reset() -> Self { + Self { + table_id: ResetLookup, + entries: (0..ResetLookup.length()) + .map(|i| { + vec![ + F::from(i as u32), + F::from(u64::from_str_radix(&format!("{:b}", i), 16).unwrap()), + ] + }) + .collect(), + } + } +} diff --git a/o1vm/src/pickles/column_env.rs b/o1vm/src/pickles/column_env.rs new file mode 100644 index 0000000000..61f72d05e1 --- /dev/null +++ b/o1vm/src/pickles/column_env.rs @@ -0,0 +1,131 @@ +use ark_ff::FftField; +use ark_poly::{Evaluations, Radix2EvaluationDomain}; +use kimchi_msm::columns::Column; + +use crate::{ + interpreters::mips::column::{N_MIPS_SEL_COLS, SCRATCH_SIZE, SCRATCH_SIZE_INVERSE}, + pickles::proof::WitnessColumns, +}; +use kimchi::circuits::{ + berkeley_columns::{BerkeleyChallengeTerm, BerkeleyChallenges}, + domains::{Domain, EvaluationDomains}, + expr::{ColumnEnvironment as TColumnEnvironment, Constants}, +}; + +type Evals = Evaluations>; + +/// The collection of polynomials (all in evaluation form) and constants +/// required to evaluate an expression as a polynomial. +/// +/// All are evaluations. +pub struct ColumnEnvironment<'a, F: FftField> { + /// The witness column polynomials. Includes relation columns and dynamic + /// selector columns. + pub witness: &'a WitnessColumns, [Evals; N_MIPS_SEL_COLS]>, + /// The value `prod_{j != 1} (1 - ω^j)`, used for efficiently + /// computing the evaluations of the unnormalized Lagrange basis + /// polynomials. + pub l0_1: F, + /// Constant values required + pub constants: Constants, + /// Challenges from the IOP. + // FIXME: change for other challenges + pub challenges: BerkeleyChallenges, + /// The domains used in the PLONK argument. + pub domain: EvaluationDomains, +} + +pub fn get_all_columns() -> Vec { + let mut cols = + Vec::::with_capacity(SCRATCH_SIZE + SCRATCH_SIZE_INVERSE + 2 + N_MIPS_SEL_COLS); + for i in 0..SCRATCH_SIZE + SCRATCH_SIZE_INVERSE + 2 { + cols.push(Column::Relation(i)); + } + for i in 0..N_MIPS_SEL_COLS { + cols.push(Column::DynamicSelector(i)); + } + cols +} + +impl WitnessColumns { + pub fn get_column(&self, col: &Column) -> Option<&G> { + match *col { + Column::Relation(i) => { + if i < SCRATCH_SIZE { + let res = &self.scratch[i]; + Some(res) + } else if i < SCRATCH_SIZE + SCRATCH_SIZE_INVERSE { + let res = &self.scratch_inverse[i - SCRATCH_SIZE]; + Some(res) + } else if i == SCRATCH_SIZE + SCRATCH_SIZE_INVERSE { + let res = &self.instruction_counter; + Some(res) + } else if i == SCRATCH_SIZE + SCRATCH_SIZE_INVERSE + 1 { + let res = &self.error; + Some(res) + } else { + panic!("We should not have that many relation columns. We have {} columns and index {} was given", SCRATCH_SIZE + SCRATCH_SIZE_INVERSE + 2, i); + } + } + Column::DynamicSelector(i) => { + assert!( + i < N_MIPS_SEL_COLS, + "We do not have that many dynamic selector columns. We have {} columns and index {} was given", + N_MIPS_SEL_COLS, + i + ); + let res = &self.selector[i]; + Some(res) + } + _ => { + panic!( + "We should not have any other type of columns. The column {:?} was given", + col + ); + } + } + } +} + +impl<'a, F: FftField> TColumnEnvironment<'a, F, BerkeleyChallengeTerm, BerkeleyChallenges> + for ColumnEnvironment<'a, F> +{ + // FIXME: do we change to the MIPS column type? + // We do not want to keep kimchi_msm/generic prover + type Column = Column; + + fn get_column(&self, col: &Self::Column) -> Option<&'a Evals> { + self.witness.get_column(col) + } + + fn get_domain(&self, d: Domain) -> Radix2EvaluationDomain { + match d { + Domain::D1 => self.domain.d1, + Domain::D2 => self.domain.d2, + Domain::D4 => self.domain.d4, + Domain::D8 => self.domain.d8, + } + } + + fn column_domain(&self, _col: &Self::Column) -> Domain { + Domain::D8 + } + + fn get_constants(&self) -> &Constants { + &self.constants + } + + fn get_challenges(&self) -> &BerkeleyChallenges { + &self.challenges + } + + fn vanishes_on_zero_knowledge_and_previous_rows( + &self, + ) -> &'a Evaluations> { + panic!("Not supposed to be used in MIPS. We do not support zero-knowledge for now") + } + + fn l0_1(&self) -> F { + self.l0_1 + } +} diff --git a/o1vm/src/pickles/main.rs b/o1vm/src/pickles/main.rs new file mode 100644 index 0000000000..5083d7d58d --- /dev/null +++ b/o1vm/src/pickles/main.rs @@ -0,0 +1,163 @@ +use ark_ff::UniformRand; +use kimchi::circuits::domains::EvaluationDomains; +use kimchi_msm::expr::E; +use log::debug; +use mina_curves::pasta::VestaParameters; +use mina_poseidon::{ + constants::PlonkSpongeConstantsKimchi, + sponge::{DefaultFqSponge, DefaultFrSponge}, +}; +use o1vm::{ + cannon::{self, Meta, Start, State}, + cannon_cli, + interpreters::mips::{ + column::N_MIPS_REL_COLS, + constraints as mips_constraints, + interpreter::{self, InterpreterEnv}, + witness::{self as mips_witness}, + Instruction, + }, + pickles::{proof::ProofInputs, prover, verifier}, + preimage_oracle::PreImageOracle, +}; +use poly_commitment::{ipa::SRS, SRS as _}; +use std::{fs::File, io::BufReader, process::ExitCode, time::Instant}; +use strum::IntoEnumIterator; + +use mina_curves::pasta::{Fp, Vesta}; + +pub const DOMAIN_SIZE: usize = 1 << 15; + +pub fn main() -> ExitCode { + let cli = cannon_cli::main_cli(); + + let mut rng = rand::thread_rng(); + + let configuration = cannon_cli::read_configuration(&cli.get_matches()); + + let file = + File::open(&configuration.input_state_file).expect("Error opening input state file "); + + let reader = BufReader::new(file); + // Read the JSON contents of the file as an instance of `State`. + let state: State = serde_json::from_reader(reader).expect("Error reading input state file"); + + let meta_file = File::open(&configuration.metadata_file).unwrap_or_else(|_| { + panic!( + "Could not open metadata file {}", + &configuration.metadata_file + ) + }); + + let meta: Meta = serde_json::from_reader(BufReader::new(meta_file)).unwrap_or_else(|_| { + panic!( + "Error deserializing metadata file {}", + &configuration.metadata_file + ) + }); + + let mut po = PreImageOracle::create(&configuration.host); + let _child = po.start(); + + // Initialize some data used for statistical computations + let start = Start::create(state.step as usize); + + env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); + + let domain_fp = EvaluationDomains::::create(DOMAIN_SIZE).unwrap(); + let srs: SRS = { + let srs = SRS::create(DOMAIN_SIZE); + srs.get_lagrange_basis(domain_fp.d1); + srs + }; + + // Initialize the environments + let mut mips_wit_env = + mips_witness::Env::::create(cannon::PAGE_SIZE as usize, state, po); + + let constraints = { + let mut mips_con_env = mips_constraints::Env::::default(); + let mut constraints = Instruction::iter() + .flat_map(|instr_typ| instr_typ.into_iter()) + .fold(vec![], |mut acc, instr| { + interpreter::interpret_instruction(&mut mips_con_env, instr); + let selector = mips_con_env.get_selector(); + let constraints_with_selector: Vec> = mips_con_env + .get_constraints() + .into_iter() + .map(|c| selector.clone() * c) + .collect(); + acc.extend(constraints_with_selector); + mips_con_env.reset(); + acc + }); + constraints.extend(mips_con_env.get_selector_constraints()); + constraints + }; + + let mut curr_proof_inputs: ProofInputs = ProofInputs::new(DOMAIN_SIZE); + while !mips_wit_env.halt { + let _instr: Instruction = mips_wit_env.step(&configuration, &meta, &start); + for (scratch, scratch_chunk) in mips_wit_env + .scratch_state + .iter() + .zip(curr_proof_inputs.evaluations.scratch.iter_mut()) + { + scratch_chunk.push(*scratch); + } + for (scratch, scratch_chunk) in mips_wit_env + .scratch_state_inverse + .iter() + .zip(curr_proof_inputs.evaluations.scratch_inverse.iter_mut()) + { + scratch_chunk.push(*scratch); + } + curr_proof_inputs + .evaluations + .instruction_counter + .push(Fp::from(mips_wit_env.instruction_counter)); + // FIXME: Might be another value + curr_proof_inputs.evaluations.error.push(Fp::rand(&mut rng)); + + curr_proof_inputs + .evaluations + .selector + .push(Fp::from((mips_wit_env.selector - N_MIPS_REL_COLS) as u64)); + + if curr_proof_inputs.evaluations.instruction_counter.len() == DOMAIN_SIZE { + // FIXME + let start_iteration = Instant::now(); + debug!("Limit of {DOMAIN_SIZE} reached. We make a proof, verify it (for testing) and start with a new chunk"); + let proof = prover::prove::< + Vesta, + DefaultFqSponge, + DefaultFrSponge, + _, + >(domain_fp, &srs, curr_proof_inputs, &constraints, &mut rng) + .unwrap(); + // FIXME: check that the proof is correct. This is for testing purposes. + // Leaving like this for now. + debug!( + "Proof generated in {elapsed} μs", + elapsed = start_iteration.elapsed().as_micros() + ); + { + let start_iteration = Instant::now(); + let verif = verifier::verify::< + Vesta, + DefaultFqSponge, + DefaultFrSponge, + >(domain_fp, &srs, &constraints, &proof); + debug!( + "Verification done in {elapsed} μs", + elapsed = start_iteration.elapsed().as_micros() + ); + assert!(verif); + } + + curr_proof_inputs = ProofInputs::new(DOMAIN_SIZE); + } + } + // TODO: Logic + ExitCode::SUCCESS +} diff --git a/o1vm/src/pickles/mod.rs b/o1vm/src/pickles/mod.rs new file mode 100644 index 0000000000..193162e90b --- /dev/null +++ b/o1vm/src/pickles/mod.rs @@ -0,0 +1,37 @@ +//! This is the pickles flavor of the o1vm. +//! The goal of this flavor is to run a version of the o1vm with selectors for +//! each instruction using the Pasta curves and the IPA PCS. +//! +//! A proof is generated for each set of N continuous instructions, where N is +//! the size of the supported SRS. The proofs will then be aggregated using +//! a modified version of pickles. +//! +//! You can run this flavor by using: +//! +//! ```bash +//! O1VM_FLAVOR=pickles bash run-code.sh +//! ``` + +pub mod column_env; +pub mod proof; +pub mod prover; +pub mod verifier; + +/// Maximum degree of the constraints. +/// It does include the additional degree induced by the multiplication of the +/// selectors. +pub const MAXIMUM_DEGREE_CONSTRAINTS: u64 = 6; + +/// Degree of the quotient polynomial. We do evaluate all polynomials on d8 +/// (because of the value of [MAXIMUM_DEGREE_CONSTRAINTS]), and therefore, we do +/// have a degree 7 for the quotient polynomial. +/// Used to keep track of the number of chunks we do have when we commit to the +/// quotient polynomial. +pub const DEGREE_QUOTIENT_POLYNOMIAL: u64 = 7; + +/// Total number of constraints for all instructions, including the constraints +/// added for the selectors. +pub const TOTAL_NUMBER_OF_CONSTRAINTS: usize = 464; + +#[cfg(test)] +mod tests; diff --git a/o1vm/src/pickles/proof.rs b/o1vm/src/pickles/proof.rs new file mode 100644 index 0000000000..40529a3409 --- /dev/null +++ b/o1vm/src/pickles/proof.rs @@ -0,0 +1,41 @@ +use kimchi::{curve::KimchiCurve, proof::PointEvaluations}; +use poly_commitment::{ipa::OpeningProof, PolyComm}; + +use crate::interpreters::mips::column::{N_MIPS_SEL_COLS, SCRATCH_SIZE, SCRATCH_SIZE_INVERSE}; + +pub struct WitnessColumns { + pub scratch: [G; SCRATCH_SIZE], + pub scratch_inverse: [G; SCRATCH_SIZE_INVERSE], + pub instruction_counter: G, + pub error: G, + pub selector: S, +} + +pub struct ProofInputs { + pub evaluations: WitnessColumns, Vec>, +} + +impl ProofInputs { + pub fn new(domain_size: usize) -> Self { + ProofInputs { + evaluations: WitnessColumns { + scratch: std::array::from_fn(|_| Vec::with_capacity(domain_size)), + scratch_inverse: std::array::from_fn(|_| Vec::with_capacity(domain_size)), + instruction_counter: Vec::with_capacity(domain_size), + error: Vec::with_capacity(domain_size), + selector: Vec::with_capacity(domain_size), + }, + } + } +} + +// FIXME: should we blind the commitment? +pub struct Proof { + pub commitments: WitnessColumns, [PolyComm; N_MIPS_SEL_COLS]>, + pub zeta_evaluations: WitnessColumns, + pub zeta_omega_evaluations: WitnessColumns, + pub quotient_commitment: PolyComm, + pub quotient_evaluations: PointEvaluations>, + /// IPA opening proof + pub opening_proof: OpeningProof, +} diff --git a/o1vm/src/pickles/prover.rs b/o1vm/src/pickles/prover.rs new file mode 100644 index 0000000000..0863227ef5 --- /dev/null +++ b/o1vm/src/pickles/prover.rs @@ -0,0 +1,456 @@ +use std::array; + +use ark_ff::{One, PrimeField, Zero}; +use ark_poly::{univariate::DensePolynomial, Evaluations, Polynomial, Radix2EvaluationDomain as D}; +use kimchi::{ + circuits::{ + berkeley_columns::BerkeleyChallenges, + domains::EvaluationDomains, + expr::{l0_1, Constants}, + }, + curve::KimchiCurve, + groupmap::GroupMap, + plonk_sponge::FrSponge, + proof::PointEvaluations, +}; +use log::debug; +use mina_poseidon::{sponge::ScalarChallenge, FqSponge}; +use o1_utils::ExtendedDensePolynomial; +use poly_commitment::{ + commitment::{absorb_commitment, PolyComm}, + ipa::{OpeningProof, SRS}, + utils::DensePolynomialOrEvaluations, + OpenProof as _, SRS as _, +}; +use rand::{CryptoRng, RngCore}; +use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator}; + +use super::{ + column_env::ColumnEnvironment, + proof::{Proof, ProofInputs, WitnessColumns}, + DEGREE_QUOTIENT_POLYNOMIAL, +}; +use crate::{interpreters::mips::column::N_MIPS_SEL_COLS, E}; +use thiserror::Error; + +/// Errors that can arise when creating a proof +#[derive(Error, Debug, Clone)] +pub enum ProverError { + #[error("the provided constraint has degree {0} > allowed {1}; expr: {2}")] + ConstraintDegreeTooHigh(u64, u64, String), +} + +/// Make a PlonKish proof for the given circuit. As inputs, we get the execution +/// trace consisting of evaluations of polynomials over a certain domain +/// `domain`. +/// +/// The proof is made of the following steps: +/// 1. For each column, we create a commitment and absorb it in the sponge. +/// 2. We compute the quotient polynomial. +/// 3. We evaluate each polynomial (columns + quotient) to two challenges ζ and ζω. +/// 4. We make a batch opening proof using the IPA PCS. +/// +/// The final proof consists of the opening proof, the commitments and the +/// evaluations at ζ and ζω. +pub fn prove< + G: KimchiCurve, + EFqSponge: FqSponge + Clone, + EFrSponge: FrSponge, + RNG, +>( + domain: EvaluationDomains, + srs: &SRS, + inputs: ProofInputs, + constraints: &[E], + rng: &mut RNG, +) -> Result, ProverError> +where + G::BaseField: PrimeField, + RNG: RngCore + CryptoRng, +{ + let num_chunks = 1; + let omega = domain.d1.group_gen; + + let mut fq_sponge = EFqSponge::new(G::other_curve_sponge_params()); + + //////////////////////////////////////////////////////////////////////////// + // Round 1: Creating and absorbing column commitments + //////////////////////////////////////////////////////////////////////////// + + debug!("Prover: interpolating all columns, including the selectors"); + let ProofInputs { evaluations } = inputs; + let polys: WitnessColumns< + DensePolynomial, + [DensePolynomial; N_MIPS_SEL_COLS], + > = { + let WitnessColumns { + scratch, + scratch_inverse, + instruction_counter, + error, + selector, + } = evaluations; + + let domain_size = domain.d1.size as usize; + + // Build the selectors + let selector: [Vec; N_MIPS_SEL_COLS] = array::from_fn(|i| { + let mut s_i = Vec::with_capacity(domain_size); + for s in &selector { + s_i.push(if G::ScalarField::from(i as u64) == *s { + G::ScalarField::one() + } else { + G::ScalarField::zero() + }) + } + s_i + }); + + let eval_col = |evals: Vec| { + Evaluations::>::from_vec_and_domain(evals, domain.d1) + .interpolate() + }; + // Doing in parallel + let scratch = scratch.into_par_iter().map(eval_col).collect::>(); + let scratch_inverse = scratch_inverse + .into_par_iter() + .map(|mut evals| { + ark_ff::batch_inversion(&mut evals); + eval_col(evals) + }) + .collect::>(); + let selector = selector.into_par_iter().map(eval_col).collect::>(); + WitnessColumns { + scratch: scratch.try_into().unwrap(), + scratch_inverse: scratch_inverse.try_into().unwrap(), + instruction_counter: eval_col(instruction_counter), + error: eval_col(error.clone()), + selector: selector.try_into().unwrap(), + } + }; + + debug!("Prover: committing to all columns, including the selectors"); + let commitments: WitnessColumns, [PolyComm; N_MIPS_SEL_COLS]> = { + let WitnessColumns { + scratch, + scratch_inverse, + instruction_counter, + error, + selector, + } = &polys; + + let comm = |poly: &DensePolynomial| { + srs.commit_custom( + poly, + num_chunks, + &PolyComm::new(vec![G::ScalarField::one()]), + ) + .unwrap() + .commitment + }; + // Doing in parallel + let scratch = scratch.par_iter().map(comm).collect::>(); + let scratch_inverse = scratch_inverse.par_iter().map(comm).collect::>(); + let selector = selector.par_iter().map(comm).collect::>(); + WitnessColumns { + scratch: scratch.try_into().unwrap(), + scratch_inverse: scratch_inverse.try_into().unwrap(), + instruction_counter: comm(instruction_counter), + error: comm(error), + selector: selector.try_into().unwrap(), + } + }; + + debug!("Prover: evaluating all columns, including the selectors, on d8"); + // We evaluate on a domain higher than d1 for the quotient polynomial. + // Based on the regression test + // `test_regression_constraints_with_selectors`, the highest degree is 6. + // Therefore, we do evaluate on d8. + let evaluations_d8 = { + let WitnessColumns { + scratch, + scratch_inverse, + instruction_counter, + error, + selector, + } = &polys; + let eval_d8 = + |poly: &DensePolynomial| poly.evaluate_over_domain_by_ref(domain.d8); + // Doing in parallel + let scratch = scratch.into_par_iter().map(eval_d8).collect::>(); + let scratch_inverse = scratch_inverse + .into_par_iter() + .map(eval_d8) + .collect::>(); + let selector = selector.into_par_iter().map(eval_d8).collect::>(); + WitnessColumns { + scratch: scratch.try_into().unwrap(), + scratch_inverse: scratch_inverse.try_into().unwrap(), + instruction_counter: eval_d8(instruction_counter), + error: eval_d8(error), + selector: selector.try_into().unwrap(), + } + }; + + // Absorbing the commitments - Fiat Shamir + // We do not parallelize as we need something deterministic. + for comm in commitments.scratch.iter() { + absorb_commitment(&mut fq_sponge, comm) + } + for comm in commitments.scratch_inverse.iter() { + absorb_commitment(&mut fq_sponge, comm) + } + absorb_commitment(&mut fq_sponge, &commitments.instruction_counter); + absorb_commitment(&mut fq_sponge, &commitments.error); + for comm in commitments.selector.iter() { + absorb_commitment(&mut fq_sponge, comm) + } + + //////////////////////////////////////////////////////////////////////////// + // Round 2: Creating and committing to the quotient polynomial + //////////////////////////////////////////////////////////////////////////// + + let (_, endo_r) = G::endos(); + + // Constraints combiner + let alpha: G::ScalarField = fq_sponge.challenge(); + + let zk_rows = 0; + let column_env: ColumnEnvironment<'_, G::ScalarField> = { + // FIXME: use a proper Challenge structure + let challenges = BerkeleyChallenges { + alpha, + // No permutation argument for the moment + beta: G::ScalarField::zero(), + gamma: G::ScalarField::zero(), + // No lookup for the moment + joint_combiner: G::ScalarField::zero(), + }; + ColumnEnvironment { + constants: Constants { + endo_coefficient: *endo_r, + mds: &G::sponge_params().mds, + zk_rows, + }, + challenges, + witness: &evaluations_d8, + l0_1: l0_1(domain.d1), + domain, + } + }; + + debug!("Prover: computing the quotient polynomial"); + // Hint: + // To debug individual constraint, you can revert the following commits that implement the + // check for individual constraints. + // ``` + // git revert 8e87244a98d55b90d175ad389611a3c98bd16b34 + // git revert 96d42c127ef025869c91e5fed680e0e383108706 + // ``` + let quotient_poly: DensePolynomial = { + // Compute ∑ α^i constraint_i as an expression + let combined_expr = + E::combine_constraints(0..(constraints.len() as u32), (constraints).to_vec()); + + // We want to compute the quotient polynomial, i.e. + // t(X) = (∑ α^i constraint_i(X)) / Z_H(X). + // The sum of the expressions is called the "constraint polynomial". + // We will use the evaluations points of the individual witness + // columns. + // Note that as the constraints might be of higher degree than N, the + // size of the set H we want the constraints to be verified on, we must + // have more than N evaluations points for each columns. This is handled + // in the ColumnEnvironment structure. + // Reminder: to compute P(X) = P_{1}(X) * P_{2}(X), from the evaluations + // of P_{1} and P_{2}, with deg(P_{1}) = deg(P_{2}(X)) = N, we must have + // 2N evaluation points to compute P as deg(P(X)) <= 2N. + let expr_evaluation: Evaluations> = + combined_expr.evaluations(&column_env); + + // And we interpolate using the evaluations + let expr_evaluation_interpolated = expr_evaluation.interpolate(); + + let fail_final_q_division = || panic!("Fail division by vanishing poly"); + let fail_remainder_not_zero = + || panic!("The constraints are not satisifed since the remainder is not zero"); + // We compute the polynomial t(X) by dividing the constraints polynomial + // by the vanishing polynomial, i.e. Z_H(X). + let (quotient, rem) = expr_evaluation_interpolated + .divide_by_vanishing_poly(domain.d1) + .unwrap_or_else(fail_final_q_division); + // As the constraints must be verified on H, the rest of the division + // must be equal to 0 as the constraints polynomial and Z_H(X) are both + // equal on H. + if !rem.is_zero() { + fail_remainder_not_zero(); + } + + quotient + }; + + let quotient_commitment = srs + .commit_custom( + "ient_poly, + DEGREE_QUOTIENT_POLYNOMIAL as usize, + &PolyComm::new(vec![ + G::ScalarField::one(); + DEGREE_QUOTIENT_POLYNOMIAL as usize + ]), + ) + .unwrap(); + absorb_commitment(&mut fq_sponge, "ient_commitment.commitment); + + //////////////////////////////////////////////////////////////////////////// + // Round 3: Evaluations at ζ and ζω + //////////////////////////////////////////////////////////////////////////// + + debug!("Prover: evaluating all columns, including the selectors, at ζ and ζω"); + let zeta_chal = ScalarChallenge(fq_sponge.challenge()); + + let zeta = zeta_chal.to_field(endo_r); + let zeta_omega = zeta * omega; + + let evals = |point| { + let WitnessColumns { + scratch, + scratch_inverse, + instruction_counter, + error, + selector, + } = &polys; + let eval = |poly: &DensePolynomial| poly.evaluate(point); + let scratch = scratch.par_iter().map(eval).collect::>(); + let scratch_inverse = scratch_inverse.par_iter().map(eval).collect::>(); + let selector = selector.par_iter().map(eval).collect::>(); + WitnessColumns { + scratch: scratch.try_into().unwrap(), + scratch_inverse: scratch_inverse.try_into().unwrap(), + instruction_counter: eval(instruction_counter), + error: eval(error), + selector: selector.try_into().unwrap(), + } + }; + // All evaluations at ζ + let zeta_evaluations: WitnessColumns = + evals(&zeta); + + // All evaluations at ζω + let zeta_omega_evaluations: WitnessColumns = + evals(&zeta_omega); + + let chunked_quotient = quotient_poly + .to_chunked_polynomial(DEGREE_QUOTIENT_POLYNOMIAL as usize, domain.d1.size as usize); + let quotient_evaluations = PointEvaluations { + zeta: chunked_quotient + .polys + .iter() + .map(|p| p.evaluate(&zeta)) + .collect::>(), + zeta_omega: chunked_quotient + .polys + .iter() + .map(|p| p.evaluate(&zeta_omega)) + .collect(), + }; + + // Absorbing evaluations with a sponge for the other field + // We initialize the state with the previous state of the fq_sponge + let fq_sponge_before_evaluations = fq_sponge.clone(); + let mut fr_sponge = EFrSponge::new(G::sponge_params()); + fr_sponge.absorb(&fq_sponge.digest()); + + for (zeta_eval, zeta_omega_eval) in zeta_evaluations + .scratch + .iter() + .zip(zeta_omega_evaluations.scratch.iter()) + { + fr_sponge.absorb(zeta_eval); + fr_sponge.absorb(zeta_omega_eval); + } + for (zeta_eval, zeta_omega_eval) in zeta_evaluations + .scratch_inverse + .iter() + .zip(zeta_omega_evaluations.scratch_inverse.iter()) + { + fr_sponge.absorb(zeta_eval); + fr_sponge.absorb(zeta_omega_eval); + } + fr_sponge.absorb(&zeta_evaluations.instruction_counter); + fr_sponge.absorb(&zeta_omega_evaluations.instruction_counter); + fr_sponge.absorb(&zeta_evaluations.error); + fr_sponge.absorb(&zeta_omega_evaluations.error); + for (zeta_eval, zeta_omega_eval) in zeta_evaluations + .selector + .iter() + .zip(zeta_omega_evaluations.selector.iter()) + { + fr_sponge.absorb(zeta_eval); + fr_sponge.absorb(zeta_omega_eval); + } + for (quotient_zeta_eval, quotient_zeta_omega_eval) in quotient_evaluations + .zeta + .iter() + .zip(quotient_evaluations.zeta_omega.iter()) + { + fr_sponge.absorb(quotient_zeta_eval); + fr_sponge.absorb(quotient_zeta_omega_eval); + } + //////////////////////////////////////////////////////////////////////////// + // Round 4: Opening proof w/o linearization polynomial + //////////////////////////////////////////////////////////////////////////// + + let mut polynomials: Vec<_> = polys.scratch.into_iter().collect(); + polynomials.extend(polys.scratch_inverse); + polynomials.push(polys.instruction_counter); + polynomials.push(polys.error); + polynomials.extend(polys.selector); + + // Preparing the polynomials for the opening proof + let mut polynomials: Vec<_> = polynomials + .iter() + .map(|poly| { + ( + DensePolynomialOrEvaluations::DensePolynomial(poly), + // We do not have any blinder, therefore we set to 1. + PolyComm::new(vec![G::ScalarField::one()]), + ) + }) + .collect(); + // we handle the quotient separately because the number of blinders = + // number of chunks, which is different for just the quotient polynomial. + polynomials.push(( + DensePolynomialOrEvaluations::DensePolynomial("ient_poly), + quotient_commitment.blinders, + )); + + // poly scale + let v_chal = fr_sponge.challenge(); + let v = v_chal.to_field(endo_r); + // eval scale + let u_chal = fr_sponge.challenge(); + let u = u_chal.to_field(endo_r); + + let group_map = G::Map::setup(); + + debug!("Prover: computing the (batched) opening proof using the IPA PCS"); + // Computing the opening proof for the IPA PCS + let opening_proof = OpeningProof::open::<_, _, D>( + srs, + &group_map, + polynomials.as_slice(), + &[zeta, zeta_omega], + v, + u, + fq_sponge_before_evaluations, + rng, + ); + + Ok(Proof { + commitments, + zeta_evaluations, + zeta_omega_evaluations, + quotient_commitment: quotient_commitment.commitment, + quotient_evaluations, + opening_proof, + }) +} diff --git a/o1vm/src/pickles/tests.rs b/o1vm/src/pickles/tests.rs new file mode 100644 index 0000000000..b4bb102dc7 --- /dev/null +++ b/o1vm/src/pickles/tests.rs @@ -0,0 +1,134 @@ +use std::time::Instant; + +use super::{ + super::interpreters::mips::column::SCRATCH_SIZE, + proof::{ProofInputs, WitnessColumns}, + prover::prove, +}; +use crate::{ + interpreters::mips::{ + column::SCRATCH_SIZE_INVERSE, + constraints as mips_constraints, + interpreter::{self, InterpreterEnv}, + Instruction, + }, + pickles::{verifier::verify, MAXIMUM_DEGREE_CONSTRAINTS, TOTAL_NUMBER_OF_CONSTRAINTS}, +}; +use ark_ff::{Field, One, UniformRand, Zero}; +use kimchi::circuits::{domains::EvaluationDomains, expr::Expr, gate::CurrOrNext}; +use kimchi_msm::{columns::Column, expr::E}; +use log::debug; +use mina_curves::pasta::{Fp, Fq, Pallas, PallasParameters}; +use mina_poseidon::{ + constants::PlonkSpongeConstantsKimchi, + sponge::{DefaultFqSponge, DefaultFrSponge}, +}; +use o1_utils::tests::make_test_rng; +use poly_commitment::SRS; +use strum::IntoEnumIterator; + +#[test] +fn test_regression_constraints_with_selectors() { + let constraints = { + let mut mips_con_env = mips_constraints::Env::::default(); + let mut constraints = Instruction::iter() + .flat_map(|instr_typ| instr_typ.into_iter()) + .fold(vec![], |mut acc, instr| { + interpreter::interpret_instruction(&mut mips_con_env, instr); + let selector = mips_con_env.get_selector(); + let constraints_with_selector: Vec> = mips_con_env + .get_constraints() + .into_iter() + .map(|c| selector.clone() * c) + .collect(); + acc.extend(constraints_with_selector); + mips_con_env.reset(); + acc + }); + constraints.extend(mips_con_env.get_selector_constraints()); + constraints + }; + + assert_eq!(constraints.len(), TOTAL_NUMBER_OF_CONSTRAINTS); + + let max_degree = constraints.iter().map(|c| c.degree(1, 0)).max().unwrap(); + assert_eq!(max_degree, MAXIMUM_DEGREE_CONSTRAINTS); +} + +fn zero_to_n_minus_one(n: usize) -> Vec { + (0..n).map(|i| Fq::from((i) as u64)).collect() +} + +#[test] +fn test_small_circuit() { + let domain = EvaluationDomains::::create(8).unwrap(); + let srs = SRS::create(8); + let proof_input = ProofInputs:: { + evaluations: WitnessColumns { + scratch: std::array::from_fn(|_| zero_to_n_minus_one(8)), + scratch_inverse: std::array::from_fn(|_| (0..8).map(|_| Fq::zero()).collect()), + instruction_counter: zero_to_n_minus_one(8) + .into_iter() + .map(|x| x + Fq::one()) + .collect(), + error: (0..8) + .map(|i| -Fq::from((i * SCRATCH_SIZE + (i + 1)) as u64)) + .collect(), + selector: zero_to_n_minus_one(8), + }, + }; + let mut expr = Expr::zero(); + for i in 0..SCRATCH_SIZE + SCRATCH_SIZE_INVERSE + 2 { + expr += Expr::cell(Column::Relation(i), CurrOrNext::Curr); + } + let mut rng = make_test_rng(None); + + type BaseSponge = DefaultFqSponge; + type ScalarSponge = DefaultFrSponge; + + let proof = prove::( + domain, + &srs, + proof_input, + &[expr.clone()], + &mut rng, + ) + .unwrap(); + + let instant_before_verification = Instant::now(); + let verif = verify::(domain, &srs, &[expr.clone()], &proof); + let instant_after_verification = Instant::now(); + debug!( + "Verification took: {} ms", + (instant_after_verification - instant_before_verification).as_millis() + ); + assert!(verif, "Verification fails"); +} + +#[test] +fn test_arkworks_batch_inversion_with_only_zeroes() { + let input = vec![Fq::zero(); 8]; + let exp_output = vec![Fq::zero(); 8]; + let mut output = input.clone(); + ark_ff::batch_inversion::(&mut output); + assert_eq!(output, exp_output); +} + +#[test] +fn test_arkworks_batch_inversion_with_zeroes_and_ones() { + let input: Vec = vec![Fq::zero(), Fq::one(), Fq::zero()]; + let exp_output: Vec = vec![Fq::zero(), Fq::one(), Fq::zero()]; + let mut output: Vec = input.clone(); + ark_ff::batch_inversion::(&mut output); + assert_eq!(output, exp_output); +} + +#[test] +fn test_arkworks_batch_inversion_with_zeroes_and_random() { + let mut rng = o1_utils::tests::make_test_rng(None); + let input: Vec = vec![Fq::zero(), Fq::rand(&mut rng), Fq::one()]; + let exp_output: Vec = vec![Fq::zero(), input[1].inverse().unwrap(), Fq::one()]; + let mut output: Vec = input.clone(); + ark_ff::batch_inversion::(&mut output); + assert_eq!(output, exp_output); +} diff --git a/o1vm/src/pickles/verifier.rs b/o1vm/src/pickles/verifier.rs new file mode 100644 index 0000000000..fbbaa912eb --- /dev/null +++ b/o1vm/src/pickles/verifier.rs @@ -0,0 +1,267 @@ +use ark_ec::AffineRepr; +use ark_ff::{Field, One, PrimeField, Zero}; +use rand::thread_rng; + +use kimchi::{ + circuits::{ + berkeley_columns::BerkeleyChallenges, + domains::EvaluationDomains, + expr::{ColumnEvaluations, Constants, Expr, ExprError, PolishToken}, + gate::CurrOrNext, + }, + curve::KimchiCurve, + groupmap::GroupMap, + plonk_sponge::FrSponge, + proof::PointEvaluations, +}; +use mina_poseidon::{sponge::ScalarChallenge, FqSponge}; +use poly_commitment::{ + commitment::{ + absorb_commitment, combined_inner_product, BatchEvaluationProof, Evaluation, PolyComm, + }, + ipa::OpeningProof, + OpenProof, +}; + +use super::{ + column_env::get_all_columns, + proof::{Proof, WitnessColumns}, +}; +use crate::{interpreters::mips::column::N_MIPS_SEL_COLS, E}; +use kimchi_msm::columns::Column; + +type CommitmentColumns = WitnessColumns, [PolyComm; N_MIPS_SEL_COLS]>; +type EvaluationColumns = WitnessColumns; + +struct ColumnEval<'a, G: AffineRepr> { + commitment: &'a CommitmentColumns, + zeta_eval: &'a EvaluationColumns, + zeta_omega_eval: &'a EvaluationColumns, +} + +impl ColumnEvaluations for ColumnEval<'_, G> { + type Column = Column; + fn evaluate( + &self, + col: Self::Column, + ) -> Result, ExprError> { + let ColumnEval { + commitment: _, + zeta_eval, + zeta_omega_eval, + } = self; + if let Some(&zeta) = zeta_eval.get_column(&col) { + if let Some(&zeta_omega) = zeta_omega_eval.get_column(&col) { + Ok(PointEvaluations { zeta, zeta_omega }) + } else { + Err(ExprError::MissingEvaluation(col, CurrOrNext::Next)) + } + } else { + Err(ExprError::MissingEvaluation(col, CurrOrNext::Curr)) + } + } +} + +pub fn verify< + G: KimchiCurve, + EFqSponge: Clone + FqSponge, + EFrSponge: FrSponge, +>( + domain: EvaluationDomains, + srs: & as OpenProof>::SRS, + constraints: &[E], + proof: &Proof, +) -> bool +where + ::BaseField: PrimeField, +{ + let Proof { + commitments, + zeta_evaluations, + zeta_omega_evaluations, + quotient_commitment, + quotient_evaluations, + opening_proof, + } = proof; + + //////////////////////////////////////////////////////////////////////////// + // TODO : public inputs + //////////////////////////////////////////////////////////////////////////// + + //////////////////////////////////////////////////////////////////////////// + // Absorbing all the commitments to the columns + //////////////////////////////////////////////////////////////////////////// + + let mut fq_sponge = EFqSponge::new(G::other_curve_sponge_params()); + for comm in commitments.scratch.iter() { + absorb_commitment(&mut fq_sponge, comm) + } + for comm in commitments.scratch_inverse.iter() { + absorb_commitment(&mut fq_sponge, comm) + } + absorb_commitment(&mut fq_sponge, &commitments.instruction_counter); + absorb_commitment(&mut fq_sponge, &commitments.error); + for comm in commitments.selector.iter() { + absorb_commitment(&mut fq_sponge, comm) + } + + // Sample α with the Fq-Sponge. + let alpha = fq_sponge.challenge(); + + //////////////////////////////////////////////////////////////////////////// + // Quotient polynomial + //////////////////////////////////////////////////////////////////////////// + + absorb_commitment(&mut fq_sponge, quotient_commitment); + + // -- Preparing for opening proof verification + let zeta_chal = ScalarChallenge(fq_sponge.challenge()); + let (_, endo_r) = G::endos(); + let zeta: G::ScalarField = zeta_chal.to_field(endo_r); + let omega = domain.d1.group_gen; + let zeta_omega = zeta * omega; + + let column_eval = ColumnEval { + commitment: commitments, + zeta_eval: zeta_evaluations, + zeta_omega_eval: zeta_omega_evaluations, + }; + + // -- Absorb all commitments_and_evaluations + let fq_sponge_before_commitments_and_evaluations = fq_sponge.clone(); + let mut fr_sponge = EFrSponge::new(G::sponge_params()); + fr_sponge.absorb(&fq_sponge.digest()); + + for (zeta_eval, zeta_omega_eval) in zeta_evaluations + .scratch + .iter() + .zip(zeta_omega_evaluations.scratch.iter()) + { + fr_sponge.absorb(zeta_eval); + fr_sponge.absorb(zeta_omega_eval); + } + for (zeta_eval, zeta_omega_eval) in zeta_evaluations + .scratch_inverse + .iter() + .zip(zeta_omega_evaluations.scratch_inverse.iter()) + { + fr_sponge.absorb(zeta_eval); + fr_sponge.absorb(zeta_omega_eval); + } + fr_sponge.absorb(&zeta_evaluations.instruction_counter); + fr_sponge.absorb(&zeta_omega_evaluations.instruction_counter); + fr_sponge.absorb(&zeta_evaluations.error); + fr_sponge.absorb(&zeta_omega_evaluations.error); + for (zeta_eval, zeta_omega_eval) in zeta_evaluations + .selector + .iter() + .zip(zeta_omega_evaluations.selector.iter()) + { + fr_sponge.absorb(zeta_eval); + fr_sponge.absorb(zeta_omega_eval); + } + for (quotient_zeta_eval, quotient_zeta_omega_eval) in quotient_evaluations + .zeta + .iter() + .zip(quotient_evaluations.zeta_omega.iter()) + { + fr_sponge.absorb(quotient_zeta_eval); + fr_sponge.absorb(quotient_zeta_omega_eval); + } + + // FIXME: use a proper Challenge structure + let challenges = BerkeleyChallenges { + alpha, + // No permutation argument for the moment + beta: G::ScalarField::zero(), + gamma: G::ScalarField::zero(), + // No lookup for the moment + joint_combiner: G::ScalarField::zero(), + }; + let (_, endo_r) = G::endos(); + + let constants = Constants { + endo_coefficient: *endo_r, + mds: &G::sponge_params().mds, + zk_rows: 0, + }; + + let combined_expr = + Expr::combine_constraints(0..(constraints.len() as u32), constraints.to_vec()); + + let numerator_zeta = PolishToken::evaluate( + combined_expr.to_polish().as_slice(), + domain.d1, + zeta, + &column_eval, + &constants, + &challenges, + ) + .unwrap_or_else(|_| panic!("Could not evaluate quotient polynomial at zeta")); + + let v_chal = fr_sponge.challenge(); + let v = v_chal.to_field(endo_r); + let u_chal = fr_sponge.challenge(); + let u = u_chal.to_field(endo_r); + + let mut evaluations: Vec<_> = get_all_columns() + .into_iter() + .map(|column| { + let commitment = column_eval + .commitment + .get_column(&column) + .unwrap_or_else(|| panic!("Could not get `commitment` for `Evaluation`")) + .clone(); + + let evaluations = column_eval + .evaluate(column) + .unwrap_or_else(|_| panic!("Could not get `evaluations` for `Evaluation`")); + + Evaluation { + commitment, + evaluations: vec![vec![evaluations.zeta], vec![evaluations.zeta_omega]], + } + }) + .collect(); + + evaluations.push(Evaluation { + commitment: proof.quotient_commitment.clone(), + evaluations: vec![ + quotient_evaluations.zeta.clone(), + quotient_evaluations.zeta_omega.clone(), + ], + }); + + let combined_inner_product = { + let es: Vec<_> = evaluations + .iter() + .map(|Evaluation { evaluations, .. }| evaluations.clone()) + .collect(); + + combined_inner_product(&v, &u, es.as_slice()) + }; + + let batch = BatchEvaluationProof { + sponge: fq_sponge_before_commitments_and_evaluations, + evaluations, + evaluation_points: vec![zeta, zeta_omega], + polyscale: v, + evalscale: u, + opening: opening_proof, + combined_inner_product, + }; + + let group_map = G::Map::setup(); + + // Check the actual quotient works. + let (quotient_zeta, _) = quotient_evaluations.zeta.iter().fold( + (G::ScalarField::zero(), G::ScalarField::one()), + |(res, zeta_i_n), chunk| { + let res = res + zeta_i_n * chunk; + let zeta_i_n = zeta_i_n * zeta.pow([domain.d1.size]); + (res, zeta_i_n) + }, + ); + (quotient_zeta == numerator_zeta / (zeta.pow([domain.d1.size]) - G::ScalarField::one())) + && OpeningProof::verify(srs, &group_map, &mut [batch], &mut thread_rng()) +} diff --git a/o1vm/src/preimage_oracle.rs b/o1vm/src/preimage_oracle.rs new file mode 100644 index 0000000000..4ff82106c5 --- /dev/null +++ b/o1vm/src/preimage_oracle.rs @@ -0,0 +1,302 @@ +use crate::cannon::{ + Hint, HostProgram, Preimage, HINT_CLIENT_READ_FD, HINT_CLIENT_WRITE_FD, + PREIMAGE_CLIENT_READ_FD, PREIMAGE_CLIENT_WRITE_FD, +}; +use command_fds::{CommandFdExt, FdMapping}; +use log::debug; +use os_pipe::{PipeReader, PipeWriter}; +use std::{ + io::{Read, Write}, + os::fd::{AsRawFd, FromRawFd, OwnedFd}, + process::{Child, Command}, +}; + +pub struct PreImageOracle { + pub cmd: Command, + pub oracle_client: RW, + pub oracle_server: RW, + pub hint_client: RW, + pub hint_server: RW, +} + +pub trait PreImageOracleT { + fn get_preimage(&mut self, key: [u8; 32]) -> Preimage; + + fn hint(&mut self, hint: Hint); +} + +pub struct ReadWrite { + pub reader: R, + pub writer: W, +} + +pub struct RW(pub ReadWrite); + +// Here, we implement `os_pipe::pipe` in a way that allows us to pass flags. In particular, we +// don't pass the `CLOEXEC` flag, because we want these pipes to survive an exec, and we set +// `DIRECT` to handle writes as single atomic operations (up to splitting at the buffer size). +// This fixes the IPC hangs. This is bad, but the hang is worse. + +#[cfg(not(any(target_os = "ios", target_os = "macos", target_os = "haiku", windows)))] +fn create_pipe() -> std::io::Result<(PipeReader, PipeWriter)> { + let mut fds: [libc::c_int; 2] = [0; 2]; + let res = unsafe { libc::pipe2(fds.as_mut_ptr(), libc::O_DIRECT) }; + if res != 0 { + return Err(std::io::Error::last_os_error()); + } + unsafe { + Ok(( + PipeReader::from_raw_fd(fds[0]), + PipeWriter::from_raw_fd(fds[1]), + )) + } +} + +#[cfg(any(target_os = "ios", target_os = "macos", target_os = "haiku"))] +pub fn create_pipe() -> std::io::Result<(PipeReader, PipeWriter)> { + let mut fds: [libc::c_int; 2] = [0; 2]; + let res = unsafe { libc::pipe(fds.as_mut_ptr()) }; + if res != 0 { + return Err(std::io::Error::last_os_error()); + } + // It appears that Mac and friends don't have DIRECT. Oh well. Don't use a Mac. + let res = unsafe { libc::fcntl(fds[0], libc::F_SETFD, 0) }; + if res != 0 { + return Err(std::io::Error::last_os_error()); + } + let res = unsafe { libc::fcntl(fds[1], libc::F_SETFD, 0) }; + if res != 0 { + return Err(std::io::Error::last_os_error()); + } + unsafe { + Ok(( + PipeReader::from_raw_fd(fds[0]), + PipeWriter::from_raw_fd(fds[1]), + )) + } +} + +#[cfg(windows)] +pub fn create_pipe() -> std::io::Result<(PipeReader, PipeWriter)> { + os_pipe::pipe() +} + +// Create bidirectional channel between A and B +// +// Schematically we create 2 unidirectional pipes and creates 2 structures made +// by taking the writer from one and the reader from the other. +// +// A B +// | ar <---- bw | +// | aw ----> br | +// +pub fn create_bidirectional_channel() -> Option<(RW, RW)> { + let (ar, bw) = create_pipe().ok()?; + let (br, aw) = create_pipe().ok()?; + Some(( + RW(ReadWrite { + reader: ar, + writer: aw, + }), + RW(ReadWrite { + reader: br, + writer: bw, + }), + )) +} + +impl PreImageOracle { + pub fn create(hp_opt: &Option) -> PreImageOracle { + let host_program = hp_opt.as_ref().expect("No host program given"); + + let mut cmd = Command::new(&host_program.name); + cmd.args(&host_program.arguments); + + let (oracle_client, oracle_server) = + create_bidirectional_channel().expect("Could not create bidirectional oracle channel"); + let (hint_client, hint_server) = + create_bidirectional_channel().expect("Could not create bidirectional hint channel"); + + // file descriptors 0, 1, 2 respectively correspond to the inherited stdin, + // stdout, stderr. + // We need to map 3, 4, 5, 6 in the child process + cmd.fd_mappings(vec![ + FdMapping { + parent_fd: unsafe { OwnedFd::from_raw_fd(hint_server.0.writer.as_raw_fd()) }, + child_fd: HINT_CLIENT_WRITE_FD, + }, + FdMapping { + parent_fd: unsafe { OwnedFd::from_raw_fd(hint_server.0.reader.as_raw_fd()) }, + child_fd: HINT_CLIENT_READ_FD, + }, + FdMapping { + parent_fd: unsafe { OwnedFd::from_raw_fd(oracle_server.0.writer.as_raw_fd()) }, + child_fd: PREIMAGE_CLIENT_WRITE_FD, + }, + FdMapping { + parent_fd: unsafe { OwnedFd::from_raw_fd(oracle_server.0.reader.as_raw_fd()) }, + child_fd: PREIMAGE_CLIENT_READ_FD, + }, + ]) + .unwrap_or_else(|_| panic!("Could not map file descriptors to preimage server process")); + + PreImageOracle { + cmd, + oracle_client, + oracle_server, + hint_client, + hint_server, + } + } + + pub fn start(&mut self) -> Child { + // Spawning inherits the current process's stdin/stdout/stderr descriptors + self.cmd + .spawn() + .expect("Could not spawn pre-image oracle process") + } +} + +impl PreImageOracleT for PreImageOracle { + // The preimage protocol goes as follows + // 1. Ask for data through a key + // 2. Get the answers in the following format + // +------------+--------------------+ + // | length <8> | pre-image | + // +---------------------------------+ + // a. a 64-bit integer indicating the length of the actual data + // b. the preimage data, with a size of bits + fn get_preimage(&mut self, key: [u8; 32]) -> Preimage { + let RW(ReadWrite { reader, writer }) = &mut self.oracle_client; + + let r = writer.write_all(&key); + assert!(r.is_ok()); + let r = writer.flush(); + assert!(r.is_ok()); + + debug!("Reading response"); + let mut buf = [0_u8; 8]; + let resp = reader.read_exact(&mut buf); + assert!(resp.is_ok()); + + debug!("Extracting contents"); + let length = u64::from_be_bytes(buf); + let mut preimage = vec![0_u8; length as usize]; + let resp = reader.read_exact(&mut preimage); + + assert!(resp.is_ok()); + + debug!( + "Got preimage of length {}\n {}", + preimage.len(), + hex::encode(&preimage) + ); + // We should have read exactly bytes + assert_eq!(preimage.len(), length as usize); + + Preimage::create(preimage) + } + + // The hint protocol goes as follows: + // 1. Write a hint request with the following byte-stream format + // +------------+---------------+ + // | length <8> | hint | + // +----------------------------+ + // + // 2. Get back a single ack byte informing the hint has been processed. + fn hint(&mut self, hint: Hint) { + let RW(ReadWrite { reader, writer }) = &mut self.hint_client; + + // Write hint request + let mut hint_bytes = hint.get(); + let hint_length = hint_bytes.len(); + + let mut msg: Vec = vec![]; + msg.append(&mut u64::to_be_bytes(hint_length as u64).to_vec()); + msg.append(&mut hint_bytes); + + let _ = writer.write(&msg); + + // Read single byte acknowledgment response + let mut buf = [0_u8]; + // And do nothing with it anyway + let _ = reader.read_exact(&mut buf); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // Test that bidirectional channels work as expected + // That is, after creating a pair (c0, c1) + // 1. c1's reader can read what c0's writer produces + // 2. c0's reader can read what c1's writer produces + #[test] + fn test_bidir_channels() { + let (mut c0, mut c1) = create_bidirectional_channel().unwrap(); + + // Send a single byte message + let msg = [42_u8]; + let mut buf = [0_u8; 1]; + + let writer_joiner = std::thread::spawn(move || { + let r = c0.0.writer.write(&msg); + assert!(r.is_ok()); + }); + + let reader_joiner = std::thread::spawn(move || { + let r = c1.0.reader.read_exact(&mut buf); + assert!(r.is_ok()); + buf + }); + + // Retrieve the buffer from the reader + let buf = reader_joiner.join().unwrap(); + // Ensure that the writer has completed + writer_joiner.join().unwrap(); + + // Check that we correctly read the single byte message + assert_eq!(msg, buf); + + // Create a more structured message with the preimage format + // +------------+--------------------+ + // | length <8> | pre-image | + // +---------------------------------+ + // Here we'll use a length of 2 + let msg2 = vec![42_u8, 43_u8]; + let len = msg2.len() as u64; + let mut msg = u64::to_be_bytes(len).to_vec(); + msg.extend_from_slice(&msg2); + + // Write the message + let writer_joiner = std::thread::spawn(move || { + let r = c1.0.writer.write(&msg); + assert!(r.is_ok()); + msg + }); + + // Read back the length from the other end of the bidirectionnal channel + let reader_joiner = std::thread::spawn(move || { + let mut response_vec = vec![]; + // We do a single read to mirror go, even though we should *really* do 2. 'Go' figure. + let r = c0.0.reader.read_to_end(&mut response_vec); + assert!(r.is_ok()); + + let n = u64::from_be_bytes(response_vec[0..8].try_into().unwrap()); + + let data = response_vec[8..(n + 8) as usize].to_vec(); + (n, data) + }); + + // Retrieve the data from the reader + let (n, data) = reader_joiner.join().unwrap(); + + // Ensure that the writer has completed + writer_joiner.join().unwrap(); + + // Check that the responses are equal + assert_eq!(n, len); + assert_eq!(data, msg2); + } +} diff --git a/o1vm/src/ramlookup.rs b/o1vm/src/ramlookup.rs new file mode 100644 index 0000000000..beafab2551 --- /dev/null +++ b/o1vm/src/ramlookup.rs @@ -0,0 +1,117 @@ +use ark_ff::{Field, One, Zero}; +use kimchi_msm::{Logup, LookupTableID}; + +/// Enum representing the two different modes of a RAMLookup +#[derive(Copy, Clone, Debug)] +pub enum LookupMode { + Read, + Write, +} + +/// Struct containing a RAMLookup +#[derive(Clone, Debug)] +pub struct RAMLookup { + /// The table ID corresponding to this lookup + pub(crate) table_id: ID, + /// Whether it is a read or write lookup + pub(crate) mode: LookupMode, + /// The number of times that this lookup value should be added to / subtracted from the lookup accumulator. + pub(crate) magnitude: T, + /// The columns containing the content of this lookup + pub(crate) value: Vec, +} + +impl RAMLookup +where + T: Clone + + std::ops::Add + + std::ops::Sub + + std::ops::Mul + + std::fmt::Debug + + One + + Zero, + ID: LookupTableID, +{ + /// Creates a new RAMLookup from a mode, a table ID, a magnitude, and a value + pub fn new(mode: LookupMode, table_id: ID, magnitude: T, value: &[T]) -> Self { + Self { + mode, + table_id, + magnitude, + value: value.to_vec(), + } + } + + /// Returns the numerator corresponding to this lookup in the Logup argument + pub fn numerator(&self) -> T { + match self.mode { + LookupMode::Read => T::zero() - self.magnitude.clone(), + LookupMode::Write => self.magnitude.clone(), + } + } + + /// Transforms the current RAMLookup into an equivalent Logup + pub fn into_logup(self) -> Logup { + Logup::new(self.table_id, self.numerator(), &self.value) + } + + /// Reads one value when `if_is_true` is 1. + pub fn read_if(if_is_true: T, table_id: ID, value: Vec) -> Self { + Self { + mode: LookupMode::Read, + magnitude: if_is_true, + table_id, + value, + } + } + + /// Writes one value when `if_is_true` is 1. + pub fn write_if(if_is_true: T, table_id: ID, value: Vec) -> Self { + Self { + mode: LookupMode::Write, + magnitude: if_is_true, + table_id, + value, + } + } + + /// Reads one value from a table. + pub fn read_one(table_id: ID, value: Vec) -> Self { + Self { + mode: LookupMode::Read, + magnitude: T::one(), + table_id, + value, + } + } + + /// Writes one value to a table. + pub fn write_one(table_id: ID, value: Vec) -> Self { + Self { + mode: LookupMode::Write, + magnitude: T::one(), + table_id, + value, + } + } +} + +impl std::fmt::Display for RAMLookup { + fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let numerator = match self.mode { + LookupMode::Read => -self.magnitude, + LookupMode::Write => self.magnitude, + }; + write!( + formatter, + "numerator: {}\ntable_id: {:?}\nvalue:\n[\n", + numerator, + self.table_id.to_field::() + )?; + for value in self.value.iter() { + writeln!(formatter, "\t{}", value)?; + } + write!(formatter, "]")?; + Ok(()) + } +} diff --git a/o1vm/src/test_preimage_read.rs b/o1vm/src/test_preimage_read.rs new file mode 100644 index 0000000000..094702a8fc --- /dev/null +++ b/o1vm/src/test_preimage_read.rs @@ -0,0 +1,91 @@ +use clap::Arg; +use log::{debug, error}; +use o1vm::{ + cannon::PreimageKey, + cannon_cli::{main_cli, read_configuration}, + preimage_oracle::{PreImageOracle, PreImageOracleT}, +}; +use std::{ + io::{self}, + path::{Path, PathBuf}, + process::ExitCode, + str::FromStr, +}; + +fn main() -> ExitCode { + use rand::Rng; + + env_logger::init(); + + // Add command-line parameter to read the Optimism op-program DB directory + let cli = main_cli().arg( + Arg::new("preimage-db-dir") + .long("preimage-db-dir") + .value_name("PREIMAGE_DB_DIR"), + ); + + // Now read matches with the additional argument(s) + let matches = cli.get_matches(); + + let configuration = read_configuration(&matches); + + // Get DB directory and abort if unset + let preimage_db_dir = matches.get_one::("preimage-db-dir"); + + if let Some(preimage_key_dir) = preimage_db_dir { + let mut po = PreImageOracle::create(&configuration.host); + let _child = po.start(); + debug!("Let server start"); + std::thread::sleep(std::time::Duration::from_secs(5)); + + debug!("Reading from {}", preimage_key_dir); + // Get all files under the preimage db dir + let paths = std::fs::read_dir(preimage_key_dir) + .unwrap() + .map(|res| res.map(|e| e.path())) + .collect::, io::Error>>() + .unwrap(); + + // Just take 10 elements at random to test + let how_many = 10_usize; + let max = paths.len(); + let mut rng = rand::thread_rng(); + + let mut selected: Vec = vec![PathBuf::new(); how_many]; + for pb in selected.iter_mut() { + let idx = rng.gen_range(0..max); + *pb = paths[idx].to_path_buf(); + } + + for (idx, path) in selected.into_iter().enumerate() { + let preimage_key = Path::new(&path) + .file_name() + .unwrap() + .to_str() + .unwrap() + .split('.') + .collect::>()[0]; + + let hash = PreimageKey::from_str(preimage_key).unwrap(); + let mut key = hash.0; + key[0] = 2; // Keccak + + debug!( + "Generating OP Keccak key for {} at index {}", + preimage_key, idx + ); + + let expected = std::fs::read_to_string(path).unwrap(); + + debug!("Asking for preimage"); + let preimage = po.get_preimage(key); + let got = hex::encode(preimage.get()).to_string(); + + assert_eq!(expected, got); + } + ExitCode::SUCCESS + } else { + error!("Unset command-line argument --preimage-db-dir. Cannot run test. Please set parameter and rerun."); + ExitCode::FAILURE + } +} diff --git a/o1vm/src/utils.rs b/o1vm/src/utils.rs new file mode 100644 index 0000000000..64b70ded78 --- /dev/null +++ b/o1vm/src/utils.rs @@ -0,0 +1,50 @@ +const KUNIT: usize = 1024; // a kunit of memory is 1024 things (bytes, kilobytes, ...) +const PREFIXES: &str = "KMGTPE"; // prefixes for memory quantities KiB, MiB, GiB, ... + +// Create a human-readable string representation of the memory size +pub fn memory_size(total: usize) -> String { + if total < KUNIT { + format!("{total} B") + } else { + // Compute the index in the prefixes string above + let mut idx = 0; + let mut d = KUNIT; + let mut n = total / KUNIT; + + while n >= KUNIT { + d *= KUNIT; + idx += 1; + n /= KUNIT; + } + + let value = total as f64 / d as f64; + + let prefix = + //////////////////////////////////////////////////////////////////////// + // Famous last words: 1023 exabytes ought to be enough for anybody // + // // + // Corollary: // + // unwrap() below shouldn't fail // + // The maximum representation for usize corresponds to 16 exabytes // + // anyway // + //////////////////////////////////////////////////////////////////////// + PREFIXES.chars().nth(idx).unwrap(); + + format!("{:.1} {}iB", value, prefix) + } +} + +#[cfg(test)] +mod tests { + + use super::*; + + #[test] + fn test_memory_size() { + assert_eq!(memory_size(1023_usize), "1023 B"); + assert_eq!(memory_size(1024_usize), "1.0 KiB"); + assert_eq!(memory_size(1024 * 1024_usize), "1.0 MiB"); + assert_eq!(memory_size(2100 * 1024 * 1024_usize), "2.1 GiB"); + assert_eq!(memory_size(std::usize::MAX), "16.0 EiB"); + } +} diff --git a/o1vm/test_preimage_read.sh b/o1vm/test_preimage_read.sh new file mode 100755 index 0000000000..de5b68fc7e --- /dev/null +++ b/o1vm/test_preimage_read.sh @@ -0,0 +1,52 @@ +#!/bin/sh +# Test preimage read from VM request to server response +# Usage: test_preimage_read.sh [OP_DB_DIR] [NETWORK_NAME] + +source rpcs.sh + +set +u +if [ -z "${FILENAME}" ]; then + FILENAME="$(./setenv-for-latest-l2-block.sh)" +fi +set -u + +source $FILENAME + +set -x + +# Location of prepopulated OP DB +# OP_PROGRAM_DATA_DIR is generated by ./setenv-for-latest-block.sh +# It makes sense to use it as default value here +DATADIR=${1:-${OP_PROGRAM_DATA_DIR}} +NETWORK=${2:-sepolia} + +# If you actually need to populate the OP DB, run +################################################# +# ${OP_DIR}/op-program/bin/op-program \ # +# --log.level DEBUG \ # +# --l1 ${L1_RPC} \ # +# --l2 ${L2_RPC} \ # +# --network ${NETWORK} \ # +# --datadir ${DATADIR} \ # +# --l1.head ${L1_HEAD} \ # +# --l2.head ${L2_HEAD} \ # +# --l2.outputroot ${STARTING_OUTPUT_ROOT} \ # +# --l2.claim ${L2_CLAIM} \ # +# --l2.blocknumber ${L2_BLOCK_NUMBER} # +################################################# + +# Run test with debug on +RUST_LOG=debug cargo run -r --bin test_optimism_preimage_read -- \ + --preimage-db-dir ${DATADIR} -- \ + op-program \ + --log.level DEBUG \ + --l1 ${L1_RPC} \ + --l2 ${L2_RPC} \ + --network ${NETWORK} \ + --datadir ${DATADIR} \ + --l1.head ${L1_HEAD} \ + --l2.head ${L2_HEAD} \ + --l2.outputroot ${STARTING_OUTPUT_ROOT} \ + --l2.claim ${L2_CLAIM} \ + --l2.blocknumber ${L2_BLOCK_NUMBER} \ + --server diff --git a/o1vm/tests/test_elf_loader.rs b/o1vm/tests/test_elf_loader.rs new file mode 100644 index 0000000000..eba0fdf28e --- /dev/null +++ b/o1vm/tests/test_elf_loader.rs @@ -0,0 +1,19 @@ +#[test] +// This test is used to check that the elf loader is working correctly. +// We must export the code used in this test in a function that can be called by +// the o1vm at load time. +fn test_correctly_parsing_elf() { + let curr_dir = std::env::current_dir().unwrap(); + let path = curr_dir.join(std::path::PathBuf::from( + "resources/programs/riscv32im/fibonacci", + )); + let state = o1vm::elf_loader::parse_riscv32(&path).unwrap(); + + // This is the output we get by running objdump -d fibonacci + assert_eq!(state.pc, 69932); + + // We do have only one page of memory + assert_eq!(state.memory.len(), 1); + // Which is the 17th + assert_eq!(state.memory[0].index, 17); +} diff --git a/o1vm/tests/test_riscv_elf.rs b/o1vm/tests/test_riscv_elf.rs new file mode 100644 index 0000000000..7b278a3a8e --- /dev/null +++ b/o1vm/tests/test_riscv_elf.rs @@ -0,0 +1,75 @@ +use mina_curves::pasta::Fp; +use o1vm::interpreters::riscv32im::{ + interpreter::{IInstruction, Instruction, RInstruction}, + witness::Env, + PAGE_SIZE, +}; + +#[test] +// Checking an instruction can be converted into a string. +// It is mostly because we would want to use it to debug or write better error +// messages. +fn test_instruction_can_be_converted_into_string() { + let instruction = Instruction::RType(RInstruction::Add); + assert_eq!(instruction.to_string(), "add"); + + let instruction = Instruction::RType(RInstruction::Sub); + assert_eq!(instruction.to_string(), "sub"); + + let instruction = Instruction::IType(IInstruction::LoadByte); + assert_eq!(instruction.to_string(), "lb"); + + let instruction = Instruction::IType(IInstruction::LoadHalf); + assert_eq!(instruction.to_string(), "lh"); +} + +#[test] +fn test_no_action() { + let curr_dir = std::env::current_dir().unwrap(); + let path = curr_dir.join(std::path::PathBuf::from( + "resources/programs/riscv32im/no-action", + )); + let state = o1vm::elf_loader::parse_riscv32(&path).unwrap(); + let mut witness = Env::::create(PAGE_SIZE.try_into().unwrap(), state); + // This is the output we get by running objdump -d no-action + assert_eq!(witness.registers.current_instruction_pointer, 69844); + assert_eq!(witness.registers.next_instruction_pointer, 69848); + + (0..=7).for_each(|_| { + let instr = witness.step(); + // li is addi, li is a pseudo instruction + assert_eq!(instr, Instruction::IType(IInstruction::AddImmediate)) + }); + assert_eq!(witness.registers.general_purpose[10], 0); + assert_eq!(witness.registers.general_purpose[11], 0); + assert_eq!(witness.registers.general_purpose[12], 0); + assert_eq!(witness.registers.general_purpose[13], 0); + assert_eq!(witness.registers.general_purpose[14], 0); + assert_eq!(witness.registers.general_purpose[15], 0); + assert_eq!(witness.registers.general_purpose[16], 0); + assert_eq!(witness.registers.general_purpose[17], 42); +} + +// FIXME: stop ignoring when all the instructions are implemented. +#[test] +#[ignore] +fn test_fibonacci_7() { + let curr_dir = std::env::current_dir().unwrap(); + let path = curr_dir.join(std::path::PathBuf::from( + "resources/programs/riscv32im/fibonacci-7", + )); + let state = o1vm::elf_loader::parse_riscv32(&path).unwrap(); + let mut witness = Env::::create(PAGE_SIZE.try_into().unwrap(), state); + // This is the output we get by running objdump -d fibonacci-7 + assert_eq!(witness.registers.current_instruction_pointer, 69932); + assert_eq!(witness.registers.next_instruction_pointer, 69936); + + while !witness.halt { + witness.step(); + if witness.registers.current_instruction_pointer == 0x1117c { + // Fibonacci sequence: + // 1 1 2 3 5 8 13 + assert_eq!(witness.registers.general_purpose[10], 13); + } + } +} diff --git a/poly-commitment/Cargo.toml b/poly-commitment/Cargo.toml index b5f6cf76b4..e1b9a44151 100644 --- a/poly-commitment/Cargo.toml +++ b/poly-commitment/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "poly-commitment" version = "0.1.0" -description = "An implementation of an inner-product argument polynomial commitment scheme, as used in kimchi" +description = "Library implementing different polynomial commitments schemes, like IPA and KZG10" repository = "https://github.com/o1-labs/proof-systems" homepage = "https://o1-labs.github.io/proof-systems/" documentation = "https://o1-labs.github.io/proof-systems/rustdoc/" @@ -45,4 +45,8 @@ ocaml_types = ["ocaml", "ocaml-gen"] [[bench]] name = "poly_comm" +harness = false + +[[bench]] +name = "ipa" harness = false \ No newline at end of file diff --git a/poly-commitment/README.md b/poly-commitment/README.md new file mode 100644 index 0000000000..44b195fc4d --- /dev/null +++ b/poly-commitment/README.md @@ -0,0 +1,13 @@ +# poly-commitment: implementations of multiple PCS + +This library offers implementations of different Polynomial Commitment Scheme +(PCS) that can be used in Polynomial Interactive Oracle Proof (PIOP) like PlonK. + +Currently, the following polynomial commitment schemes are implemented: +- [KZG10](./src/kzg.rs) +- [Inner Product Argument](./src/ipa.rs) + +The implementations are made initially to be compatible with Kimchi (a Plonk-ish +variant with 15 wires and some custom gates) and to be used in the Mina +protocol. For instance, submodules are created to convert into OCaml to be used +in the Mina protocol codebase. diff --git a/poly-commitment/benches/ipa.rs b/poly-commitment/benches/ipa.rs new file mode 100644 index 0000000000..6c2abe5ef2 --- /dev/null +++ b/poly-commitment/benches/ipa.rs @@ -0,0 +1,56 @@ +use ark_ff::UniformRand; +use ark_poly::{univariate::DensePolynomial, DenseUVPolynomial, Radix2EvaluationDomain}; +use criterion::{black_box, criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion}; +use groupmap::GroupMap; +use mina_curves::pasta::{Fp, Vesta, VestaParameters}; +use mina_poseidon::{constants::PlonkSpongeConstantsKimchi, sponge::DefaultFqSponge, FqSponge}; +use poly_commitment::{ + commitment::CommitmentCurve, ipa::SRS, utils::DensePolynomialOrEvaluations, PolyComm, SRS as _, +}; + +fn benchmark_ipa_open(c: &mut Criterion) { + let mut group = c.benchmark_group("IPA"); + let group_map = ::Map::setup(); + let mut rng = o1_utils::tests::make_test_rng(None); + + let elm = vec![Fp::rand(&mut rng), Fp::rand(&mut rng)]; + let polyscale = Fp::rand(&mut rng); + let evalscale = Fp::rand(&mut rng); + for log_n in [5, 10].into_iter() { + let n = 1 << log_n; + let srs = SRS::::create(n); + let sponge = DefaultFqSponge::::new( + mina_poseidon::pasta::fq_kimchi::static_params(), + ); + let poly_coefficients: Vec = (0..n).map(|_| Fp::rand(&mut rng)).collect(); + let poly = DensePolynomial::::from_coefficients_vec(poly_coefficients); + let poly_commit = srs.commit(&poly.clone(), 1, &mut rng); + group.bench_with_input(BenchmarkId::new("IPA Vesta open", n), &n, |b, _| { + b.iter_batched( + || (poly.clone(), poly_commit.clone()), + |(poly, poly_commit)| { + let polys: Vec<( + DensePolynomialOrEvaluations<_, Radix2EvaluationDomain<_>>, + PolyComm<_>, + )> = vec![( + DensePolynomialOrEvaluations::DensePolynomial(&poly), + poly_commit.blinders, + )]; + black_box(srs.open( + &group_map, + &polys, + &elm, + polyscale, + evalscale, + sponge.clone(), + &mut rng, + )) + }, + BatchSize::SmallInput, + ); + }); + } +} + +criterion_group!(benches, benchmark_ipa_open); +criterion_main!(benches); diff --git a/poly-commitment/src/chunked.rs b/poly-commitment/src/chunked.rs deleted file mode 100644 index 0538285979..0000000000 --- a/poly-commitment/src/chunked.rs +++ /dev/null @@ -1,46 +0,0 @@ -use ark_ec::CurveGroup; -use ark_ff::{Field, Zero}; -use std::ops::AddAssign; - -use crate::{commitment::CommitmentCurve, PolyComm}; - -impl PolyComm -where - C: CommitmentCurve, -{ - /// Multiplies each commitment chunk of f with powers of zeta^n - // TODO(mimoo): better name for this function - pub fn chunk_commitment(&self, zeta_n: C::ScalarField) -> Self { - let mut res = C::Group::zero(); - // use Horner's to compute chunk[0] + z^n chunk[1] + z^2n chunk[2] + ... - // as ( chunk[-1] * z^n + chunk[-2] ) * z^n + chunk[-3] - // (https://en.wikipedia.org/wiki/Horner%27s_method) - for chunk in self.elems.iter().rev() { - res *= zeta_n; - res.add_assign(chunk); - } - - PolyComm { - elems: vec![res.into_affine()], - } - } -} - -impl PolyComm -where - F: Field, -{ - /// Multiplies each blinding chunk of f with powers of zeta^n - // TODO(mimoo): better name for this function - pub fn chunk_blinding(&self, zeta_n: F) -> F { - let mut res = F::zero(); - // use Horner's to compute chunk[0] + z^n chunk[1] + z^2n chunk[2] + ... - // as ( chunk[-1] * z^n + chunk[-2] ) * z^n + chunk[-3] - // (https://en.wikipedia.org/wiki/Horner%27s_method) - for chunk in self.elems.iter().rev() { - res *= zeta_n; - res += chunk - } - res - } -} diff --git a/poly-commitment/src/commitment.rs b/poly-commitment/src/commitment.rs index ee2b96bf5f..f756e5b110 100644 --- a/poly-commitment/src/commitment.rs +++ b/poly-commitment/src/commitment.rs @@ -1,46 +1,99 @@ //! This module implements Dlog-based polynomial commitment schema. -//! The folowing functionality is implemented +//! The following functionality is implemented //! //! 1. Commit to polynomial with its max degree -//! 2. Open polynomial commitment batch at the given evaluation point and scaling factor scalar -//! producing the batched opening proof +//! 2. Open polynomial commitment batch at the given evaluation point and +//! scaling factor scalar producing the batched opening proof //! 3. Verify batch of batched opening proofs -use crate::{ - error::CommitmentError, - srs::{endos, SRS}, - SRS as SRSTrait, -}; use ark_ec::{ models::short_weierstrass::Affine as SWJAffine, short_weierstrass::SWCurveConfig, AffineRepr, CurveGroup, VariableBaseMSM, }; -use ark_ff::{BigInteger, Field, One, PrimeField, UniformRand, Zero}; -use ark_poly::{ - univariate::DensePolynomial, EvaluationDomain, Evaluations, Radix2EvaluationDomain as D, -}; +use ark_ff::{BigInteger, Field, One, PrimeField, Zero}; +use ark_poly::univariate::DensePolynomial; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; -use core::ops::{Add, AddAssign, Sub}; use groupmap::{BWParameters, GroupMap}; use mina_poseidon::{sponge::ScalarChallenge, FqSponge}; -use o1_utils::{math, ExtendedDensePolynomial as _}; -use rand_core::{CryptoRng, RngCore}; -use rayon::prelude::*; -use serde::{Deserialize, Serialize}; -use serde_with::serde_as; -use std::iter::Iterator; - -use super::evaluation_proof::*; +use o1_utils::{field_helpers::product, ExtendedDensePolynomial as _}; +use serde::{de::Visitor, Deserialize, Serialize}; +use serde_with::{ + de::DeserializeAsWrap, ser::SerializeAsWrap, serde_as, DeserializeAs, SerializeAs, +}; +use std::{ + iter::Iterator, + marker::PhantomData, + ops::{Add, AddAssign, Sub}, +}; -/// A polynomial commitment. +/// Represent a polynomial commitment when the type is instantiated with a +/// curve. +/// +/// The structure also handles chunking, i.e. when we aim to handle polynomials +/// whose degree is higher than the SRS size. For this reason, we do use a +/// vector for the field `chunks`. +/// +/// Note that the parameter `C` is not constrained to be a curve, therefore in +/// some places in the code, `C` can refer to a scalar field element. For +/// instance, `PolyComm` is used to represent the evaluation of the +/// polynomial bound by a specific commitment, at a particular evaluation point. #[serde_as] #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] #[serde(bound = "C: CanonicalDeserialize + CanonicalSerialize")] pub struct PolyComm { #[serde_as(as = "Vec")] - pub elems: Vec, + pub chunks: Vec, } +impl PolyComm +where + C: CommitmentCurve, +{ + /// Multiplies each commitment chunk of f with powers of zeta^n + pub fn chunk_commitment(&self, zeta_n: C::ScalarField) -> Self { + let mut res = C::Group::zero(); + // use Horner's to compute chunk[0] + z^n chunk[1] + z^2n chunk[2] + ... + // as ( chunk[-1] * z^n + chunk[-2] ) * z^n + chunk[-3] + // (https://en.wikipedia.org/wiki/Horner%27s_method) + for chunk in self.chunks.iter().rev() { + res *= zeta_n; + res.add_assign(chunk); + } + + PolyComm { + chunks: vec![res.into_affine()], + } + } +} + +impl PolyComm +where + F: Field, +{ + /// Multiplies each blinding chunk of f with powers of zeta^n + pub fn chunk_blinding(&self, zeta_n: F) -> F { + let mut res = F::zero(); + // use Horner's to compute chunk[0] + z^n chunk[1] + z^2n chunk[2] + ... + // as ( chunk[-1] * z^n + chunk[-2] ) * z^n + chunk[-3] + // (https://en.wikipedia.org/wiki/Horner%27s_method) + for chunk in self.chunks.iter().rev() { + res *= zeta_n; + res += chunk + } + res + } +} + +impl<'a, G> IntoIterator for &'a PolyComm { + type Item = &'a G; + type IntoIter = std::slice::Iter<'a, G>; + + fn into_iter(self) -> Self::IntoIter { + self.chunks.iter() + } +} + +/// A commitment to a polynomial with some blinding factors. #[derive(Clone, Debug, Serialize, Deserialize)] pub struct BlindedCommitment where @@ -51,51 +104,118 @@ where } impl PolyComm { - pub fn new(elems: Vec) -> Self { - Self { elems } + pub fn new(chunks: Vec) -> Self { + Self { chunks } } } -impl PolyComm +impl SerializeAs> for PolyComm where - A: CanonicalDeserialize + CanonicalSerialize, + U: SerializeAs, { + fn serialize_as(source: &PolyComm, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.collect_seq( + source + .chunks + .iter() + .map(|e| SerializeAsWrap::::new(e)), + ) + } +} + +impl<'de, T, U> DeserializeAs<'de, PolyComm> for PolyComm +where + U: DeserializeAs<'de, T>, +{ + fn deserialize_as(deserializer: D) -> Result, D::Error> + where + D: serde::Deserializer<'de>, + { + struct SeqVisitor { + marker: PhantomData<(T, U)>, + } + + impl<'de, T, U> Visitor<'de> for SeqVisitor + where + U: DeserializeAs<'de, T>, + { + type Value = PolyComm; + + fn expecting(&self, formatter: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + formatter.write_str("a sequence") + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: serde::de::SeqAccess<'de>, + { + #[allow(clippy::redundant_closure_call)] + let mut chunks = vec![]; + + while let Some(value) = seq + .next_element()? + .map(|v: DeserializeAsWrap| v.into_inner()) + { + chunks.push(value); + } + + Ok(PolyComm::new(chunks)) + } + } + + let visitor = SeqVisitor:: { + marker: PhantomData, + }; + deserializer.deserialize_seq(visitor) + } +} + +impl PolyComm { pub fn map(&self, mut f: F) -> PolyComm where F: FnMut(A) -> B, B: CanonicalDeserialize + CanonicalSerialize, { - let elems = self.elems.iter().map(|x| f(x.clone())).collect(); - PolyComm { elems } + let chunks = self.chunks.iter().map(|x| f(*x)).collect(); + PolyComm::new(chunks) } - /// Returns the length of the commitment. + /// Returns the number of chunks. pub fn len(&self) -> usize { - self.elems.len() + self.chunks.len() } /// Returns `true` if the commitment is empty. pub fn is_empty(&self) -> bool { - self.elems.is_empty() + self.chunks.is_empty() } -} -impl PolyComm { - // TODO: if all callers end up calling unwrap, just call this zip_eq and panic here (and document the panic) + // TODO: if all callers end up calling unwrap, just call this zip_eq and + // panic here (and document the panic) pub fn zip( &self, other: &PolyComm, ) -> Option> { - if self.elems.len() != other.elems.len() { + if self.chunks.len() != other.chunks.len() { return None; } - let elems = self - .elems + let chunks = self + .chunks .iter() - .zip(other.elems.iter()) + .zip(other.chunks.iter()) .map(|(x, y)| (*x, *y)) .collect(); - Some(PolyComm { elems }) + Some(PolyComm::new(chunks)) + } + + /// Return only the first chunk + /// Getting this single value is relatively common in the codebase, even + /// though we should not do this, and abstract the chunks in the structure. + pub fn get_first_chunk(&self) -> A { + self.chunks[0] } } @@ -113,9 +233,10 @@ impl PolyComm { /// |g: G, x: G::ScalarField| g.scale(2*x + 2^n) /// ``` /// -/// otherwise. So, if we want to actually scale by `s`, we need to apply the inverse function -/// of `|x| x + 2^n` (or of `|x| 2*x + 2^n` in the other case), before supplying the scalar -/// to our in-circuit scalar-multiplication function. This computes that inverse function. +/// otherwise. So, if we want to actually scale by `x`, we need to apply the +/// inverse function of `|x| x + 2^n` (or of `|x| 2*x + 2^n` in the other case), +/// before supplying the scalar to our in-circuit scalar-multiplication +/// function. This computes that inverse function. /// Namely, /// /// ```ignore @@ -150,20 +271,20 @@ impl<'a, 'b, C: AffineRepr> Add<&'a PolyComm> for &'b PolyComm { type Output = PolyComm; fn add(self, other: &'a PolyComm) -> PolyComm { - let mut elems = vec![]; - let n1 = self.elems.len(); - let n2 = other.elems.len(); + let mut chunks = vec![]; + let n1 = self.chunks.len(); + let n2 = other.chunks.len(); for i in 0..std::cmp::max(n1, n2) { let pt = if i < n1 && i < n2 { - (self.elems[i] + other.elems[i]).into_affine() + (self.chunks[i] + other.chunks[i]).into_affine() } else if i < n1 { - self.elems[i] + self.chunks[i] } else { - other.elems[i] + other.chunks[i] }; - elems.push(pt); + chunks.push(pt); } - PolyComm { elems } + PolyComm::new(chunks) } } @@ -171,27 +292,27 @@ impl<'a, 'b, C: AffineRepr + Sub> Sub<&'a PolyComm> for &' type Output = PolyComm; fn sub(self, other: &'a PolyComm) -> PolyComm { - let mut elems = vec![]; - let n1 = self.elems.len(); - let n2 = other.elems.len(); + let mut chunks = vec![]; + let n1 = self.chunks.len(); + let n2 = other.chunks.len(); for i in 0..std::cmp::max(n1, n2) { let pt = if i < n1 && i < n2 { - (self.elems[i] - other.elems[i]).into_affine() + (self.chunks[i] - other.chunks[i]).into_affine() } else if i < n1 { - self.elems[i] + self.chunks[i] } else { - other.elems[i] + other.chunks[i] }; - elems.push(pt); + chunks.push(pt); } - PolyComm { elems } + PolyComm::new(chunks) } } impl PolyComm { pub fn scale(&self, c: C::ScalarField) -> PolyComm { PolyComm { - elems: self.elems.iter().map(|g| g.mul(c).into_affine()).collect(), + chunks: self.chunks.iter().map(|g| g.mul(c).into_affine()).collect(), } } @@ -210,33 +331,24 @@ impl PolyComm { let all_scalars: Vec<_> = elm.iter().map(|s| s.into_bigint()).collect(); - let elems_size = Iterator::max(com.iter().map(|c| c.elems.len())).unwrap(); - let mut elems = Vec::with_capacity(elems_size); + let elems_size = Iterator::max(com.iter().map(|c| c.chunks.len())).unwrap(); + let mut chunks = Vec::with_capacity(elems_size); for chunk in 0..elems_size { let (points, scalars): (Vec<_>, Vec<_>) = com .iter() .zip(&all_scalars) // get rid of scalars that don't have an associated chunk - .filter_map(|(com, scalar)| com.elems.get(chunk).map(|c| (c, scalar))) + .filter_map(|(com, scalar)| com.chunks.get(chunk).map(|c| (c, scalar))) .unzip(); let chunk_msm = C::Group::msm_bigint(&points, &scalars); - elems.push(chunk_msm.into_affine()); + chunks.push(chunk_msm.into_affine()); } - Self::new(elems) + Self::new(chunks) } } -/// Returns the product of all the field elements belonging to an iterator. -pub fn product(xs: impl Iterator) -> F { - let mut res = F::one(); - for x in xs { - res *= &x; - } - res -} - /// Returns (1 + chal[-1] x)(1 + chal[-2] x^2)(1 + chal[-3] x^4) ... /// It's "step 8: Define the univariate polynomial" of /// appendix A.2 of @@ -266,17 +378,6 @@ pub fn b_poly_coefficients(chals: &[F]) -> Vec { s } -/// `pows(d, x)` returns a vector containing the first `d` powers of the field element `x` (from `1` to `x^(d-1)`). -pub fn pows(d: usize, x: F) -> Vec { - let mut acc = F::one(); - let mut res = vec![]; - for _ in 1..=d { - res.push(acc); - acc *= x; - } - res -} - pub fn squeeze_prechallenge>( sponge: &mut EFqSponge, ) -> ScalarChallenge { @@ -294,7 +395,7 @@ pub fn absorb_commitment, ) { - sponge.absorb_g(&commitment.elems); + sponge.absorb_g(&commitment.chunks); } /// A useful trait extending AffineRepr for commitments. @@ -309,7 +410,7 @@ pub trait CommitmentCurve: AffineRepr + Sub { } /// A trait extending CommitmentCurve for endomorphisms. -/// Unfortunately, we can't specify that `AffineCurve`, +/// Unfortunately, we can't specify that `AffineRepr`, /// so usage of this traits must manually bind `G::BaseField: PrimeField`. pub trait EndoCurve: CommitmentCurve { /// Combine where x1 = one @@ -380,14 +481,26 @@ impl EndoCurve for SWJAffine

{ } } -pub fn to_group(m: &G::Map, t: ::BaseField) -> G { - let (x, y) = m.to_group(t); - G::of_coordinates(x, y) -} - -/// Computes the linearization of the evaluations of a (potentially split) polynomial. -/// Each given `poly` is associated to a matrix where the rows represent the number of evaluated points, -/// and the columns represent potential segments (if a polynomial was split in several parts). +/// Computes the linearization of the evaluations of a (potentially +/// split) polynomial. +/// +/// Each polynomial in `polys` is represented by a matrix where the +/// rows correspond to evaluated points, and the columns represent +/// potential segments (if a polynomial was split in several parts). +/// +/// Elements in `evaluation_points` are several discrete points on which +/// we evaluate polynomials, e.g. `[zeta,zeta*w]`. See `PointEvaluations`. +/// +/// Note that if one of the polynomial comes specified with a degree +/// bound, the evaluation for the last segment is potentially shifted +/// to meet the proof. +/// +/// Returns +/// ```text +/// |polys| |segments[k]| +/// Σ Σ polyscale^{k*n+i} (Σ polys[k][j][i] * evalscale^j) +/// k = 1 i = 1 j +/// ``` #[allow(clippy::type_complexity)] pub fn combined_inner_product( polyscale: &F, @@ -395,21 +508,31 @@ pub fn combined_inner_product( // TODO(mimoo): needs a type that can get you evaluations or segments polys: &[Vec>], ) -> F { + // final combined evaluation result let mut res = F::zero(); - let mut xi_i = F::one(); + // polyscale^i + let mut polyscale_i = F::one(); for evals_tr in polys.iter().filter(|evals_tr| !evals_tr[0].is_empty()) { - // transpose the evaluations - let evals = (0..evals_tr[0].len()) + // Transpose the evaluations. + // evals[i] = {evals_tr[j][i]}_j now corresponds to a column in evals_tr, + // representing a segment. + let evals: Vec<_> = (0..evals_tr[0].len()) .map(|i| evals_tr.iter().map(|v| v[i]).collect::>()) - .collect::>(); + .collect(); - // iterating over the polynomial segments + // Iterating over the polynomial segments. + // Each segment gets its own polyscale^i, each segment element j is multiplied by evalscale^j. + // Given that polyscale_i = polyscale^i0 at this point, after this loop we have: + // + // res += Σ polyscale^{i0+i} ( Σ evals_tr[j][i] * evalscale^j ) + // i j + // for eval in &evals { + // p_i(evalscale) let term = DensePolynomial::::eval_polynomial(eval, *evalscale); - - res += &(xi_i * term); - xi_i *= polyscale; + res += &(polyscale_i * term); + polyscale_i *= polyscale; } } res @@ -420,33 +543,62 @@ pub struct Evaluation where G: AffineRepr, { - /// The commitment of the polynomial being evaluated + /// The commitment of the polynomial being evaluated. + /// Note that PolyComm contains a vector of commitments, which handles the + /// case when chunking is used, i.e. when the polynomial degree is higher + /// than the SRS size. pub commitment: PolyComm, - /// Contains an evaluation table + /// Contains an evaluation table. For instance, for vanilla PlonK, it + /// would be a vector of (chunked) evaluations at ζ and ζω. + /// The outer vector would be the evaluations at the different points (e.g. + /// ζ and ζω for vanilla PlonK) and the inner vector would be the chunks of + /// the polynomial. pub evaluations: Vec>, } /// Contains the batch evaluation -// TODO: I think we should really change this name to something more correct pub struct BatchEvaluationProof<'a, G, EFqSponge, OpeningProof> where G: AffineRepr, EFqSponge: FqSponge, { + /// Sponge used to coin and absorb values and simulate + /// non-interactivity using the Fiat-Shamir transformation. pub sponge: EFqSponge, + /// A list of evaluations, each supposed to correspond to a different + /// polynomial. pub evaluations: Vec>, - /// vector of evaluation points + /// The actual evaluation points. Each field `evaluations` of each structure + /// of `Evaluation` should have the same (outer) length. pub evaluation_points: Vec, - /// scaling factor for evaluation point powers + /// A challenge to combine polynomials. Powers of this point will be used, + /// hence the name. pub polyscale: G::ScalarField, - /// scaling factor for polynomials + /// A challenge to aggregate multiple evaluation points. pub evalscale: G::ScalarField, - /// batched opening proof + /// The opening proof. pub opening: &'a OpeningProof, pub combined_inner_product: G::ScalarField, } +/// This function populates the parameters `scalars` and `points`. +/// It iterates over the evaluations and adds each commitment to the +/// vector `points`. +/// The parameter `scalars` is populated with the values: +/// `rand_base * polyscale^i` for each commitment. +/// For instance, if we have 3 commitments, the `scalars` vector will +/// contain the values +/// ```text +/// [rand_base, rand_base * polyscale, rand_base * polyscale^2]` +/// ``` +/// and the vector `points` will contain the commitments. +/// +/// Note that the function skips the commitments that are empty. +/// +/// If more than one commitment is present in a single evaluation (i.e. if +/// `elems` is larger than one), it means that probably chunking was used (i.e. +/// it is a commitment to a polynomial larger than the SRS). pub fn combine_commitments( evaluations: &[Evaluation], scalars: &mut Vec, @@ -454,571 +606,26 @@ pub fn combine_commitments( polyscale: G::ScalarField, rand_base: G::ScalarField, ) { - let mut xi_i = G::ScalarField::one(); + // will contain the power of polyscale + let mut polyscale_i = G::ScalarField::one(); - for Evaluation { commitment, .. } in evaluations - .iter() - .filter(|x| !x.commitment.elems.is_empty()) - { + for Evaluation { commitment, .. } in evaluations.iter().filter(|x| !x.commitment.is_empty()) { // iterating over the polynomial segments - for comm_ch in &commitment.elems { - scalars.push(rand_base * xi_i); + for comm_ch in &commitment.chunks { + scalars.push(rand_base * polyscale_i); points.push(*comm_ch); - xi_i *= polyscale; - } - } -} - -pub fn combine_evaluations( - evaluations: &Vec>, - polyscale: G::ScalarField, -) -> Vec { - let mut xi_i = G::ScalarField::one(); - let mut acc = { - let num_evals = if !evaluations.is_empty() { - evaluations[0].evaluations.len() - } else { - 0 - }; - vec![G::ScalarField::zero(); num_evals] - }; - - for Evaluation { evaluations, .. } in evaluations - .iter() - .filter(|x| !x.commitment.elems.is_empty()) - { - // iterating over the polynomial segments - for j in 0..evaluations[0].len() { - for i in 0..evaluations.len() { - acc[i] += evaluations[i][j] * xi_i; - } - xi_i *= polyscale; + // compute next power of polyscale + polyscale_i *= polyscale; } } - - acc } -impl SRSTrait for SRS { - /// The maximum polynomial degree that can be committed to - fn max_poly_size(&self) -> usize { - self.g.len() - } - - fn get_lagrange_basis(&self, domain_size: usize) -> &Vec> { - self.get_lagrange_basis_from_domain_size(domain_size) - } - - fn blinding_commitment(&self) -> G { - self.h - } - - /// Commits a polynomial, potentially splitting the result in multiple commitments. - fn commit( - &self, - plnm: &DensePolynomial, - num_chunks: usize, - rng: &mut (impl RngCore + CryptoRng), - ) -> BlindedCommitment { - self.mask(self.commit_non_hiding(plnm, num_chunks), rng) - } - - /// Turns a non-hiding polynomial commitment into a hidding polynomial commitment. Transforms each given `` into `( + wH, w)` with a random `w` per commitment. - fn mask( - &self, - comm: PolyComm, - rng: &mut (impl RngCore + CryptoRng), - ) -> BlindedCommitment { - let blinders = comm.map(|_| G::ScalarField::rand(rng)); - self.mask_custom(comm, &blinders).unwrap() - } - - /// Same as [SRS::mask] except that you can pass the blinders manually. - fn mask_custom( - &self, - com: PolyComm, - blinders: &PolyComm, - ) -> Result, CommitmentError> { - let commitment = com - .zip(blinders) - .ok_or_else(|| CommitmentError::BlindersDontMatch(blinders.len(), com.len()))? - .map(|(g, b)| { - let mut g_masked = self.h.mul(b); - g_masked.add_assign(&g); - g_masked.into_affine() - }); - Ok(BlindedCommitment { - commitment, - blinders: blinders.clone(), - }) - } - - /// This function commits a polynomial using the SRS' basis of size `n`. - /// - `plnm`: polynomial to commit to with max size of sections - /// - `num_chunks`: the number of commitments to be included in the output polynomial commitment - /// The function returns an unbounded commitment vector - /// (which splits the commitment into several commitments of size at most `n`). - fn commit_non_hiding( - &self, - plnm: &DensePolynomial, - num_chunks: usize, - ) -> PolyComm { - let is_zero = plnm.is_zero(); - - let coeffs: Vec<_> = plnm.iter().map(|c| c.into_bigint()).collect(); - - // chunk while commiting - let mut elems = vec![]; - if is_zero { - elems.push(G::zero()); - } else { - coeffs.chunks(self.g.len()).for_each(|coeffs_chunk| { - let chunk = G::Group::msm_bigint(&self.g, coeffs_chunk); - elems.push(chunk.into_affine()); - }); - } - - for _ in elems.len()..num_chunks { - elems.push(G::zero()); - } - - PolyComm:: { elems } - } - - fn commit_evaluations_non_hiding( - &self, - domain: D, - plnm: &Evaluations>, - ) -> PolyComm { - let basis = self.get_lagrange_basis(domain); - let commit_evaluations = |evals: &Vec, basis: &Vec>| { - PolyComm::::multi_scalar_mul(&basis.iter().collect::>()[..], &evals[..]) - }; - match domain.size.cmp(&plnm.domain().size) { - std::cmp::Ordering::Less => { - let s = (plnm.domain().size / domain.size) as usize; - let v: Vec<_> = (0..(domain.size())).map(|i| plnm.evals[s * i]).collect(); - commit_evaluations(&v, basis) - } - std::cmp::Ordering::Equal => commit_evaluations(&plnm.evals, basis), - std::cmp::Ordering::Greater => { - panic!("desired commitment domain size greater than evaluations' domain size") - } - } - } - - fn commit_evaluations( - &self, - domain: D, - plnm: &Evaluations>, - rng: &mut (impl RngCore + CryptoRng), - ) -> BlindedCommitment { - self.mask(self.commit_evaluations_non_hiding(domain, plnm), rng) - } -} - -impl SRS { - /// This function verifies batch of batched polynomial commitment opening proofs - /// batch: batch of batched polynomial commitment opening proofs - /// vector of evaluation points - /// polynomial scaling factor for this batched openinig proof - /// eval scaling factor for this batched openinig proof - /// batch/vector of polycommitments (opened in this batch), evaluation vectors and, optionally, max degrees - /// opening proof for this batched opening - /// oracle_params: parameters for the random oracle argument - /// randomness source context - /// RETURN: verification status - pub fn verify( - &self, - group_map: &G::Map, - batch: &mut [BatchEvaluationProof>], - rng: &mut RNG, - ) -> bool - where - EFqSponge: FqSponge, - RNG: RngCore + CryptoRng, - G::BaseField: PrimeField, - { - // Verifier checks for all i, - // c_i Q_i + delta_i = z1_i (G_i + b_i U_i) + z2_i H - // - // if we sample evalscale at random, it suffices to check - // - // 0 == sum_i evalscale^i (c_i Q_i + delta_i - ( z1_i (G_i + b_i U_i) + z2_i H )) - // - // and because each G_i is a multiexp on the same array self.g, we - // can batch the multiexp across proofs. - // - // So for each proof in the batch, we add onto our big multiexp the following terms - // evalscale^i c_i Q_i - // evalscale^i delta_i - // - (evalscale^i z1_i) G_i - // - (evalscale^i z2_i) H - // - (evalscale^i z1_i b_i) U_i - - // We also check that the sg component of the proof is equal to the polynomial commitment - // to the "s" array - - let nonzero_length = self.g.len(); - - let max_rounds = math::ceil_log2(nonzero_length); - - let padded_length = 1 << max_rounds; - - let (_, endo_r) = endos::(); - - // TODO: This will need adjusting - let padding = padded_length - nonzero_length; - let mut points = vec![self.h]; - points.extend(self.g.clone()); - points.extend(vec![G::zero(); padding]); - - let mut scalars = vec![G::ScalarField::zero(); padded_length + 1]; - assert_eq!(scalars.len(), points.len()); - - // sample randomiser to scale the proofs with - let rand_base = G::ScalarField::rand(rng); - let sg_rand_base = G::ScalarField::rand(rng); - - let mut rand_base_i = G::ScalarField::one(); - let mut sg_rand_base_i = G::ScalarField::one(); - - for BatchEvaluationProof { - sponge, - evaluation_points, - polyscale, - evalscale, - evaluations, - opening, - combined_inner_product, - } in batch.iter_mut() - { - sponge.absorb_fr(&[shift_scalar::(*combined_inner_product)]); - - let t = sponge.challenge_fq(); - let u: G = to_group(group_map, t); - - let Challenges { chal, chal_inv } = opening.challenges::(&endo_r, sponge); - - sponge.absorb_g(&[opening.delta]); - let c = ScalarChallenge(sponge.challenge()).to_field(&endo_r); - - // < s, sum_i evalscale^i pows(evaluation_point[i]) > - // == - // sum_i evalscale^i < s, pows(evaluation_point[i]) > - let b0 = { - let mut scale = G::ScalarField::one(); - let mut res = G::ScalarField::zero(); - for &e in evaluation_points.iter() { - let term = b_poly(&chal, e); - res += &(scale * term); - scale *= *evalscale; - } - res - }; - - let s = b_poly_coefficients(&chal); - - let neg_rand_base_i = -rand_base_i; - - // TERM - // - rand_base_i z1 G - // - // we also add -sg_rand_base_i * G to check correctness of sg. - points.push(opening.sg); - scalars.push(neg_rand_base_i * opening.z1 - sg_rand_base_i); - - // Here we add - // sg_rand_base_i * ( < s, self.g > ) - // = - // < sg_rand_base_i s, self.g > - // - // to check correctness of the sg component. - { - let terms: Vec<_> = s.par_iter().map(|s| sg_rand_base_i * s).collect(); - - for (i, term) in terms.iter().enumerate() { - scalars[i + 1] += term; - } - } - - // TERM - // - rand_base_i * z2 * H - scalars[0] -= &(rand_base_i * opening.z2); - - // TERM - // -rand_base_i * (z1 * b0 * U) - scalars.push(neg_rand_base_i * (opening.z1 * b0)); - points.push(u); - - // TERM - // rand_base_i c_i Q_i - // = rand_base_i c_i - // (sum_j (chal_invs[j] L_j + chals[j] R_j) + P_prime) - // where P_prime = combined commitment + combined_inner_product * U - let rand_base_i_c_i = c * rand_base_i; - for ((l, r), (u_inv, u)) in opening.lr.iter().zip(chal_inv.iter().zip(chal.iter())) { - points.push(*l); - scalars.push(rand_base_i_c_i * u_inv); - - points.push(*r); - scalars.push(rand_base_i_c_i * u); - } - - // TERM - // sum_j evalscale^j (sum_i polyscale^i f_i) (elm_j) - // == sum_j sum_i evalscale^j polyscale^i f_i(elm_j) - // == sum_i polyscale^i sum_j evalscale^j f_i(elm_j) - combine_commitments( - evaluations, - &mut scalars, - &mut points, - *polyscale, - rand_base_i_c_i, - ); - - scalars.push(rand_base_i_c_i * *combined_inner_product); - points.push(u); - - scalars.push(rand_base_i); - points.push(opening.delta); - - rand_base_i *= &rand_base; - sg_rand_base_i *= &sg_rand_base; - } - - // verify the equation - let scalars: Vec<_> = scalars.iter().map(|x| x.into_bigint()).collect(); - G::Group::msm_bigint(&points, &scalars) == G::Group::zero() - } -} - -pub fn inner_prod(xs: &[F], ys: &[F]) -> F { - let mut res = F::zero(); - for (&x, y) in xs.iter().zip(ys) { - res += &(x * y); - } - res -} - -// -// Tests -// - -#[cfg(test)] -mod tests { - use super::*; - - use crate::srs::SRS; - use ark_poly::{DenseUVPolynomial, Polynomial, Radix2EvaluationDomain}; - use mina_curves::pasta::{Fp, Vesta as VestaG}; - use mina_poseidon::{constants::PlonkSpongeConstantsKimchi as SC, sponge::DefaultFqSponge}; - use std::array; - - #[test] - fn test_lagrange_commitments() { - let n = 64; - let domain = D::::new(n).unwrap(); - - let srs = SRS::::create(n); - srs.get_lagrange_basis(domain); - - let num_chunks = domain.size() / srs.g.len(); - - let expected_lagrange_commitments: Vec<_> = (0..n) - .map(|i| { - let mut e = vec![Fp::zero(); n]; - e[i] = Fp::one(); - let p = Evaluations::>::from_vec_and_domain(e, domain).interpolate(); - srs.commit_non_hiding(&p, num_chunks) - }) - .collect(); - - let computed_lagrange_commitments = srs.get_lagrange_basis_from_domain_size(domain.size()); - for i in 0..n { - assert_eq!( - computed_lagrange_commitments[i], - expected_lagrange_commitments[i], - ); - } - } - - #[test] - // This tests with two chunks. - fn test_chunked_lagrange_commitments() { - let n = 64; - let divisor = 4; - let domain = D::::new(n).unwrap(); - - let srs = SRS::::create(n / divisor); - srs.get_lagrange_basis(domain); - - let num_chunks = domain.size() / srs.g.len(); - assert!(num_chunks == divisor); - - let expected_lagrange_commitments: Vec<_> = (0..n) - .map(|i| { - let mut e = vec![Fp::zero(); n]; - e[i] = Fp::one(); - let p = Evaluations::>::from_vec_and_domain(e, domain).interpolate(); - srs.commit_non_hiding(&p, num_chunks) - }) - .collect(); - - let computed_lagrange_commitments = srs.get_lagrange_basis_from_domain_size(domain.size()); - for i in 0..n { - assert_eq!( - computed_lagrange_commitments[i], - expected_lagrange_commitments[i], - ); - } - } - - #[test] - // TODO @volhovm I don't understand what this test does and - // whether it is worth leaving. - /// Same as test_chunked_lagrange_commitments, but with a slight - /// offset in the SRS - fn test_offset_chunked_lagrange_commitments() { - let n = 64; - let domain = D::::new(n).unwrap(); - - let srs = SRS::::create(n / 2 + 1); - srs.get_lagrange_basis(domain); - - // Is this even taken into account?... - let num_chunks = (domain.size() + srs.g.len() - 1) / srs.g.len(); - assert!(num_chunks == 2); - - let expected_lagrange_commitments: Vec<_> = (0..n) - .map(|i| { - let mut e = vec![Fp::zero(); n]; - e[i] = Fp::one(); - let p = Evaluations::>::from_vec_and_domain(e, domain).interpolate(); - srs.commit_non_hiding(&p, num_chunks) // this requires max = Some(64) - }) - .collect(); - - let computed_lagrange_commitments = srs.get_lagrange_basis_from_domain_size(domain.size()); - for i in 0..n { - assert_eq!( - computed_lagrange_commitments[i], - expected_lagrange_commitments[i], - ); - } - } - - #[test] - fn test_opening_proof() { - // create two polynomials - let coeffs: [Fp; 10] = array::from_fn(|i| Fp::from(i as u32)); - let poly1 = DensePolynomial::::from_coefficients_slice(&coeffs); - let poly2 = DensePolynomial::::from_coefficients_slice(&coeffs[..5]); - - // create an SRS - let srs = SRS::::create(20); - let rng = &mut o1_utils::tests::make_test_rng(None); - - // commit the two polynomials - let commitment1 = srs.commit(&poly1, 1, rng); - let commitment2 = srs.commit(&poly2, 1, rng); - - // create an aggregated opening proof - let (u, v) = (Fp::rand(rng), Fp::rand(rng)); - let group_map = ::Map::setup(); - let sponge = - DefaultFqSponge::<_, SC>::new(mina_poseidon::pasta::fq_kimchi::static_params()); - - let polys: Vec<( - DensePolynomialOrEvaluations<_, Radix2EvaluationDomain<_>>, - PolyComm<_>, - )> = vec![ - ( - DensePolynomialOrEvaluations::DensePolynomial(&poly1), - commitment1.blinders, - ), - ( - DensePolynomialOrEvaluations::DensePolynomial(&poly2), - commitment2.blinders, - ), - ]; - let elm = vec![Fp::rand(rng), Fp::rand(rng)]; - - let opening_proof = srs.open(&group_map, &polys, &elm, v, u, sponge.clone(), rng); - - // evaluate the polynomials at these two points - let poly1_chunked_evals = vec![ - poly1 - .to_chunked_polynomial(1, srs.g.len()) - .evaluate_chunks(elm[0]), - poly1 - .to_chunked_polynomial(1, srs.g.len()) - .evaluate_chunks(elm[1]), - ]; - - fn sum(c: &[Fp]) -> Fp { - c.iter().fold(Fp::zero(), |a, &b| a + b) - } - - assert_eq!(sum(&poly1_chunked_evals[0]), poly1.evaluate(&elm[0])); - assert_eq!(sum(&poly1_chunked_evals[1]), poly1.evaluate(&elm[1])); - - let poly2_chunked_evals = vec![ - poly2 - .to_chunked_polynomial(1, srs.g.len()) - .evaluate_chunks(elm[0]), - poly2 - .to_chunked_polynomial(1, srs.g.len()) - .evaluate_chunks(elm[1]), - ]; - - assert_eq!(sum(&poly2_chunked_evals[0]), poly2.evaluate(&elm[0])); - assert_eq!(sum(&poly2_chunked_evals[1]), poly2.evaluate(&elm[1])); - - let evaluations = vec![ - Evaluation { - commitment: commitment1.commitment, - evaluations: poly1_chunked_evals, - }, - Evaluation { - commitment: commitment2.commitment, - evaluations: poly2_chunked_evals, - }, - ]; - - let combined_inner_product = { - let es: Vec<_> = evaluations - .iter() - .map(|Evaluation { evaluations, .. }| evaluations.clone()) - .collect(); - combined_inner_product(&v, &u, &es) - }; - - // verify the proof - let mut batch = vec![BatchEvaluationProof { - sponge, - evaluation_points: elm.clone(), - polyscale: v, - evalscale: u, - evaluations, - opening: &opening_proof, - combined_inner_product, - }]; - - assert!(srs.verify(&group_map, &mut batch, rng)); - } -} - -// -// OCaml types -// - #[cfg(feature = "ocaml_types")] pub mod caml { - use super::*; - // polynomial commitment + use super::PolyComm; + use ark_ec::AffineRepr; #[derive(Clone, Debug, ocaml::IntoValue, ocaml::FromValue, ocaml_gen::Struct)] pub struct CamlPolyComm { @@ -1035,7 +642,7 @@ pub mod caml { { fn from(polycomm: PolyComm) -> Self { Self { - unshifted: polycomm.elems.into_iter().map(CamlG::from).collect(), + unshifted: polycomm.chunks.into_iter().map(CamlG::from).collect(), shifted: None, } } @@ -1048,7 +655,7 @@ pub mod caml { { fn from(polycomm: &'a PolyComm) -> Self { Self { - unshifted: polycomm.elems.iter().map(Into::::into).collect(), + unshifted: polycomm.chunks.iter().map(Into::::into).collect(), shifted: None, } } @@ -1064,7 +671,7 @@ pub mod caml { "mina#14628: Shifted commitments are deprecated and must not be used" ); PolyComm { - elems: camlpolycomm + chunks: camlpolycomm .unshifted .into_iter() .map(Into::::into) @@ -1084,61 +691,7 @@ pub mod caml { ); PolyComm { //FIXME something with as_ref() - elems: camlpolycomm.unshifted.iter().map(Into::into).collect(), - } - } - } - - // opening proof - - #[derive(ocaml::IntoValue, ocaml::FromValue, ocaml_gen::Struct)] - pub struct CamlOpeningProof { - /// vector of rounds of L & R commitments - pub lr: Vec<(G, G)>, - pub delta: G, - pub z1: F, - pub z2: F, - pub sg: G, - } - - impl From> for CamlOpeningProof - where - G: AffineRepr, - CamlG: From, - CamlF: From, - { - fn from(opening_proof: OpeningProof) -> Self { - Self { - lr: opening_proof - .lr - .into_iter() - .map(|(g1, g2)| (CamlG::from(g1), CamlG::from(g2))) - .collect(), - delta: CamlG::from(opening_proof.delta), - z1: opening_proof.z1.into(), - z2: opening_proof.z2.into(), - sg: CamlG::from(opening_proof.sg), - } - } - } - - impl From> for OpeningProof - where - G: AffineRepr, - CamlG: Into, - CamlF: Into, - { - fn from(caml: CamlOpeningProof) -> Self { - Self { - lr: caml - .lr - .into_iter() - .map(|(g1, g2)| (g1.into(), g2.into())) - .collect(), - delta: caml.delta.into(), - z1: caml.z1.into(), - z2: caml.z2.into(), - sg: caml.sg.into(), + chunks: camlpolycomm.unshifted.iter().map(Into::into).collect(), } } } diff --git a/poly-commitment/src/evaluation_proof.rs b/poly-commitment/src/evaluation_proof.rs deleted file mode 100644 index 71cf472aeb..0000000000 --- a/poly-commitment/src/evaluation_proof.rs +++ /dev/null @@ -1,462 +0,0 @@ -use crate::{ - commitment::*, - srs::{endos, SRS}, - PolynomialsToCombine, SRS as _, -}; -use ark_ec::{AffineRepr, CurveGroup, VariableBaseMSM}; -use ark_ff::{FftField, Field, One, PrimeField, UniformRand, Zero}; -use ark_poly::{univariate::DensePolynomial, DenseUVPolynomial, EvaluationDomain, Evaluations}; -use mina_poseidon::{sponge::ScalarChallenge, FqSponge}; -use o1_utils::{math, ExtendedDensePolynomial}; -use rand_core::{CryptoRng, RngCore}; -use rayon::prelude::*; -use serde::{Deserialize, Serialize}; -use serde_with::serde_as; -use std::iter::Iterator; - -// A formal sum of the form -// `s_0 * p_0 + ... s_n * p_n` -// where each `s_i` is a scalar and each `p_i` is a polynomial. -#[derive(Default)] -struct ScaledChunkedPolynomial(Vec<(F, P)>); - -pub enum DensePolynomialOrEvaluations<'a, F: FftField, D: EvaluationDomain> { - DensePolynomial(&'a DensePolynomial), - Evaluations(&'a Evaluations, D), -} - -impl ScaledChunkedPolynomial { - fn add_poly(&mut self, scale: F, p: P) { - self.0.push((scale, p)) - } -} - -impl<'a, F: Field> ScaledChunkedPolynomial { - fn to_dense_polynomial(&self) -> DensePolynomial { - let mut res = DensePolynomial::::zero(); - - let scaled: Vec<_> = self - .0 - .par_iter() - .map(|(scale, segment)| { - let scale = *scale; - let v = segment.par_iter().map(|x| scale * *x).collect(); - DensePolynomial::from_coefficients_vec(v) - }) - .collect(); - - for p in scaled { - res += &p; - } - - res - } -} - -/// Combine the polynomials using `polyscale`, creating a single unified polynomial to open. -pub fn combine_polys>( - plnms: PolynomialsToCombine, // vector of polynomial with optional degree bound and commitment randomness - polyscale: G::ScalarField, // scaling factor for polynoms - srs_length: usize, -) -> (DensePolynomial, G::ScalarField) { - let mut plnm = ScaledChunkedPolynomial::::default(); - let mut plnm_evals_part = { - // For now just check that all the evaluation polynomials are the same degree so that we - // can do just a single FFT. - // Furthermore we check they have size less than the SRS size so we don't have to do chunking. - // If/when we change this, we can add more complicated code to handle different degrees. - let degree = plnms - .iter() - .fold(None, |acc, (p, _)| match p { - DensePolynomialOrEvaluations::DensePolynomial(_) => acc, - DensePolynomialOrEvaluations::Evaluations(_, d) => { - if let Some(n) = acc { - assert_eq!(n, d.size()); - } - Some(d.size()) - } - }) - .unwrap_or(0); - vec![G::ScalarField::zero(); degree] - }; - - let mut omega = G::ScalarField::zero(); - let mut scale = G::ScalarField::one(); - - // iterating over polynomials in the batch - for (p_i, omegas) in plnms { - match p_i { - DensePolynomialOrEvaluations::Evaluations(evals_i, sub_domain) => { - let stride = evals_i.evals.len() / sub_domain.size(); - let evals = &evals_i.evals; - plnm_evals_part - .par_iter_mut() - .enumerate() - .for_each(|(i, x)| { - *x += scale * evals[i * stride]; - }); - for j in 0..omegas.elems.len() { - omega += &(omegas.elems[j] * scale); - scale *= &polyscale; - } - } - - DensePolynomialOrEvaluations::DensePolynomial(p_i) => { - let mut offset = 0; - // iterating over chunks of the polynomial - for j in 0..omegas.elems.len() { - let segment = &p_i.coeffs[std::cmp::min(offset, p_i.coeffs.len()) - ..std::cmp::min(offset + srs_length, p_i.coeffs.len())]; - plnm.add_poly(scale, segment); - - omega += &(omegas.elems[j] * scale); - scale *= &polyscale; - offset += srs_length; - } - } - } - } - - let mut plnm = plnm.to_dense_polynomial(); - if !plnm_evals_part.is_empty() { - let n = plnm_evals_part.len(); - let max_poly_size = srs_length; - let num_chunks = if n == 0 { - 1 - } else { - n / max_poly_size + if n % max_poly_size == 0 { 0 } else { 1 } - }; - plnm += &Evaluations::from_vec_and_domain(plnm_evals_part, D::new(n).unwrap()) - .interpolate() - .to_chunked_polynomial(num_chunks, max_poly_size) - .linearize(polyscale); - } - - (plnm, omega) -} - -impl SRS { - /// This function opens polynomial commitments in batch - /// plnms: batch of polynomials to open commitments for with, optionally, max degrees - /// elm: evaluation point vector to open the commitments at - /// polyscale: polynomial scaling factor for opening commitments in batch - /// evalscale: eval scaling factor for opening commitments in batch - /// oracle_params: parameters for the random oracle argument - /// RETURN: commitment opening proof - #[allow(clippy::too_many_arguments)] - #[allow(clippy::type_complexity)] - #[allow(clippy::many_single_char_names)] - pub fn open>( - &self, - group_map: &G::Map, - // TODO(mimoo): create a type for that entry - plnms: PolynomialsToCombine, // vector of polynomial with commitment randomness - elm: &[G::ScalarField], // vector of evaluation points - polyscale: G::ScalarField, // scaling factor for polynoms - evalscale: G::ScalarField, // scaling factor for evaluation point powers - mut sponge: EFqSponge, // sponge - rng: &mut RNG, - ) -> OpeningProof - where - EFqSponge: Clone + FqSponge, - RNG: RngCore + CryptoRng, - G::BaseField: PrimeField, - G: EndoCurve, - { - let (endo_q, endo_r) = endos::(); - - let rounds = math::ceil_log2(self.g.len()); - let padded_length = 1 << rounds; - - // TODO: Trim this to the degree of the largest polynomial - let padding = padded_length - self.g.len(); - let mut g = self.g.clone(); - g.extend(vec![G::zero(); padding]); - - let (p, blinding_factor) = combine_polys::(plnms, polyscale, self.g.len()); - - let rounds = math::ceil_log2(self.g.len()); - - // b_j = sum_i r^i elm_i^j - let b_init = { - // randomise/scale the eval powers - let mut scale = G::ScalarField::one(); - let mut res: Vec = - (0..padded_length).map(|_| G::ScalarField::zero()).collect(); - for e in elm { - for (i, t) in pows(padded_length, *e).iter().enumerate() { - res[i] += &(scale * t); - } - scale *= &evalscale; - } - res - }; - - let combined_inner_product = p - .coeffs - .iter() - .zip(b_init.iter()) - .map(|(a, b)| *a * b) - .fold(G::ScalarField::zero(), |acc, x| acc + x); - - sponge.absorb_fr(&[shift_scalar::(combined_inner_product)]); - - let t = sponge.challenge_fq(); - let u: G = to_group(group_map, t); - - let mut a = p.coeffs; - assert!(padded_length >= a.len()); - a.extend(vec![G::ScalarField::zero(); padded_length - a.len()]); - - let mut b = b_init; - - let mut lr = vec![]; - - let mut blinders = vec![]; - - let mut chals = vec![]; - let mut chal_invs = vec![]; - - for _ in 0..rounds { - let n = g.len() / 2; - let (g_lo, g_hi) = (g[0..n].to_vec(), g[n..].to_vec()); - let (a_lo, a_hi) = (&a[0..n], &a[n..]); - let (b_lo, b_hi) = (&b[0..n], &b[n..]); - - let rand_l = ::rand(rng); - let rand_r = ::rand(rng); - - let l = G::Group::msm_bigint( - &[&g[0..n], &[self.h, u]].concat(), - &[&a[n..], &[rand_l, inner_prod(a_hi, b_lo)]] - .concat() - .iter() - .map(|x| x.into_bigint()) - .collect::>(), - ) - .into_affine(); - - let r = G::Group::msm_bigint( - &[&g[n..], &[self.h, u]].concat(), - &[&a[0..n], &[rand_r, inner_prod(a_lo, b_hi)]] - .concat() - .iter() - .map(|x| x.into_bigint()) - .collect::>(), - ) - .into_affine(); - - lr.push((l, r)); - blinders.push((rand_l, rand_r)); - - sponge.absorb_g(&[l]); - sponge.absorb_g(&[r]); - - let u_pre = squeeze_prechallenge(&mut sponge); - let u = u_pre.to_field(&endo_r); - let u_inv = u.inverse().unwrap(); - - chals.push(u); - chal_invs.push(u_inv); - - a = a_hi - .par_iter() - .zip(a_lo) - .map(|(&hi, &lo)| { - // lo + u_inv * hi - let mut res = hi; - res *= u_inv; - res += &lo; - res - }) - .collect(); - - b = b_lo - .par_iter() - .zip(b_hi) - .map(|(&lo, &hi)| { - // lo + u * hi - let mut res = hi; - res *= u; - res += &lo; - res - }) - .collect(); - - g = G::combine_one_endo(endo_r, endo_q, &g_lo, &g_hi, u_pre); - } - - assert!(g.len() == 1); - let a0 = a[0]; - let b0 = b[0]; - let g0 = g[0]; - - let r_prime = blinders - .iter() - .zip(chals.iter().zip(chal_invs.iter())) - .map(|((l, r), (u, u_inv))| ((*l) * u_inv) + (*r * u)) - .fold(blinding_factor, |acc, x| acc + x); - - let d = ::rand(rng); - let r_delta = ::rand(rng); - - let delta = ((g0.into_group() + (u.mul(b0))).into_affine().mul(d) + self.h.mul(r_delta)) - .into_affine(); - - sponge.absorb_g(&[delta]); - let c = ScalarChallenge(sponge.challenge()).to_field(&endo_r); - - let z1 = a0 * c + d; - let z2 = c * r_prime + r_delta; - - OpeningProof { - delta, - lr, - z1, - z2, - sg: g0, - } - } - - /// This function is a debugging helper. - #[allow(clippy::too_many_arguments)] - #[allow(clippy::type_complexity)] - #[allow(clippy::many_single_char_names)] - pub fn prover_polynomials_to_verifier_evaluations>( - &self, - plnms: PolynomialsToCombine, - elm: &[G::ScalarField], // vector of evaluation points - ) -> Vec> - where - G::BaseField: PrimeField, - { - plnms - .iter() - .enumerate() - .map(|(i, (poly_or_evals, blinders))| { - let poly = match poly_or_evals { - DensePolynomialOrEvaluations::DensePolynomial(poly) => (*poly).clone(), - DensePolynomialOrEvaluations::Evaluations(evals, _) => { - (*evals).clone().interpolate() - } - }; - let chunked_polynomial = - poly.to_chunked_polynomial(blinders.elems.len(), self.g.len()); - let chunked_commitment = { self.commit_non_hiding(&poly, blinders.elems.len()) }; - let masked_commitment = match self.mask_custom(chunked_commitment, blinders) { - Ok(comm) => comm, - Err(err) => panic!("Error at index {i}: {err}"), - }; - let chunked_evals = elm - .iter() - .map(|elm| chunked_polynomial.evaluate_chunks(*elm)) - .collect(); - Evaluation { - commitment: masked_commitment.commitment, - - evaluations: chunked_evals, - } - }) - .collect() - } -} - -#[serde_as] -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Default)] -#[serde(bound = "G: ark_serialize::CanonicalDeserialize + ark_serialize::CanonicalSerialize")] -pub struct OpeningProof { - /// vector of rounds of L & R commitments - #[serde_as(as = "Vec<(o1_utils::serialization::SerdeAs, o1_utils::serialization::SerdeAs)>")] - pub lr: Vec<(G, G)>, - #[serde_as(as = "o1_utils::serialization::SerdeAs")] - pub delta: G, - #[serde_as(as = "o1_utils::serialization::SerdeAs")] - pub z1: G::ScalarField, - #[serde_as(as = "o1_utils::serialization::SerdeAs")] - pub z2: G::ScalarField, - #[serde_as(as = "o1_utils::serialization::SerdeAs")] - pub sg: G, -} - -impl + CommitmentCurve + EndoCurve> - crate::OpenProof for OpeningProof -{ - type SRS = SRS; - - fn open::ScalarField>>( - srs: &Self::SRS, - group_map: &::Map, - plnms: PolynomialsToCombine, - elm: &[::ScalarField], // vector of evaluation points - polyscale: ::ScalarField, // scaling factor for polynoms - evalscale: ::ScalarField, // scaling factor for evaluation point powers - sponge: EFqSponge, // sponge - rng: &mut RNG, - ) -> Self - where - EFqSponge: - Clone + FqSponge<::BaseField, G, ::ScalarField>, - RNG: RngCore + CryptoRng, - { - srs.open(group_map, plnms, elm, polyscale, evalscale, sponge, rng) - } - - fn verify( - srs: &Self::SRS, - group_map: &G::Map, - batch: &mut [BatchEvaluationProof], - rng: &mut RNG, - ) -> bool - where - EFqSponge: FqSponge<::BaseField, G, ::ScalarField>, - RNG: RngCore + CryptoRng, - { - srs.verify(group_map, batch, rng) - } -} - -pub struct Challenges { - pub chal: Vec, - pub chal_inv: Vec, -} - -impl OpeningProof { - pub fn prechallenges>( - &self, - sponge: &mut EFqSponge, - ) -> Vec> { - let _t = sponge.challenge_fq(); - self.lr - .iter() - .map(|(l, r)| { - sponge.absorb_g(&[*l]); - sponge.absorb_g(&[*r]); - squeeze_prechallenge(sponge) - }) - .collect() - } - - pub fn challenges>( - &self, - endo_r: &G::ScalarField, - sponge: &mut EFqSponge, - ) -> Challenges { - let chal: Vec<_> = self - .lr - .iter() - .map(|(l, r)| { - sponge.absorb_g(&[*l]); - sponge.absorb_g(&[*r]); - squeeze_challenge(endo_r, sponge) - }) - .collect(); - - let chal_inv = { - let mut cs = chal.clone(); - ark_ff::batch_inversion(&mut cs); - cs - }; - - Challenges { chal, chal_inv } - } -} diff --git a/poly-commitment/src/hash_map_cache.rs b/poly-commitment/src/hash_map_cache.rs index f0ed7f5628..08f91b94a1 100644 --- a/poly-commitment/src/hash_map_cache.rs +++ b/poly-commitment/src/hash_map_cache.rs @@ -1,4 +1,5 @@ use std::{ + cmp::Eq, collections::{hash_map::Entry, HashMap}, hash::Hash, sync::{Arc, Mutex}, @@ -9,13 +10,19 @@ pub struct HashMapCache { contents: Arc>>, } -impl HashMapCache { +impl HashMapCache { pub fn new() -> Self { HashMapCache { contents: Arc::new(Mutex::new(HashMap::new())), } } + pub fn new_from_hashmap(hashmap: HashMap) -> Self { + HashMapCache { + contents: Arc::new(Mutex::new(hashmap)), + } + } + pub fn get_or_generate Value>(&self, key: Key, generator: F) -> &Value { let mut hashmap = self.contents.lock().unwrap(); let entry = (*hashmap).entry(key); @@ -39,3 +46,9 @@ impl HashMapCache { self.contents.lock().unwrap().contains_key(key) } } + +impl From> for HashMap { + fn from(cache: HashMapCache) -> HashMap { + cache.contents.lock().unwrap().clone() + } +} diff --git a/poly-commitment/src/ipa.rs b/poly-commitment/src/ipa.rs new file mode 100644 index 0000000000..17c561bb3d --- /dev/null +++ b/poly-commitment/src/ipa.rs @@ -0,0 +1,1036 @@ +//! This module contains the implementation of the polynomial commitment scheme +//! called the Inner Product Argument (IPA) as described in [Efficient +//! Zero-Knowledge Arguments for Arithmetic Circuits in the Discrete Log +//! Setting](https://eprint.iacr.org/2016/263) + +use crate::{ + commitment::{ + b_poly, b_poly_coefficients, combine_commitments, shift_scalar, squeeze_challenge, + squeeze_prechallenge, BatchEvaluationProof, CommitmentCurve, EndoCurve, + }, + error::CommitmentError, + hash_map_cache::HashMapCache, + utils::combine_polys, + BlindedCommitment, PolyComm, PolynomialsToCombine, SRS as SRSTrait, +}; +use ark_ec::{AffineRepr, CurveGroup, VariableBaseMSM}; +use ark_ff::{BigInteger, Field, One, PrimeField, UniformRand, Zero}; +use ark_poly::{ + univariate::DensePolynomial, EvaluationDomain, Evaluations, Radix2EvaluationDomain as D, +}; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; +use blake2::{Blake2b512, Digest}; +use groupmap::GroupMap; +use mina_poseidon::{sponge::ScalarChallenge, FqSponge}; +use o1_utils::{ + field_helpers::{inner_prod, pows}, + math, +}; +use rand::{CryptoRng, RngCore}; +use rayon::prelude::*; +use serde::{Deserialize, Serialize}; +use serde_with::serde_as; +use std::{cmp::min, iter::Iterator, ops::AddAssign}; + +#[serde_as] +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(bound = "G: CanonicalDeserialize + CanonicalSerialize")] +pub struct SRS { + /// The vector of group elements for committing to polynomials in + /// coefficient form. + #[serde_as(as = "Vec")] + pub g: Vec, + + /// A group element used for blinding commitments + #[serde_as(as = "o1_utils::serialization::SerdeAs")] + pub h: G, + + // TODO: the following field should be separated, as they are optimization + // values + /// Commitments to Lagrange bases, per domain size + #[serde(skip)] + pub lagrange_bases: HashMapCache>>, +} + +impl PartialEq for SRS +where + G: PartialEq, +{ + fn eq(&self, other: &Self) -> bool { + self.g == other.g && self.h == other.h + } +} + +pub fn endos() -> (G::BaseField, G::ScalarField) +where + G::BaseField: PrimeField, +{ + let endo_q: G::BaseField = mina_poseidon::sponge::endo_coefficient(); + let endo_r = { + let potential_endo_r: G::ScalarField = mina_poseidon::sponge::endo_coefficient(); + let t = G::generator(); + let (x, y) = t.to_coordinates().unwrap(); + let phi_t = G::of_coordinates(x * endo_q, y); + if t.mul(potential_endo_r) == phi_t.into_group() { + potential_endo_r + } else { + potential_endo_r * potential_endo_r + } + }; + (endo_q, endo_r) +} + +fn point_of_random_bytes(map: &G::Map, random_bytes: &[u8]) -> G +where + G::BaseField: Field, +{ + // packing in bit-representation + const N: usize = 31; + let extension_degree = G::BaseField::extension_degree() as usize; + + let mut base_fields = Vec::with_capacity(N * extension_degree); + + for base_count in 0..extension_degree { + let mut bits = [false; 8 * N]; + let offset = base_count * N; + for i in 0..N { + for j in 0..8 { + bits[8 * i + j] = (random_bytes[offset + i] >> j) & 1 == 1; + } + } + + let n = + <::BasePrimeField as PrimeField>::BigInt::from_bits_be(&bits); + let t = <::BasePrimeField as PrimeField>::from_bigint(n) + .expect("packing code has a bug"); + base_fields.push(t) + } + + let t = G::BaseField::from_base_prime_field_elems(&base_fields).unwrap(); + + let (x, y) = map.to_group(t); + G::of_coordinates(x, y).mul_by_cofactor() +} + +/// Additional methods for the SRS structure +impl SRS { + /// This function verifies a batch of polynomial commitment opening proofs. + /// Return `true` if the verification is successful, `false` otherwise. + pub fn verify( + &self, + group_map: &G::Map, + batch: &mut [BatchEvaluationProof>], + rng: &mut RNG, + ) -> bool + where + EFqSponge: FqSponge, + RNG: RngCore + CryptoRng, + G::BaseField: PrimeField, + { + // Verifier checks for all i, + // c_i Q_i + delta_i = z1_i (G_i + b_i U_i) + z2_i H + // + // if we sample evalscale at random, it suffices to check + // + // 0 == sum_i evalscale^i (c_i Q_i + delta_i - ( z1_i (G_i + b_i U_i) + z2_i H )) + // + // and because each G_i is a multiexp on the same array self.g, we + // can batch the multiexp across proofs. + // + // So for each proof in the batch, we add onto our big multiexp the following terms + // evalscale^i c_i Q_i + // evalscale^i delta_i + // - (evalscale^i z1_i) G_i + // - (evalscale^i z2_i) H + // - (evalscale^i z1_i b_i) U_i + + // We also check that the sg component of the proof is equal to the polynomial commitment + // to the "s" array + + let nonzero_length = self.g.len(); + + let max_rounds = math::ceil_log2(nonzero_length); + + let padded_length = 1 << max_rounds; + + let (_, endo_r) = endos::(); + + // TODO: This will need adjusting + let padding = padded_length - nonzero_length; + let mut points = vec![self.h]; + points.extend(self.g.clone()); + points.extend(vec![G::zero(); padding]); + + let mut scalars = vec![G::ScalarField::zero(); padded_length + 1]; + assert_eq!(scalars.len(), points.len()); + + // sample randomiser to scale the proofs with + let rand_base = G::ScalarField::rand(rng); + let sg_rand_base = G::ScalarField::rand(rng); + + let mut rand_base_i = G::ScalarField::one(); + let mut sg_rand_base_i = G::ScalarField::one(); + + for BatchEvaluationProof { + sponge, + evaluation_points, + polyscale, + evalscale, + evaluations, + opening, + combined_inner_product, + } in batch.iter_mut() + { + sponge.absorb_fr(&[shift_scalar::(*combined_inner_product)]); + + let u_base: G = { + let t = sponge.challenge_fq(); + let (x, y) = group_map.to_group(t); + G::of_coordinates(x, y) + }; + + let Challenges { chal, chal_inv } = opening.challenges::(&endo_r, sponge); + + sponge.absorb_g(&[opening.delta]); + let c = ScalarChallenge(sponge.challenge()).to_field(&endo_r); + + // < s, sum_i evalscale^i pows(evaluation_point[i]) > + // == + // sum_i evalscale^i < s, pows(evaluation_point[i]) > + let b0 = { + let mut scale = G::ScalarField::one(); + let mut res = G::ScalarField::zero(); + for &e in evaluation_points.iter() { + let term = b_poly(&chal, e); + res += &(scale * term); + scale *= *evalscale; + } + res + }; + + let s = b_poly_coefficients(&chal); + + let neg_rand_base_i = -rand_base_i; + + // TERM + // - rand_base_i z1 G + // + // we also add -sg_rand_base_i * G to check correctness of sg. + points.push(opening.sg); + scalars.push(neg_rand_base_i * opening.z1 - sg_rand_base_i); + + // Here we add + // sg_rand_base_i * ( < s, self.g > ) + // = + // < sg_rand_base_i s, self.g > + // + // to check correctness of the sg component. + { + let terms: Vec<_> = s.par_iter().map(|s| sg_rand_base_i * s).collect(); + + for (i, term) in terms.iter().enumerate() { + scalars[i + 1] += term; + } + } + + // TERM + // - rand_base_i * z2 * H + scalars[0] -= &(rand_base_i * opening.z2); + + // TERM + // -rand_base_i * (z1 * b0 * U) + scalars.push(neg_rand_base_i * (opening.z1 * b0)); + points.push(u_base); + + // TERM + // rand_base_i c_i Q_i + // = rand_base_i c_i + // (sum_j (chal_invs[j] L_j + chals[j] R_j) + P_prime) + // where P_prime = combined commitment + combined_inner_product * U + let rand_base_i_c_i = c * rand_base_i; + for ((l, r), (u_inv, u)) in opening.lr.iter().zip(chal_inv.iter().zip(chal.iter())) { + points.push(*l); + scalars.push(rand_base_i_c_i * u_inv); + + points.push(*r); + scalars.push(rand_base_i_c_i * u); + } + + // TERM + // sum_j evalscale^j (sum_i polyscale^i f_i) (elm_j) + // == sum_j sum_i evalscale^j polyscale^i f_i(elm_j) + // == sum_i polyscale^i sum_j evalscale^j f_i(elm_j) + combine_commitments( + evaluations, + &mut scalars, + &mut points, + *polyscale, + rand_base_i_c_i, + ); + + scalars.push(rand_base_i_c_i * *combined_inner_product); + points.push(u_base); + + scalars.push(rand_base_i); + points.push(opening.delta); + + rand_base_i *= &rand_base; + sg_rand_base_i *= &sg_rand_base; + } + + // verify the equation + let scalars: Vec<_> = scalars.iter().map(|x| x.into_bigint()).collect(); + G::Group::msm_bigint(&points, &scalars) == G::Group::zero() + } + + /// This function creates a trusted-setup SRS instance for circuits with + /// number of rows up to `depth`. + /// + /// # Safety + /// + /// This function is unsafe because it creates a trusted setup and the toxic + /// waste is passed as a parameter. + pub unsafe fn create_trusted_setup(x: G::ScalarField, depth: usize) -> Self { + let m = G::Map::setup(); + + let mut x_pow = G::ScalarField::one(); + let g: Vec<_> = (0..depth) + .map(|_| { + let res = G::generator().mul(x_pow); + x_pow *= x; + res.into_affine() + }) + .collect(); + + // Compute a blinder + let h = { + let mut h = Blake2b512::new(); + h.update("srs_misc".as_bytes()); + // FIXME: This is for retrocompatibility with a previous version + // that was using a list initialisation. It is not necessary. + h.update(0_u32.to_be_bytes()); + point_of_random_bytes(&m, &h.finalize()) + }; + + Self { + g, + h, + lagrange_bases: HashMapCache::new(), + } + } +} + +impl SRS +where + ::Map: Sync, + G::BaseField: PrimeField, +{ + /// This function creates SRS instance for circuits with number of rows up + /// to `depth`. + pub fn create_parallel(depth: usize) -> Self { + let m = G::Map::setup(); + + let g: Vec<_> = (0..depth) + .into_par_iter() + .map(|i| { + let mut h = Blake2b512::new(); + h.update((i as u32).to_be_bytes()); + point_of_random_bytes(&m, &h.finalize()) + }) + .collect(); + + // Compute a blinder + let h = { + let mut h = Blake2b512::new(); + h.update("srs_misc".as_bytes()); + // FIXME: This is for retrocompatibility with a previous version + // that was using a list initialisation. It is not necessary. + h.update(0_u32.to_be_bytes()); + point_of_random_bytes(&m, &h.finalize()) + }; + + Self { + g, + h, + lagrange_bases: HashMapCache::new(), + } + } +} + +impl SRSTrait for SRS +where + G: CommitmentCurve, +{ + /// The maximum polynomial degree that can be committed to + fn max_poly_size(&self) -> usize { + self.g.len() + } + + fn blinding_commitment(&self) -> G { + self.h + } + + /// Turns a non-hiding polynomial commitment into a hidding polynomial + /// commitment. Transforms each given `` into `( + wH, w)` with + /// a random `w` per commitment. + fn mask( + &self, + comm: PolyComm, + rng: &mut (impl RngCore + CryptoRng), + ) -> BlindedCommitment { + let blinders = comm.map(|_| G::ScalarField::rand(rng)); + self.mask_custom(comm, &blinders).unwrap() + } + + fn mask_custom( + &self, + com: PolyComm, + blinders: &PolyComm, + ) -> Result, CommitmentError> { + let commitment = com + .zip(blinders) + .ok_or_else(|| CommitmentError::BlindersDontMatch(blinders.len(), com.len()))? + .map(|(g, b)| { + let mut g_masked = self.h.mul(b); + g_masked.add_assign(&g); + g_masked.into_affine() + }); + Ok(BlindedCommitment { + commitment, + blinders: blinders.clone(), + }) + } + + fn commit_non_hiding( + &self, + plnm: &DensePolynomial, + num_chunks: usize, + ) -> PolyComm { + let is_zero = plnm.is_zero(); + + let coeffs: Vec<_> = plnm.iter().map(|c| c.into_bigint()).collect(); + + // chunk while commiting + let mut chunks = vec![]; + if is_zero { + chunks.push(G::zero()); + } else { + coeffs.chunks(self.g.len()).for_each(|coeffs_chunk| { + let chunk = G::Group::msm_bigint(&self.g, coeffs_chunk); + chunks.push(chunk.into_affine()); + }); + } + + for _ in chunks.len()..num_chunks { + chunks.push(G::zero()); + } + + PolyComm::::new(chunks) + } + + fn commit( + &self, + plnm: &DensePolynomial, + num_chunks: usize, + rng: &mut (impl RngCore + CryptoRng), + ) -> BlindedCommitment { + self.mask(self.commit_non_hiding(plnm, num_chunks), rng) + } + + fn commit_custom( + &self, + plnm: &DensePolynomial, + num_chunks: usize, + blinders: &PolyComm, + ) -> Result, CommitmentError> { + self.mask_custom(self.commit_non_hiding(plnm, num_chunks), blinders) + } + + fn commit_evaluations_non_hiding( + &self, + domain: D, + plnm: &Evaluations>, + ) -> PolyComm { + let basis = self.get_lagrange_basis(domain); + let commit_evaluations = |evals: &Vec, basis: &Vec>| { + PolyComm::::multi_scalar_mul(&basis.iter().collect::>()[..], &evals[..]) + }; + match domain.size.cmp(&plnm.domain().size) { + std::cmp::Ordering::Less => { + let s = (plnm.domain().size / domain.size) as usize; + let v: Vec<_> = (0..(domain.size())).map(|i| plnm.evals[s * i]).collect(); + commit_evaluations(&v, basis) + } + std::cmp::Ordering::Equal => commit_evaluations(&plnm.evals, basis), + std::cmp::Ordering::Greater => { + panic!("desired commitment domain size ({}) greater than evaluations' domain size ({}):", domain.size, plnm.domain().size) + } + } + } + + fn commit_evaluations( + &self, + domain: D, + plnm: &Evaluations>, + rng: &mut (impl RngCore + CryptoRng), + ) -> BlindedCommitment { + self.mask(self.commit_evaluations_non_hiding(domain, plnm), rng) + } + + fn commit_evaluations_custom( + &self, + domain: D, + plnm: &Evaluations>, + blinders: &PolyComm, + ) -> Result, CommitmentError> { + self.mask_custom(self.commit_evaluations_non_hiding(domain, plnm), blinders) + } + + fn create(depth: usize) -> Self { + let m = G::Map::setup(); + + let g: Vec<_> = (0..depth) + .map(|i| { + let mut h = Blake2b512::new(); + h.update((i as u32).to_be_bytes()); + point_of_random_bytes(&m, &h.finalize()) + }) + .collect(); + + // Compute a blinder + let h = { + let mut h = Blake2b512::new(); + h.update("srs_misc".as_bytes()); + // FIXME: This is for retrocompatibility with a previous version + // that was using a list initialisation. It is not necessary. + h.update(0_u32.to_be_bytes()); + point_of_random_bytes(&m, &h.finalize()) + }; + + Self { + g, + h, + lagrange_bases: HashMapCache::new(), + } + } + + fn get_lagrange_basis_from_domain_size(&self, domain_size: usize) -> &Vec> { + self.lagrange_bases.get_or_generate(domain_size, || { + self.lagrange_basis(D::new(domain_size).unwrap()) + }) + } + + fn get_lagrange_basis(&self, domain: D) -> &Vec> { + self.lagrange_bases + .get_or_generate(domain.size(), || self.lagrange_basis(domain)) + } + + fn size(&self) -> usize { + self.g.len() + } +} + +impl SRS { + #[allow(clippy::type_complexity)] + #[allow(clippy::many_single_char_names)] + // NB: a slight modification to the original protocol is done when absorbing + // the first prover message to improve the efficiency in a recursive + // setting. + pub fn open>( + &self, + group_map: &G::Map, + plnms: PolynomialsToCombine, + elm: &[G::ScalarField], + polyscale: G::ScalarField, + evalscale: G::ScalarField, + mut sponge: EFqSponge, + rng: &mut RNG, + ) -> OpeningProof + where + EFqSponge: Clone + FqSponge, + RNG: RngCore + CryptoRng, + G::BaseField: PrimeField, + G: EndoCurve, + { + let (endo_q, endo_r) = endos::(); + + let rounds = math::ceil_log2(self.g.len()); + let padded_length = 1 << rounds; + + // TODO: Trim this to the degree of the largest polynomial + // TODO: We do always suppose we have a power of 2 for the SRS in + // practice. Therefore, padding equals zero, and this code can be + // removed. Only a current test case uses a SRS with a non-power of 2. + let padding = padded_length - self.g.len(); + let mut g = self.g.clone(); + g.extend(vec![G::zero(); padding]); + + // Combines polynomials roughly as follows: p(X) := ∑_i polyscale^i p_i(X) + // + // `blinding_factor` is a combined set of commitments that are + // paired with polynomials in `plnms`. In kimchi, these input commitments + // are poly com blinders, so often `[G::ScalarField::one(); num_chunks]` or zeroes. + let (p, blinding_factor) = combine_polys::(plnms, polyscale, self.g.len()); + + // The initial evaluation vector for polynomial commitment b_init is not + // just the powers of a single point as in the original IPA (1,ζ,ζ^2,...) + // + // but rather a vector of linearly combined powers with `evalscale` as recombiner. + // + // b_init[j] = Σ_i evalscale^i elm_i^j + // = ζ^j + evalscale * ζ^j ω^j (in the specific case of challenges (ζ,ζω)) + // + // So in our case b_init is the following vector: + // 1 + evalscale + // ζ + evalscale * ζ ω + // ζ^2 + evalscale * (ζ ω)^2 + // ζ^3 + evalscale * (ζ ω)^3 + // ... + let b_init = { + // randomise/scale the eval powers + let mut scale = G::ScalarField::one(); + let mut res: Vec = + (0..padded_length).map(|_| G::ScalarField::zero()).collect(); + for e in elm { + for (i, t) in pows(padded_length, *e).iter().enumerate() { + res[i] += &(scale * t); + } + scale *= &evalscale; + } + res + }; + + // Combined polynomial p(X) evaluated at the combined eval point b_init. + let combined_inner_product = p + .coeffs + .iter() + .zip(b_init.iter()) + .map(|(a, b)| *a * b) + .fold(G::ScalarField::zero(), |acc, x| acc + x); + + // Usually, the prover sends `combined_inner_product`` to the verifier + // So we should absorb `combined_inner_product`` + // However it is more efficient in the recursion circuit + // to absorb a slightly modified version of it. + // As a reminder, in a recursive setting, the challenges are given as a public input + // and verified in the next iteration. + // See the `shift_scalar`` doc. + sponge.absorb_fr(&[shift_scalar::(combined_inner_product)]); + + // Generate another randomisation base U; our commitments will be w.r.t bases {G_i},H,U. + let u_base: G = { + let t = sponge.challenge_fq(); + let (x, y) = group_map.to_group(t); + G::of_coordinates(x, y) + }; + + let mut a = p.coeffs; + assert!(padded_length >= a.len()); + a.extend(vec![G::ScalarField::zero(); padded_length - a.len()]); + + let mut b = b_init; + + let mut lr = vec![]; + + let mut blinders = vec![]; + + let mut chals = vec![]; + let mut chal_invs = vec![]; + + // The main IPA folding loop that has log iterations. + for _ in 0..rounds { + let n = g.len() / 2; + // Pedersen bases + let (g_lo, g_hi) = (&g[0..n], &g[n..]); + // Polynomial coefficients + let (a_lo, a_hi) = (&a[0..n], &a[n..]); + // Evaluation points + let (b_lo, b_hi) = (&b[0..n], &b[n..]); + + // Blinders for L/R + let rand_l = ::rand(rng); + let rand_r = ::rand(rng); + + // Pedersen commitment to a_lo,rand_l, + let l = G::Group::msm_bigint( + &[g_lo, &[self.h, u_base]].concat(), + &[a_hi, &[rand_l, inner_prod(a_hi, b_lo)]] + .concat() + .iter() + .map(|x| x.into_bigint()) + .collect::>(), + ) + .into_affine(); + + let r = G::Group::msm_bigint( + &[g_hi, &[self.h, u_base]].concat(), + &[a_lo, &[rand_r, inner_prod(a_lo, b_hi)]] + .concat() + .iter() + .map(|x| x.into_bigint()) + .collect::>(), + ) + .into_affine(); + + lr.push((l, r)); + blinders.push((rand_l, rand_r)); + + sponge.absorb_g(&[l]); + sponge.absorb_g(&[r]); + + // Round #i challenges; + // - not to be confused with "u_base" + // - not to be confused with "u" as "polyscale" + let u_pre = squeeze_prechallenge(&mut sponge); + let u = u_pre.to_field(&endo_r); + let u_inv = u.inverse().unwrap(); + + chals.push(u); + chal_invs.push(u_inv); + + // IPA-folding polynomial coefficients + a = a_hi + .par_iter() + .zip(a_lo) + .map(|(&hi, &lo)| { + // lo + u_inv * hi + let mut res = hi; + res *= u_inv; + res += &lo; + res + }) + .collect(); + + // IPA-folding evaluation points + b = b_lo + .par_iter() + .zip(b_hi) + .map(|(&lo, &hi)| { + // lo + u * hi + let mut res = hi; + res *= u; + res += &lo; + res + }) + .collect(); + + // IPA-folding bases + g = G::combine_one_endo(endo_r, endo_q, g_lo, g_hi, u_pre); + } + + assert!( + g.len() == 1 && a.len() == 1 && b.len() == 1, + "IPA commitment folding must produce single elements after log rounds" + ); + let a0 = a[0]; + let b0 = b[0]; + let g0 = g[0]; + + // Compute r_prime, a folded blinder. It combines blinders on + // each individual step of the IPA folding process together + // with the final blinding_factor of the polynomial. + // + // r_prime := ∑_i (rand_l[i] * u[i]^{-1} + rand_r * u[i]) + // + blinding_factor + // + // where u is a vector of folding challenges, and rand_l/rand_r are + // intermediate L/R blinders. + let r_prime = blinders + .iter() + .zip(chals.iter().zip(chal_invs.iter())) + .map(|((rand_l, rand_r), (u, u_inv))| ((*rand_l) * u_inv) + (*rand_r * u)) + .fold(blinding_factor, |acc, x| acc + x); + + let d = ::rand(rng); + let r_delta = ::rand(rng); + + // Compute delta, the commitment + // delta = [d] G0 + [b0*d] U_base + [r_delta] H^r (as a group element, in additive notation) + let delta = ((g0.into_group() + (u_base.mul(b0))).into_affine().mul(d) + + self.h.mul(r_delta)) + .into_affine(); + + sponge.absorb_g(&[delta]); + let c = ScalarChallenge(sponge.challenge()).to_field(&endo_r); + + // (?) Schnorr-like responses showing the knowledge of r_prime and a0. + let z1 = a0 * c + d; + let z2 = r_prime * c + r_delta; + + OpeningProof { + delta, + lr, + z1, + z2, + sg: g0, + } + } + + fn lagrange_basis(&self, domain: D) -> Vec> { + let n = domain.size(); + + // Let V be a vector space over the field F. + // + // Given + // - a domain [ 1, w, w^2, ..., w^{n - 1} ] + // - a vector v := [ v_0, ..., v_{n - 1} ] in V^n + // + // the FFT algorithm computes the matrix application + // + // u = M(w) * v + // + // where + // M(w) = + // 1 1 1 ... 1 + // 1 w w^2 ... w^{n-1} + // ... + // 1 w^{n-1} (w^2)^{n-1} ... (w^{n-1})^{n-1} + // + // The IFFT algorithm computes + // + // v = M(w)^{-1} * u + // + // Let's see how we can use this algorithm to compute the lagrange basis + // commitments. + // + // Let V be the vector space F[x] of polynomials in x over F. + // Let v in V be the vector [ L_0, ..., L_{n - 1} ] where L_i is the i^{th} + // normalized Lagrange polynomial (where L_i(w^j) = j == i ? 1 : 0). + // + // Consider the rows of M(w) * v. Let me write out the matrix and vector so you + // can see more easily. + // + // | 1 1 1 ... 1 | | L_0 | + // | 1 w w^2 ... w^{n-1} | * | L_1 | + // | ... | | ... | + // | 1 w^{n-1} (w^2)^{n-1} ... (w^{n-1})^{n-1} | | L_{n-1} | + // + // The 0th row is L_0 + L1 + ... + L_{n - 1}. So, it's the polynomial + // that has the value 1 on every element of the domain. + // In other words, it's the polynomial 1. + // + // The 1st row is L_0 + w L_1 + ... + w^{n - 1} L_{n - 1}. So, it's the + // polynomial which has value w^i on w^i. + // In other words, it's the polynomial x. + // + // In general, you can see that row i is in fact the polynomial x^i. + // + // Thus, M(w) * v is the vector u, where u = [ 1, x, x^2, ..., x^n ] + // + // Therefore, the IFFT algorithm, when applied to the vector u (the standard + // monomial basis) will yield the vector v of the (normalized) Lagrange polynomials. + // + // Now, because the polynomial commitment scheme is additively homomorphic, and + // because the commitment to the polynomial x^i is just self.g[i], we can obtain + // commitments to the normalized Lagrange polynomials by applying IFFT to the + // vector self.g[0..n]. + // + // + // Further still, we can do the same trick for 'chunked' polynomials. + // + // Recall that a chunked polynomial is some f of degree k*n - 1 with + // f(x) = f_0(x) + x^n f_1(x) + ... + x^{(k-1) n} f_{k-1}(x) + // where each f_i has degree n-1. + // + // In the above, if we set u = [ 1, x^2, ... x^{n-1}, 0, 0, .., 0 ] + // then we effectively 'zero out' any polynomial terms higher than x^{n-1}, leaving + // us with the 'partial Lagrange polynomials' that contribute to f_0. + // + // Similarly, u = [ 0, 0, ..., 0, 1, x^2, ..., x^{n-1}, 0, 0, ..., 0] with n leading + // zeros 'zeroes out' all terms except the 'partial Lagrange polynomials' that + // contribute to f_1, and likewise for each f_i. + // + // By computing each of these, and recollecting the terms as a vector of polynomial + // commitments, we obtain a chunked commitment to the L_i polynomials. + let srs_size = self.g.len(); + let num_elems = (n + srs_size - 1) / srs_size; + let mut chunks = Vec::with_capacity(num_elems); + + // For each chunk + for i in 0..num_elems { + // Initialize the vector with zero curve points + let mut lg: Vec<::Group> = vec![::Group::zero(); n]; + // Overwrite the terms corresponding to that chunk with the SRS curve points + let start_offset = i * srs_size; + let num_terms = min((i + 1) * srs_size, n) - start_offset; + for j in 0..num_terms { + lg[start_offset + j] = self.g[j].into_group() + } + // Apply the IFFT + domain.ifft_in_place(&mut lg); + // Append the 'partial Langrange polynomials' to the vector of elems chunks + chunks.push(::Group::normalize_batch(lg.as_mut_slice())); + } + + (0..n) + .map(|i| PolyComm { + chunks: chunks.iter().map(|v| v[i]).collect(), + }) + .collect() + } +} + +#[serde_as] +#[derive(Clone, Debug, Serialize, Deserialize, Default, PartialEq)] +#[serde(bound = "G: ark_serialize::CanonicalDeserialize + ark_serialize::CanonicalSerialize")] +pub struct OpeningProof { + /// Vector of rounds of L & R commitments + #[serde_as(as = "Vec<(o1_utils::serialization::SerdeAs, o1_utils::serialization::SerdeAs)>")] + pub lr: Vec<(G, G)>, + #[serde_as(as = "o1_utils::serialization::SerdeAs")] + pub delta: G, + #[serde_as(as = "o1_utils::serialization::SerdeAs")] + pub z1: G::ScalarField, + #[serde_as(as = "o1_utils::serialization::SerdeAs")] + pub z2: G::ScalarField, + /// A final folded commitment base + #[serde_as(as = "o1_utils::serialization::SerdeAs")] + pub sg: G, +} + +impl + CommitmentCurve + EndoCurve> + crate::OpenProof for OpeningProof +{ + type SRS = SRS; + + fn open::ScalarField>>( + srs: &Self::SRS, + group_map: &::Map, + plnms: PolynomialsToCombine, + elm: &[::ScalarField], // vector of evaluation points + polyscale: ::ScalarField, // scaling factor for polynoms + evalscale: ::ScalarField, // scaling factor for evaluation point powers + sponge: EFqSponge, // sponge + rng: &mut RNG, + ) -> Self + where + EFqSponge: + Clone + FqSponge<::BaseField, G, ::ScalarField>, + RNG: RngCore + CryptoRng, + { + srs.open(group_map, plnms, elm, polyscale, evalscale, sponge, rng) + } + + fn verify( + srs: &Self::SRS, + group_map: &G::Map, + batch: &mut [BatchEvaluationProof], + rng: &mut RNG, + ) -> bool + where + EFqSponge: FqSponge<::BaseField, G, ::ScalarField>, + RNG: RngCore + CryptoRng, + { + srs.verify(group_map, batch, rng) + } +} + +/// Commitment round challenges (endo mapped) and their inverses. +pub struct Challenges { + pub chal: Vec, + pub chal_inv: Vec, +} + +impl OpeningProof { + /// Computes a log-sized vector of scalar challenges for + /// recombining elements inside the IPA. + pub fn prechallenges>( + &self, + sponge: &mut EFqSponge, + ) -> Vec> { + let _t = sponge.challenge_fq(); + self.lr + .iter() + .map(|(l, r)| { + sponge.absorb_g(&[*l]); + sponge.absorb_g(&[*r]); + squeeze_prechallenge(sponge) + }) + .collect() + } + + /// Same as `prechallenges`, but maps scalar challenges using the + /// provided endomorphism, and computes their inverses. + pub fn challenges>( + &self, + endo_r: &G::ScalarField, + sponge: &mut EFqSponge, + ) -> Challenges { + let chal: Vec<_> = self + .lr + .iter() + .map(|(l, r)| { + sponge.absorb_g(&[*l]); + sponge.absorb_g(&[*r]); + squeeze_challenge(endo_r, sponge) + }) + .collect(); + + let chal_inv = { + let mut cs = chal.clone(); + ark_ff::batch_inversion(&mut cs); + cs + }; + + Challenges { chal, chal_inv } + } +} + +#[cfg(feature = "ocaml_types")] +pub mod caml { + use super::OpeningProof; + use ark_ec::AffineRepr; + use ocaml; + + #[derive(ocaml::IntoValue, ocaml::FromValue, ocaml_gen::Struct)] + pub struct CamlOpeningProof { + /// vector of rounds of L & R commitments + pub lr: Vec<(G, G)>, + pub delta: G, + pub z1: F, + pub z2: F, + pub sg: G, + } + + impl From> for CamlOpeningProof + where + G: AffineRepr, + CamlG: From, + CamlF: From, + { + fn from(opening_proof: OpeningProof) -> Self { + Self { + lr: opening_proof + .lr + .into_iter() + .map(|(g1, g2)| (CamlG::from(g1), CamlG::from(g2))) + .collect(), + delta: CamlG::from(opening_proof.delta), + z1: opening_proof.z1.into(), + z2: opening_proof.z2.into(), + sg: CamlG::from(opening_proof.sg), + } + } + } + + impl From> for OpeningProof + where + G: AffineRepr, + CamlG: Into, + CamlF: Into, + { + fn from(caml: CamlOpeningProof) -> Self { + Self { + lr: caml + .lr + .into_iter() + .map(|(g1, g2)| (g1.into(), g2.into())) + .collect(), + delta: caml.delta.into(), + z1: caml.z1.into(), + z2: caml.z2.into(), + sg: caml.sg.into(), + } + } + } +} diff --git a/poly-commitment/src/kzg.rs b/poly-commitment/src/kzg.rs new file mode 100644 index 0000000000..958b2b494c --- /dev/null +++ b/poly-commitment/src/kzg.rs @@ -0,0 +1,487 @@ +//! This module implements the KZG protocol described in the paper +//! [Constant-Size Commitments to Polynomials and Their +//! Applications](https://www.iacr.org/archive/asiacrypt2010/6477178/6477178.pdf) +//! by Kate, Zaverucha and Goldberg, often referred to as the KZG10 paper. +//! +//! The protocol requires a structured reference string (SRS) that contains +//! powers of a generator of a group, and a pairing friendly curve. +//! +//! The pairing friendly curve requirement is hidden in the Pairing trait +//! parameter. + +use crate::{ + commitment::*, ipa::SRS, utils::combine_polys, CommitmentError, PolynomialsToCombine, + SRS as SRSTrait, +}; + +use ark_ec::{pairing::Pairing, AffineRepr, VariableBaseMSM}; +use ark_ff::{One, PrimeField, Zero}; +use ark_poly::{ + univariate::{DenseOrSparsePolynomial, DensePolynomial}, + DenseUVPolynomial, EvaluationDomain, Evaluations, Polynomial, Radix2EvaluationDomain as D, +}; +use mina_poseidon::FqSponge; +use rand::thread_rng; +use rand_core::{CryptoRng, RngCore}; +use serde::{Deserialize, Serialize}; +use serde_with::serde_as; +use std::ops::Neg; + +/// Combine the (chunked) evaluations of multiple polynomials. +/// This function returns the accumulation of the evaluations, scaled by +/// `polyscale`. +/// If no evaluation is given, the function returns an empty vector. +/// It does also suppose that for each evaluation, the number of evaluations is +/// the same. It is not constrained yet in the interface, but it should be. If +/// one list has not the same size, it will be shrunk to the size of the first +/// element of the list. +/// For instance, if we have 3 polynomials P1, P2, P3 evaluated at the points +/// ζ and ζω (like in vanilla PlonK), and for each polynomial, we have two +/// chunks, i.e. we have +/// ```text +/// 2 chunks of P1 +/// /---------------\ +/// E1 = [(P1_1(ζ), P1_2(ζ)), (P1_1(ζω), P1_2(ζω))] +/// E2 = [(P2_1(ζ), P2_2(ζ)), (P2_1(ζω), P2_2(ζω))] +/// E3 = [(P3_1(ζ), P3_2(ζ)), (P3_1(ζω), P3_2(ζω))] +/// ``` +/// The output will be a list of 3 elements, equal to: +/// ```text +/// P1_1(ζ) + P1_2(ζ) * polyscale + P1_1(ζω) polyscale^2 + P1_2(ζω) * polyscale^3 +/// P2_1(ζ) + P2_2(ζ) * polyscale + P2_1(ζω) polyscale^2 + P2_2(ζω) * polyscale^3 +/// ``` +pub fn combine_evaluations( + evaluations: &Vec>, + polyscale: G::ScalarField, +) -> Vec { + let mut polyscale_i = G::ScalarField::one(); + let mut acc = { + let num_evals = if !evaluations.is_empty() { + evaluations[0].evaluations.len() + } else { + 0 + }; + vec![G::ScalarField::zero(); num_evals] + }; + + for Evaluation { evaluations, .. } in evaluations.iter().filter(|x| !x.commitment.is_empty()) { + // IMPROVEME: we could have a flat array that would contain all the + // evaluations and all the chunks. It would avoid fetching the memory + // and avoid indirection into RAM. + // We could have a single flat array. + // iterating over the polynomial segments + for chunk_idx in 0..evaluations[0].len() { + // supposes that all evaluations are of the same size + for eval_pt_idx in 0..evaluations.len() { + acc[eval_pt_idx] += evaluations[eval_pt_idx][chunk_idx] * polyscale_i; + } + polyscale_i *= polyscale; + } + } + + acc +} + +#[serde_as] +#[derive(Debug, Serialize, Deserialize)] +#[serde( + bound = "Pair::G1Affine: ark_serialize::CanonicalDeserialize + ark_serialize::CanonicalSerialize" +)] +pub struct KZGProof { + #[serde_as(as = "o1_utils::serialization::SerdeAs")] + pub quotient: Pair::G1Affine, + #[serde_as(as = "o1_utils::serialization::SerdeAs")] + /// A blinding factor used to hide the polynomial, if necessary + pub blinding: ::ScalarField, +} + +impl Default for KZGProof { + fn default() -> Self { + Self { + quotient: Pair::G1Affine::generator(), + blinding: ::ScalarField::zero(), + } + } +} + +impl Clone for KZGProof { + fn clone(&self) -> Self { + Self { + quotient: self.quotient, + blinding: self.blinding, + } + } +} + +#[derive(Debug, PartialEq, Serialize, Deserialize)] +/// Define a structured reference string (i.e. SRS) for the KZG protocol. +/// The SRS consists of powers of an element `g^x` for some toxic waste `x`. +/// +/// The SRS is formed using what we call a "trusted setup". For now, the setup +/// is created using the method `create_trusted_setup`. +pub struct PairingSRS { + /// The full SRS is the one used by the prover. Can be seen as the "proving + /// key"/"secret key" + pub full_srs: SRS, + /// SRS to be used by the verifier. Can be seen as the "verification + /// key"/"public key". + pub verifier_srs: SRS, +} + +impl< + F: PrimeField, + G: CommitmentCurve, + G2: CommitmentCurve, + Pair: Pairing, + > PairingSRS +{ + /// Create a trusted setup for the KZG protocol. + /// The setup is created using a toxic waste `toxic_waste` and a depth + /// `depth`. + pub fn create_trusted_setup(toxic_waste: F, depth: usize) -> Self { + let full_srs = unsafe { SRS::create_trusted_setup(toxic_waste, depth) }; + let verifier_srs = unsafe { SRS::create_trusted_setup(toxic_waste, 3) }; + Self { + full_srs, + verifier_srs, + } + } +} + +impl Default for PairingSRS { + fn default() -> Self { + Self { + full_srs: SRS::default(), + verifier_srs: SRS::default(), + } + } +} + +impl Clone for PairingSRS { + fn clone(&self) -> Self { + Self { + full_srs: self.full_srs.clone(), + verifier_srs: self.verifier_srs.clone(), + } + } +} + +impl< + F: PrimeField, + G: CommitmentCurve, + G2: CommitmentCurve, + Pair: Pairing, + > crate::OpenProof for KZGProof +{ + type SRS = PairingSRS; + + /// Parameters: + /// - `srs`: the structured reference string + /// - `plnms`: vector of polynomials with optional degree bound and + /// commitment randomness + /// - `elm`: vector of evaluation points + /// - `polyscale`: scaling factor for polynoms + /// group_maps, sponge, rng and evalscale are not used. The parameters are + /// kept to fit the trait and to be used generically. + fn open>( + srs: &Self::SRS, + _group_map: &::Map, + plnms: PolynomialsToCombine, + elm: &[::ScalarField], + polyscale: ::ScalarField, + _evalscale: ::ScalarField, + _sponge: EFqSponge, + _rng: &mut RNG, + ) -> Self + where + EFqSponge: Clone + FqSponge<::BaseField, G, F>, + RNG: RngCore + CryptoRng, + { + KZGProof::create(srs, plnms, elm, polyscale).unwrap() + } + + fn verify( + srs: &Self::SRS, + _group_map: &G::Map, + batch: &mut [BatchEvaluationProof], + _rng: &mut RNG, + ) -> bool + where + EFqSponge: FqSponge, + RNG: RngCore + CryptoRng, + { + for BatchEvaluationProof { + sponge: _, + evaluations, + evaluation_points, + polyscale, + evalscale: _, + opening, + combined_inner_product: _, + } in batch.iter() + { + if !opening.verify(srs, evaluations, *polyscale, evaluation_points) { + return false; + } + } + true + } +} + +impl< + F: PrimeField, + G: CommitmentCurve, + G2: CommitmentCurve, + Pair: Pairing, + > SRSTrait for PairingSRS +{ + fn max_poly_size(&self) -> usize { + self.full_srs.max_poly_size() + } + + fn get_lagrange_basis(&self, domain: D) -> &Vec> { + self.full_srs.get_lagrange_basis(domain) + } + + fn get_lagrange_basis_from_domain_size(&self, domain_size: usize) -> &Vec> { + self.full_srs + .get_lagrange_basis_from_domain_size(domain_size) + } + + fn blinding_commitment(&self) -> G { + self.full_srs.blinding_commitment() + } + + fn mask_custom( + &self, + com: PolyComm, + blinders: &PolyComm, + ) -> Result, CommitmentError> { + self.full_srs.mask_custom(com, blinders) + } + + fn mask( + &self, + comm: PolyComm, + rng: &mut (impl RngCore + CryptoRng), + ) -> BlindedCommitment { + self.full_srs.mask(comm, rng) + } + + fn commit( + &self, + plnm: &DensePolynomial, + num_chunks: usize, + rng: &mut (impl RngCore + CryptoRng), + ) -> BlindedCommitment { + self.full_srs.commit(plnm, num_chunks, rng) + } + + fn commit_non_hiding( + &self, + plnm: &DensePolynomial, + num_chunks: usize, + ) -> PolyComm { + self.full_srs.commit_non_hiding(plnm, num_chunks) + } + + fn commit_custom( + &self, + plnm: &DensePolynomial<::ScalarField>, + num_chunks: usize, + blinders: &PolyComm<::ScalarField>, + ) -> Result, CommitmentError> { + self.full_srs.commit_custom(plnm, num_chunks, blinders) + } + + fn commit_evaluations_non_hiding( + &self, + domain: D, + plnm: &Evaluations>, + ) -> PolyComm { + self.full_srs.commit_evaluations_non_hiding(domain, plnm) + } + + fn commit_evaluations( + &self, + domain: D, + plnm: &Evaluations>, + rng: &mut (impl RngCore + CryptoRng), + ) -> BlindedCommitment { + self.full_srs.commit_evaluations(domain, plnm, rng) + } + + fn commit_evaluations_custom( + &self, + domain: D<::ScalarField>, + plnm: &Evaluations<::ScalarField, D<::ScalarField>>, + blinders: &PolyComm<::ScalarField>, + ) -> Result, CommitmentError> { + self.full_srs + .commit_evaluations_custom(domain, plnm, blinders) + } + + fn create(depth: usize) -> Self { + let mut rng = thread_rng(); + let toxic_waste = G::ScalarField::rand(&mut rng); + Self::create_trusted_setup(toxic_waste, depth) + } + + fn size(&self) -> usize { + self.full_srs.g.len() + } +} + +/// The polynomial that evaluates to each of `evals` for the respective `elm`s. +/// For now, only works for 2 evaluations points. +/// `elm` is the vector of evaluation points and `evals` is the vector of +/// evaluations at those points. +fn eval_polynomial(elm: &[F], evals: &[F]) -> DensePolynomial { + assert_eq!(elm.len(), evals.len()); + let (zeta, zeta_omega) = if elm.len() == 2 { + (elm[0], elm[1]) + } else { + todo!() + }; + let (eval_zeta, eval_zeta_omega) = if evals.len() == 2 { + (evals[0], evals[1]) + } else { + todo!() + }; + + // The polynomial that evaluates to `p(ζ)` at `ζ` and `p(ζω)` at + // `ζω`. + // We write `p(x) = a + bx`, which gives + // ```text + // p(ζ) = a + b * ζ + // p(ζω) = a + b * ζω + // ``` + // and so + // ```text + // b = (p(ζω) - p(ζ)) / (ζω - ζ) + // a = p(ζ) - b * ζ + // ``` + let b = (eval_zeta_omega - eval_zeta) / (zeta_omega - zeta); + let a = eval_zeta - b * zeta; + DensePolynomial::from_coefficients_slice(&[a, b]) +} + +/// The polynomial that evaluates to `0` at the evaluation points. +fn divisor_polynomial(elm: &[F]) -> DensePolynomial { + elm.iter() + .map(|value| DensePolynomial::from_coefficients_slice(&[-(*value), F::one()])) + .reduce(|poly1, poly2| &poly1 * &poly2) + .unwrap() +} + +impl< + F: PrimeField, + G: CommitmentCurve, + G2: CommitmentCurve, + Pair: Pairing, + > KZGProof +{ + /// Create a KZG proof. + /// Parameters: + /// - `srs`: the structured reference string used to commit + /// to the polynomials + /// - `plnms`: the list of polynomials to open. + /// The type is simply an alias to handle the polynomials in evaluations or + /// coefficients forms. + /// - `elm`: vector of evaluation points. Note that it only works for two + /// elements for now. + /// - `polyscale`: a challenge to batch the polynomials. + pub fn create>( + srs: &PairingSRS, + plnms: PolynomialsToCombine, + elm: &[F], + polyscale: F, + ) -> Option { + let (p, blinding_factor) = combine_polys::(plnms, polyscale, srs.full_srs.g.len()); + let evals: Vec<_> = elm.iter().map(|pt| p.evaluate(pt)).collect(); + + let quotient_poly = { + // This is where the condition on two points is enforced. + let eval_polynomial = eval_polynomial(elm, &evals); + let divisor_polynomial = divisor_polynomial(elm); + let numerator_polynomial = &p - &eval_polynomial; + let (quotient, remainder) = DenseOrSparsePolynomial::divide_with_q_and_r( + &numerator_polynomial.into(), + &divisor_polynomial.into(), + )?; + if !remainder.is_zero() { + return None; + } + quotient + }; + + let quotient = srs + .full_srs + .commit_non_hiding("ient_poly, 1) + .get_first_chunk(); + + Some(KZGProof { + quotient, + blinding: blinding_factor, + }) + } + + /// Verify a proof. Note that it only works for two elements for now, i.e. + /// elm must be of size 2. + /// Also, chunking is not supported. + pub fn verify( + &self, + srs: &PairingSRS, // SRS + evaluations: &Vec>, // commitments to the polynomials + polyscale: F, // scaling factor for polynoms + elm: &[F], // vector of evaluation points + ) -> bool { + let poly_commitment: G::Group = { + let mut scalars: Vec = Vec::new(); + let mut points = Vec::new(); + combine_commitments( + evaluations, + &mut scalars, + &mut points, + polyscale, + F::one(), /* TODO: This is inefficient */ + ); + let scalars: Vec<_> = scalars.iter().map(|x| x.into_bigint()).collect(); + + G::Group::msm_bigint(&points, &scalars) + }; + + // IMPROVEME: we could have a single flat array for all evaluations, see + // same comment in combine_evaluations + let evals = combine_evaluations(evaluations, polyscale); + let blinding_commitment = srs.full_srs.h.mul(self.blinding); + // Taking the first element of the commitment, i.e. no support for chunking. + let divisor_commitment = srs + .verifier_srs + .commit_non_hiding(&divisor_polynomial(elm), 1) + .get_first_chunk(); + // Taking the first element of the commitment, i.e. no support for chunking. + let eval_commitment = srs + .full_srs + .commit_non_hiding(&eval_polynomial(elm, &evals), 1) + .get_first_chunk() + .into_group(); + let numerator_commitment = { poly_commitment - eval_commitment - blinding_commitment }; + // We compute the result of the multiplication of two miller loop, + // to apply only one final exponentation + let to_loop_left = [ + ark_ec::pairing::prepare_g1::(numerator_commitment), + // Note that we do a neagtion here, to put everything on the same side + ark_ec::pairing::prepare_g1::(self.quotient.into_group().neg()), + ]; + let to_loop_right = [ + ark_ec::pairing::prepare_g2::(Pair::G2Affine::generator()), + ark_ec::pairing::prepare_g2::(divisor_commitment), + ]; + // the result here is numerator_commitment * 1 - quotient * divisor_commitment + // Note that the unwrap cannot fail as the output of a miller loop is non zero + let res = Pair::multi_pairing(to_loop_left, to_loop_right); + + res.is_zero() + } +} diff --git a/poly-commitment/src/lib.rs b/poly-commitment/src/lib.rs index 21fb779d59..06e546a9e7 100644 --- a/poly-commitment/src/lib.rs +++ b/poly-commitment/src/lib.rs @@ -1,21 +1,20 @@ -pub mod chunked; mod combine; pub mod commitment; pub mod error; -pub mod evaluation_proof; pub mod hash_map_cache; -pub mod pairing_proof; -pub mod srs; +pub mod ipa; +pub mod kzg; +pub mod utils; -#[cfg(test)] -mod tests; +// Exposing property based tests for the SRS trait +pub mod pbt_srs; pub use commitment::PolyComm; use crate::{ commitment::{BatchEvaluationProof, BlindedCommitment, CommitmentCurve}, error::CommitmentError, - evaluation_proof::DensePolynomialOrEvaluations, + utils::DensePolynomialOrEvaluations, }; use ark_ec::AffineRepr; use ark_ff::UniformRand; @@ -25,32 +24,31 @@ use ark_poly::{ use mina_poseidon::FqSponge; use rand_core::{CryptoRng, RngCore}; -pub trait SRS { +pub trait SRS: Clone + Sized { /// The maximum polynomial degree that can be committed to fn max_poly_size(&self) -> usize; - /// Retrieve the precomputed Lagrange basis for the given domain size - fn get_lagrange_basis(&self, domain_size: usize) -> &Vec>; - /// Get the group element used for blinding commitments fn blinding_commitment(&self) -> G; - /// Commits a polynomial, potentially splitting the result in multiple commitments. - fn commit( - &self, - plnm: &DensePolynomial, - num_chunks: usize, - rng: &mut (impl RngCore + CryptoRng), - ) -> BlindedCommitment; - /// Same as [SRS::mask] except that you can pass the blinders manually. + /// A [BlindedCommitment] object is returned instead of a PolyComm object to + /// keep the blinding factors and the commitment together. The blinded + /// commitment is saved in the commitment field of the output. + /// The output is wrapped into a [Result] to handle the case the blinders + /// are not the same length than the number of chunks commitments have. fn mask_custom( &self, com: PolyComm, blinders: &PolyComm, ) -> Result, CommitmentError>; - /// Turns a non-hiding polynomial commitment into a hidding polynomial commitment. Transforms each given `` into `( + wH, w)` with a random `w` per commitment. + /// Turns a non-hiding polynomial commitment into a hidding polynomial + /// commitment. Transforms each given `` into `( + wH, w)` with + /// a random `w` per commitment. + /// A [BlindedCommitment] object is returned instead of a PolyComm object to + /// keep the blinding factors and the commitment together. The blinded + /// commitment is saved in the commitment field of the output. fn mask( &self, comm: PolyComm, @@ -61,49 +59,151 @@ pub trait SRS { } /// This function commits a polynomial using the SRS' basis of size `n`. - /// - `plnm`: polynomial to commit to with max size of sections - /// - `num_chunks`: the number of commitments to be included in the output polynomial commitment - /// The function returns an unbounded commitment vector - /// (which splits the commitment into several commitments of size at most `n`). + /// - `plnm`: polynomial to commit to. The polynomial can be of any degree, + /// including higher than `n`. + /// - `num_chunks`: the minimal number of commitments to be included in the + /// output polynomial commitment. + /// + /// The function returns the commitments to the chunks (of size at most `n`) of + /// the polynomials. + /// + /// The function will also pad with zeroes if the polynomial has a degree + /// smaller than `n * num_chunks`. + /// + /// Note that if the polynomial has a degree higher than `n * num_chunks`, + /// the output will contain more than `num_chunks` commitments as it will + /// also contain the additional chunks. + /// + /// See the test + /// [crate::pbt_srs::test_regression_commit_non_hiding_expected_number_of_chunks] + /// for an example of the number of chunks returned. fn commit_non_hiding( &self, plnm: &DensePolynomial, num_chunks: usize, ) -> PolyComm; + /// Commits a polynomial, potentially splitting the result in multiple + /// commitments. + /// It is analogous to [SRS::commit_evaluations] but for polynomials. + /// A [BlindedCommitment] object is returned instead of a PolyComm object to + /// keep the blinding factors and the commitment together. The blinded + /// commitment is saved in the commitment field of the output. + fn commit( + &self, + plnm: &DensePolynomial, + num_chunks: usize, + rng: &mut (impl RngCore + CryptoRng), + ) -> BlindedCommitment; + + /// Commit to a polynomial, with custom blinding factors. + /// It is a combination of [SRS::commit] and [SRS::mask_custom]. + /// It is analogous to [SRS::commit_evaluations_custom] but for polynomials. + /// A [BlindedCommitment] object is returned instead of a PolyComm object to + /// keep the blinding factors and the commitment together. The blinded + /// commitment is saved in the commitment field of the output. + /// The output is wrapped into a [Result] to handle the case the blinders + /// are not the same length than the number of chunks commitments have. + fn commit_custom( + &self, + plnm: &DensePolynomial, + num_chunks: usize, + blinders: &PolyComm, + ) -> Result, CommitmentError>; + + /// Commit to evaluations, without blinding factors. + /// It is analogous to [SRS::commit_non_hiding] but for evaluations. fn commit_evaluations_non_hiding( &self, domain: D, plnm: &Evaluations>, ) -> PolyComm; + /// Commit to evaluations with blinding factors, generated using the random + /// number generator `rng`. + /// It is analogous to [SRS::commit] but for evaluations. + /// A [BlindedCommitment] object is returned instead of a PolyComm object to + /// keep the blinding factors and the commitment together. The blinded + /// commitment is saved in the commitment field of the output. fn commit_evaluations( &self, domain: D, plnm: &Evaluations>, rng: &mut (impl RngCore + CryptoRng), ) -> BlindedCommitment; + + /// Commit to evaluations with custom blinding factors. + /// It is a combination of [SRS::commit_evaluations] and [SRS::mask_custom]. + /// It is analogous to [SRS::commit_custom] but for evaluations. + /// A [BlindedCommitment] object is returned instead of a PolyComm object to + /// keep the blinding factors and the commitment together. The blinded + /// commitment is saved in the commitment field of the output. + /// The output is wrapped into a [Result] to handle the case the blinders + /// are not the same length than the number of chunks commitments have. + fn commit_evaluations_custom( + &self, + domain: D, + plnm: &Evaluations>, + blinders: &PolyComm, + ) -> Result, CommitmentError>; + + /// Create an SRS of size `depth`. + /// + /// Warning: in the case of a trusted setup, as it is required for a + /// polynomial commitment scheme like KZG, a toxic waste is generated using + /// `rand::thread_rng()`. This is not the behavior you would expect in a + /// production environment. + /// However, we do accept this behavior for the sake of simplicity in the + /// interface, and this method will only be supposed to be used in tests in + /// this case. + fn create(depth: usize) -> Self; + + /// Compute commitments to the lagrange basis corresponding to the given domain and + /// cache them in the SRS + fn get_lagrange_basis(&self, domain: D) -> &Vec>; + + /// Same as `get_lagrange_basis` but only using the domain size. + fn get_lagrange_basis_from_domain_size(&self, domain_size: usize) -> &Vec>; + + fn size(&self) -> usize; } #[allow(type_alias_bounds)] -/// Vector of triples (polynomial itself, degree bound, omegas). +/// An alias to represent a polynomial (in either coefficient or +/// evaluation form), with a set of *scalar field* elements that +/// represent the exponent of its blinder. +// TODO: add a string to name the polynomial type PolynomialsToCombine<'a, G: CommitmentCurve, D: EvaluationDomain> = &'a [( DensePolynomialOrEvaluations<'a, G::ScalarField, D>, PolyComm, )]; -pub trait OpenProof: Sized { +pub trait OpenProof: Sized + Clone { type SRS: SRS + std::fmt::Debug; + /// Create an opening proof for a batch of polynomials. The parameters are + /// the following: + /// - `srs`: the structured reference string used to commit + /// to the polynomials + /// - `group_map`: the group map + /// - `plnms`: the list of polynomials to open, with possible blinders. + /// The type is simply an alias to handle the polynomials in evaluations or + /// coefficients forms. + /// - `elm`: the evaluation points + /// - `polyscale`: a challenge to bacth the polynomials. + /// - `evalscale`: a challenge to bacth the evaluation points + /// - `sponge`: Sponge used to coin and absorb values and simulate + /// non-interactivity using the Fiat-Shamir transformation. + /// - `rng`: a pseudo random number generator used for zero-knowledge #[allow(clippy::too_many_arguments)] fn open::ScalarField>>( srs: &Self::SRS, group_map: &::Map, - plnms: PolynomialsToCombine, // vector of polynomial with optional degree bound and commitment randomness - elm: &[::ScalarField], // vector of evaluation points - polyscale: ::ScalarField, // scaling factor for polynoms - evalscale: ::ScalarField, // scaling factor for evaluation point powers - sponge: EFqSponge, // sponge + plnms: PolynomialsToCombine, + elm: &[::ScalarField], + polyscale: ::ScalarField, + evalscale: ::ScalarField, + sponge: EFqSponge, rng: &mut RNG, ) -> Self where @@ -111,6 +211,7 @@ pub trait OpenProof: Sized { Clone + FqSponge<::BaseField, G, ::ScalarField>, RNG: RngCore + CryptoRng; + /// Verify the opening proof fn verify( srs: &Self::SRS, group_map: &G::Map, diff --git a/poly-commitment/src/pairing_proof.rs b/poly-commitment/src/pairing_proof.rs deleted file mode 100644 index 82bc23e160..0000000000 --- a/poly-commitment/src/pairing_proof.rs +++ /dev/null @@ -1,416 +0,0 @@ -use crate::{ - commitment::*, evaluation_proof::combine_polys, srs::SRS, CommitmentError, - PolynomialsToCombine, SRS as SRSTrait, -}; -use ark_ec::{pairing::Pairing, AffineRepr, VariableBaseMSM}; -use ark_ff::{PrimeField, Zero}; -use ark_poly::{ - univariate::{DenseOrSparsePolynomial, DensePolynomial}, - DenseUVPolynomial, EvaluationDomain, Evaluations, Polynomial, Radix2EvaluationDomain as D, -}; -use mina_poseidon::FqSponge; -use rand_core::{CryptoRng, RngCore}; -use serde::{Deserialize, Serialize}; -use serde_with::serde_as; - -#[serde_as] -#[derive(Debug, Serialize, Deserialize)] -#[serde( - bound = "Pair::G1Affine: ark_serialize::CanonicalDeserialize + ark_serialize::CanonicalSerialize" -)] -pub struct PairingProof { - #[serde_as(as = "o1_utils::serialization::SerdeAs")] - pub quotient: Pair::G1Affine, - #[serde_as(as = "o1_utils::serialization::SerdeAs")] - pub blinding: ::ScalarField, -} - -impl Default for PairingProof { - fn default() -> Self { - Self { - quotient: Pair::G1Affine::generator(), - blinding: ::ScalarField::zero(), - } - } -} - -impl Clone for PairingProof { - fn clone(&self) -> Self { - Self { - quotient: self.quotient, - blinding: self.blinding, - } - } -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct PairingSRS { - pub full_srs: SRS, - pub verifier_srs: SRS, -} - -impl Default for PairingSRS { - fn default() -> Self { - Self { - full_srs: SRS::default(), - verifier_srs: SRS::default(), - } - } -} - -impl Clone for PairingSRS { - fn clone(&self) -> Self { - Self { - full_srs: self.full_srs.clone(), - verifier_srs: self.verifier_srs.clone(), - } - } -} - -impl< - F: PrimeField, - G: CommitmentCurve, - G2: CommitmentCurve, - Pair: Pairing, - > PairingSRS -{ - pub fn create(x: F, n: usize) -> Self { - PairingSRS { - full_srs: SRS::create_trusted_setup(x, n), - verifier_srs: SRS::create_trusted_setup(x, 3), - } - } -} - -impl< - F: PrimeField, - G: CommitmentCurve, - G2: CommitmentCurve, - Pair: Pairing, - > crate::OpenProof for PairingProof -{ - type SRS = PairingSRS; - - fn open::ScalarField>>( - srs: &Self::SRS, - _group_map: &::Map, - plnms: PolynomialsToCombine, - elm: &[::ScalarField], // vector of evaluation points - polyscale: ::ScalarField, // scaling factor for polynoms - _evalscale: ::ScalarField, // scaling factor for evaluation point powers - _sponge: EFqSponge, // sponge - _rng: &mut RNG, - ) -> Self - where - EFqSponge: - Clone + FqSponge<::BaseField, G, ::ScalarField>, - RNG: RngCore + CryptoRng, - { - PairingProof::create(srs, plnms, elm, polyscale).unwrap() - } - - fn verify( - srs: &Self::SRS, - _group_map: &G::Map, - batch: &mut [BatchEvaluationProof], - _rng: &mut RNG, - ) -> bool - where - EFqSponge: FqSponge, - RNG: RngCore + CryptoRng, - { - for BatchEvaluationProof { - sponge: _, - evaluations, - evaluation_points, - polyscale, - evalscale: _, - opening, - combined_inner_product: _, - } in batch.iter() - { - if !opening.verify(srs, evaluations, *polyscale, evaluation_points) { - return false; - } - } - true - } -} - -impl< - F: PrimeField, - G: CommitmentCurve, - G2: CommitmentCurve, - Pair: Pairing, - > SRSTrait for PairingSRS -{ - fn max_poly_size(&self) -> usize { - self.full_srs.max_poly_size() - } - - fn get_lagrange_basis(&self, domain_size: usize) -> &Vec> { - self.full_srs - .get_lagrange_basis_from_domain_size(domain_size) - } - - fn blinding_commitment(&self) -> G { - self.full_srs.blinding_commitment() - } - - fn commit( - &self, - plnm: &DensePolynomial, - num_chunks: usize, - rng: &mut (impl RngCore + CryptoRng), - ) -> BlindedCommitment { - self.full_srs.commit(plnm, num_chunks, rng) - } - - fn mask_custom( - &self, - com: PolyComm, - blinders: &PolyComm, - ) -> Result, CommitmentError> { - self.full_srs.mask_custom(com, blinders) - } - - fn mask( - &self, - comm: PolyComm, - rng: &mut (impl RngCore + CryptoRng), - ) -> BlindedCommitment { - self.full_srs.mask(comm, rng) - } - - fn commit_non_hiding( - &self, - plnm: &DensePolynomial, - num_chunks: usize, - ) -> PolyComm { - self.full_srs.commit_non_hiding(plnm, num_chunks) - } - - fn commit_evaluations_non_hiding( - &self, - domain: D, - plnm: &Evaluations>, - ) -> PolyComm { - self.full_srs.commit_evaluations_non_hiding(domain, plnm) - } - - fn commit_evaluations( - &self, - domain: D, - plnm: &Evaluations>, - rng: &mut (impl RngCore + CryptoRng), - ) -> BlindedCommitment { - self.full_srs.commit_evaluations(domain, plnm, rng) - } -} - -/// The polynomial that evaluates to each of `evals` for the respective `elm`s. -fn eval_polynomial(elm: &[F], evals: &[F]) -> DensePolynomial { - assert_eq!(elm.len(), evals.len()); - let (zeta, zeta_omega) = if elm.len() == 2 { - (elm[0], elm[1]) - } else { - todo!() - }; - let (eval_zeta, eval_zeta_omega) = if evals.len() == 2 { - (evals[0], evals[1]) - } else { - todo!() - }; - - // The polynomial that evaluates to `p(zeta)` at `zeta` and `p(zeta_omega)` at - // `zeta_omega`. - // We write `p(x) = a + bx`, which gives - // ```text - // p(zeta) = a + b * zeta - // p(zeta_omega) = a + b * zeta_omega - // ``` - // and so - // ```text - // b = (p(zeta_omega) - p(zeta)) / (zeta_omega - zeta) - // a = p(zeta) - b * zeta - // ``` - let b = (eval_zeta_omega - eval_zeta) / (zeta_omega - zeta); - let a = eval_zeta - b * zeta; - DensePolynomial::from_coefficients_slice(&[a, b]) -} - -/// The polynomial that evaluates to `0` at the evaluation points. -fn divisor_polynomial(elm: &[F]) -> DensePolynomial { - elm.iter() - .map(|value| DensePolynomial::from_coefficients_slice(&[-(*value), F::one()])) - .reduce(|poly1, poly2| &poly1 * &poly2) - .unwrap() -} - -impl< - F: PrimeField, - G: CommitmentCurve, - G2: CommitmentCurve, - Pair: Pairing, - > PairingProof -{ - pub fn create>( - srs: &PairingSRS, - plnms: PolynomialsToCombine, // vector of polynomial with optional degree bound and commitment randomness - elm: &[G::ScalarField], // vector of evaluation points - polyscale: G::ScalarField, // scaling factor for polynoms - ) -> Option { - let (p, blinding_factor) = combine_polys::(plnms, polyscale, srs.full_srs.g.len()); - let evals: Vec<_> = elm.iter().map(|pt| p.evaluate(pt)).collect(); - - let quotient_poly = { - let eval_polynomial = eval_polynomial(elm, &evals); - let divisor_polynomial = divisor_polynomial(elm); - let numerator_polynomial = &p - &eval_polynomial; - let (quotient, remainder) = DenseOrSparsePolynomial::divide_with_q_and_r( - &numerator_polynomial.into(), - &divisor_polynomial.into(), - )?; - if !remainder.is_zero() { - return None; - } - quotient - }; - - let quotient = srs.full_srs.commit_non_hiding("ient_poly, 1).elems[0]; - - Some(PairingProof { - quotient, - blinding: blinding_factor, - }) - } - - pub fn verify( - &self, - srs: &PairingSRS, // SRS - evaluations: &Vec>, // commitments to the polynomials - polyscale: G::ScalarField, // scaling factor for polynoms - elm: &[G::ScalarField], // vector of evaluation points - ) -> bool { - let poly_commitment: G::Group = { - let mut scalars: Vec = Vec::new(); - let mut points = Vec::new(); - combine_commitments( - evaluations, - &mut scalars, - &mut points, - polyscale, - F::one(), /* TODO: This is inefficient */ - ); - let scalars: Vec<_> = scalars.iter().map(|x| x.into_bigint()).collect(); - - G::Group::msm_bigint(&points, &scalars) - }; - let evals = combine_evaluations(evaluations, polyscale); - let blinding_commitment = srs.full_srs.h.mul(self.blinding); - let divisor_commitment = srs - .verifier_srs - .commit_non_hiding(&divisor_polynomial(elm), 1) - .elems[0]; - let eval_commitment = srs - .full_srs - .commit_non_hiding(&eval_polynomial(elm, &evals), 1) - .elems[0] - .into_group(); - let numerator_commitment_proj: ::Group = - { poly_commitment - eval_commitment - blinding_commitment }; - let numerator_commitment_affine: Pair::G1Affine = From::from(numerator_commitment_proj); - - let numerator = Pair::pairing(numerator_commitment_affine, Pair::G2Affine::generator()); - let scaled_quotient = Pair::pairing(self.quotient, divisor_commitment); - numerator == scaled_quotient - } -} - -#[cfg(test)] -mod tests { - use super::{PairingProof, PairingSRS}; - use crate::{ - commitment::Evaluation, evaluation_proof::DensePolynomialOrEvaluations, srs::SRS, SRS as _, - }; - use ark_bn254::{Config, Fr as ScalarField, G1Affine as G1, G2Affine as G2}; - use ark_ec::bn::Bn; - use ark_ff::UniformRand; - use ark_poly::{ - univariate::DensePolynomial, DenseUVPolynomial, EvaluationDomain, Polynomial, - Radix2EvaluationDomain as D, - }; - - #[test] - fn test_pairing_proof() { - let n = 64; - let domain = D::::new(n).unwrap(); - - let rng = &mut o1_utils::tests::make_test_rng(None); - - let x = ScalarField::rand(rng); - - let srs = SRS::::create_trusted_setup(x, n); - let verifier_srs = SRS::::create_trusted_setup(x, 3); - srs.get_lagrange_basis(domain); - - let srs = PairingSRS { - full_srs: srs, - verifier_srs, - }; - - let polynomials: Vec<_> = (0..4) - .map(|_| { - let coeffs = (0..63).map(|_| ScalarField::rand(rng)).collect(); - DensePolynomial::from_coefficients_vec(coeffs) - }) - .collect(); - - let comms: Vec<_> = polynomials - .iter() - .map(|p| srs.full_srs.commit(p, 1, rng)) - .collect(); - - let polynomials_and_blinders: Vec<(DensePolynomialOrEvaluations<_, D<_>>, _)> = polynomials - .iter() - .zip(comms.iter()) - .map(|(p, comm)| { - let p = DensePolynomialOrEvaluations::DensePolynomial(p); - (p, comm.blinders.clone()) - }) - .collect(); - - let evaluation_points = vec![ScalarField::rand(rng), ScalarField::rand(rng)]; - - let evaluations: Vec<_> = polynomials - .iter() - .zip(comms) - .map(|(p, commitment)| { - let evaluations = evaluation_points - .iter() - .map(|x| { - // Inputs are chosen to use only 1 chunk - vec![p.evaluate(x)] - }) - .collect(); - Evaluation { - commitment: commitment.commitment, - evaluations, - } - }) - .collect(); - - let polyscale = ScalarField::rand(rng); - - let pairing_proof = PairingProof::>::create( - &srs, - polynomials_and_blinders.as_slice(), - &evaluation_points, - polyscale, - ) - .unwrap(); - - let res = pairing_proof.verify(&srs, &evaluations, polyscale, &evaluation_points); - assert!(res); - } -} diff --git a/poly-commitment/src/pbt_srs.rs b/poly-commitment/src/pbt_srs.rs new file mode 100644 index 0000000000..c21074d0d9 --- /dev/null +++ b/poly-commitment/src/pbt_srs.rs @@ -0,0 +1,77 @@ +//! This module defines Property-based tests for the SRS trait. +//! It includes tests regarding methods the SRS trait should implement. +//! It aims to verify the implementation respects the properties described in +//! the documentation of the methods. + +use ark_ff::Zero; +use ark_poly::{univariate::DensePolynomial, DenseUVPolynomial}; +use rand::Rng; + +use crate::{commitment::CommitmentCurve, SRS}; + +// Testing how many chunks are generated with different polynomial sizes and +// different number of chunks requested. +pub fn test_regression_commit_non_hiding_expected_number_of_chunks< + G: CommitmentCurve, + Srs: SRS, +>() { + let mut rng = &mut o1_utils::tests::make_test_rng(None); + // maximum random srs size is 64 + let log2_srs_size = rng.gen_range(1..6); + let srs_size = 1 << log2_srs_size; + let srs = Srs::create(srs_size); + + // If we have a polynomial of the size of the SRS (i.e. of degree `srs_size - + // 1`), and we request 1 chunk, we should get 1 chunk. + { + // srs_size is the number of evaluation degree + let poly_degree = srs_size - 1; + let poly = DensePolynomial::::rand(poly_degree, &mut rng); + let commitment = srs.commit_non_hiding(&poly, 1); + assert_eq!(commitment.len(), 1); + } + + // If we have a polynomial of the size of the SRS (i.e. of degree `srs_size - + // 1`), and we request k chunks (k > 1), we should get k chunk. + { + // srs_size is the number of evaluation degree + let poly_degree = srs_size - 1; + // maximum 10 chunks for the test + let k = rng.gen_range(2..10); + let poly = DensePolynomial::::rand(poly_degree, &mut rng); + let commitment = srs.commit_non_hiding(&poly, k); + assert_eq!(commitment.len(), k); + } + + // Same than the two previous cases, but with the special polynomial equals to zero. + { + let k = rng.gen_range(1..10); + let poly = DensePolynomial::::zero(); + let commitment = srs.commit_non_hiding(&poly, k); + assert_eq!(commitment.len(), k); + } + + // Polynomial of exactly a multiple of the SRS size, i.e degree is k * + // srs_size - 1. + { + let k = rng.gen_range(2..5); + let poly_degree = k * srs_size - 1; + let poly = DensePolynomial::::rand(poly_degree, &mut rng); + // if we request a number of chunks smaller than the multiple, we will + // still get a number of chunks equals to the multiple. + let requested_num_chunks = rng.gen_range(1..k); + let commitment = srs.commit_non_hiding(&poly, requested_num_chunks); + assert_eq!(commitment.len(), k); + + // if we request a number of chunks equals to the multiple, we will get + // the exact number of chunks. + let commitment = srs.commit_non_hiding(&poly, k); + assert_eq!(commitment.len(), k); + + // if we request a number of chunks greater than the multiple, we will + // get the exact number of chunks requested. + let requested_num_chunks = rng.gen_range(k + 1..10); + let commitment = srs.commit_non_hiding(&poly, requested_num_chunks); + assert_eq!(commitment.len(), requested_num_chunks); + } +} diff --git a/poly-commitment/src/srs.rs b/poly-commitment/src/srs.rs deleted file mode 100644 index 7c8eb3b517..0000000000 --- a/poly-commitment/src/srs.rs +++ /dev/null @@ -1,305 +0,0 @@ -//! This module implements the Marlin structured reference string primitive - -use crate::{commitment::CommitmentCurve, hash_map_cache::HashMapCache, PolyComm}; -use ark_ec::{AffineRepr, CurveGroup}; -use ark_ff::{BigInteger, Field, One, PrimeField, Zero}; -use ark_poly::{EvaluationDomain, Radix2EvaluationDomain as D}; -use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; -use blake2::{Blake2b512, Digest}; -use groupmap::GroupMap; -use rayon::prelude::*; -use serde::{Deserialize, Serialize}; -use serde_with::serde_as; -use std::{array, cmp::min}; - -#[serde_as] -#[derive(Debug, Clone, Default, Serialize, Deserialize)] -#[serde(bound = "G: CanonicalDeserialize + CanonicalSerialize")] -pub struct SRS { - /// The vector of group elements for committing to polynomials in coefficient form - #[serde_as(as = "Vec")] - pub g: Vec, - /// A group element used for blinding commitments - #[serde_as(as = "o1_utils::serialization::SerdeAs")] - pub h: G, - - // TODO: the following field should be separated, as they are optimization values - /// Commitments to Lagrange bases, per domain size - #[serde(skip)] - pub lagrange_bases: HashMapCache>>, -} - -impl PartialEq for SRS -where - G: PartialEq, -{ - fn eq(&self, other: &Self) -> bool { - self.g == other.g && self.h == other.h - } -} - -pub fn endos() -> (G::BaseField, G::ScalarField) -where - G::BaseField: PrimeField, -{ - let endo_q: G::BaseField = mina_poseidon::sponge::endo_coefficient(); - let endo_r = { - let potential_endo_r: G::ScalarField = mina_poseidon::sponge::endo_coefficient(); - let t = G::generator(); - let (x, y) = t.to_coordinates().unwrap(); - let phi_t = G::of_coordinates(x * endo_q, y); - if t.mul(potential_endo_r) == phi_t.into_group() { - potential_endo_r - } else { - potential_endo_r * potential_endo_r - } - }; - (endo_q, endo_r) -} - -fn point_of_random_bytes(map: &G::Map, random_bytes: &[u8]) -> G -where - G::BaseField: Field, -{ - // packing in bit-representation - const N: usize = 31; - let extension_degree = G::BaseField::extension_degree() as usize; - - let mut base_fields = Vec::with_capacity(N * extension_degree); - - for base_count in 0..extension_degree { - let mut bits = [false; 8 * N]; - let offset = base_count * N; - for i in 0..N { - for j in 0..8 { - bits[8 * i + j] = (random_bytes[offset + i] >> j) & 1 == 1; - } - } - - let n = - <::BasePrimeField as PrimeField>::BigInt::from_bits_be(&bits); - let t = <::BasePrimeField as PrimeField>::from_bigint(n) - .expect("packing code has a bug"); - base_fields.push(t) - } - let t = G::BaseField::from_base_prime_field_elems(&base_fields).unwrap(); - - let (x, y) = map.to_group(t); - G::of_coordinates(x, y) -} - -impl SRS { - pub fn max_degree(&self) -> usize { - self.g.len() - } - - fn lagrange_basis(&self, domain: D) -> Vec> { - let n = domain.size(); - - // Let V be a vector space over the field F. - // - // Given - // - a domain [ 1, w, w^2, ..., w^{n - 1} ] - // - a vector v := [ v_0, ..., v_{n - 1} ] in V^n - // - // the FFT algorithm computes the matrix application - // - // u = M(w) * v - // - // where - // M(w) = - // 1 1 1 ... 1 - // 1 w w^2 ... w^{n-1} - // ... - // 1 w^{n-1} (w^2)^{n-1} ... (w^{n-1})^{n-1} - // - // The IFFT algorithm computes - // - // v = M(w)^{-1} * u - // - // Let's see how we can use this algorithm to compute the lagrange basis - // commitments. - // - // Let V be the vector space F[x] of polynomials in x over F. - // Let v in V be the vector [ L_0, ..., L_{n - 1} ] where L_i is the i^{th} - // normalized Lagrange polynomial (where L_i(w^j) = j == i ? 1 : 0). - // - // Consider the rows of M(w) * v. Let me write out the matrix and vector so you - // can see more easily. - // - // | 1 1 1 ... 1 | | L_0 | - // | 1 w w^2 ... w^{n-1} | * | L_1 | - // | ... | | ... | - // | 1 w^{n-1} (w^2)^{n-1} ... (w^{n-1})^{n-1} | | L_{n-1} | - // - // The 0th row is L_0 + L1 + ... + L_{n - 1}. So, it's the polynomial - // that has the value 1 on every element of the domain. - // In other words, it's the polynomial 1. - // - // The 1st row is L_0 + w L_1 + ... + w^{n - 1} L_{n - 1}. So, it's the - // polynomial which has value w^i on w^i. - // In other words, it's the polynomial x. - // - // In general, you can see that row i is in fact the polynomial x^i. - // - // Thus, M(w) * v is the vector u, where u = [ 1, x, x^2, ..., x^n ] - // - // Therefore, the IFFT algorithm, when applied to the vector u (the standard - // monomial basis) will yield the vector v of the (normalized) Lagrange polynomials. - // - // Now, because the polynomial commitment scheme is additively homomorphic, and - // because the commitment to the polynomial x^i is just self.g[i], we can obtain - // commitments to the normalized Lagrange polynomials by applying IFFT to the - // vector self.g[0..n]. - // - // - // Further still, we can do the same trick for 'chunked' polynomials. - // - // Recall that a chunked polynomial is some f of degree k*n - 1 with - // f(x) = f_0(x) + x^n f_1(x) + ... + x^{(k-1) n} f_{k-1}(x) - // where each f_i has degree n-1. - // - // In the above, if we set u = [ 1, x^2, ... x^{n-1}, 0, 0, .., 0 ] - // then we effectively 'zero out' any polynomial terms higher than x^{n-1}, leaving - // us with the 'partial Lagrange polynomials' that contribute to f_0. - // - // Similarly, u = [ 0, 0, ..., 0, 1, x^2, ..., x^{n-1}, 0, 0, ..., 0] with n leading - // zeros 'zeroes out' all terms except the 'partial Lagrange polynomials' that - // contribute to f_1, and likewise for each f_i. - // - // By computing each of these, and recollecting the terms as a vector of polynomial - // commitments, we obtain a chunked commitment to the L_i polynomials. - let srs_size = self.g.len(); - let num_elems = (n + srs_size - 1) / srs_size; - let mut elems = Vec::with_capacity(num_elems); - - // For each chunk - for i in 0..num_elems { - // Initialize the vector with zero curve points - let mut lg: Vec<::Group> = vec![::Group::zero(); n]; - // Overwrite the terms corresponding to that chunk with the SRS curve points - let start_offset = i * srs_size; - let num_terms = min((i + 1) * srs_size, n) - start_offset; - for j in 0..num_terms { - lg[start_offset + j] = self.g[j].into_group() - } - // Apply the IFFT - domain.ifft_in_place(&mut lg); - // Append the 'partial Langrange polynomials' to the vector of elems chunks - elems.push(::Group::normalize_batch(lg.as_mut_slice())); - } - - (0..n) - .map(|i| PolyComm { - elems: elems.iter().map(|v| v[i]).collect(), - }) - .collect() - } - - pub fn get_lagrange_basis_from_domain_size(&self, domain_size: usize) -> &Vec> { - self.lagrange_bases.get_or_generate(domain_size, || { - self.lagrange_basis(D::new(domain_size).unwrap()) - }) - } - - /// Compute commitments to the lagrange basis corresponding to the given domain and - /// cache them in the SRS - pub fn get_lagrange_basis(&self, domain: D) -> &Vec> { - self.lagrange_bases - .get_or_generate(domain.size(), || self.lagrange_basis(domain)) - } - - /// This function creates a trusted-setup SRS instance for circuits with number of rows up to `depth`. - pub fn create_trusted_setup(x: G::ScalarField, depth: usize) -> Self { - let m = G::Map::setup(); - - let mut x_pow = G::ScalarField::one(); - let g: Vec<_> = (0..depth) - .map(|_| { - let res = G::generator().mul(x_pow); - x_pow *= x; - res.into_affine() - }) - .collect(); - - const MISC: usize = 1; - let [h]: [G; MISC] = array::from_fn(|i| { - let mut h = Blake2b512::new(); - h.update("srs_misc".as_bytes()); - h.update((i as u32).to_be_bytes()); - point_of_random_bytes(&m, &h.finalize()) - }); - - SRS { - g, - h, - lagrange_bases: HashMapCache::new(), - } - } -} - -impl SRS -where - G::BaseField: PrimeField, -{ - /// This function creates SRS instance for circuits with number of rows up to `depth`. - pub fn create(depth: usize) -> Self { - let m = G::Map::setup(); - - let g: Vec<_> = (0..depth) - .map(|i| { - let mut h = Blake2b512::new(); - h.update((i as u32).to_be_bytes()); - point_of_random_bytes(&m, &h.finalize()) - }) - .collect(); - - const MISC: usize = 1; - let [h]: [G; MISC] = array::from_fn(|i| { - let mut h = Blake2b512::new(); - h.update("srs_misc".as_bytes()); - h.update((i as u32).to_be_bytes()); - point_of_random_bytes(&m, &h.finalize()) - }); - - SRS { - g, - h, - lagrange_bases: HashMapCache::new(), - } - } -} - -impl SRS -where - ::Map: Sync, - G::BaseField: PrimeField, -{ - /// This function creates SRS instance for circuits with number of rows up to `depth`. - pub fn create_parallel(depth: usize) -> Self { - let m = G::Map::setup(); - - let g: Vec<_> = (0..depth) - .into_par_iter() - .map(|i| { - let mut h = Blake2b512::new(); - h.update((i as u32).to_be_bytes()); - point_of_random_bytes(&m, &h.finalize()) - }) - .collect(); - - const MISC: usize = 1; - let [h]: [G; MISC] = array::from_fn(|i| { - let mut h = Blake2b512::new(); - h.update("srs_misc".as_bytes()); - h.update((i as u32).to_be_bytes()); - point_of_random_bytes(&m, &h.finalize()) - }); - - SRS { - g, - h, - lagrange_bases: HashMapCache::new(), - } - } -} diff --git a/poly-commitment/src/tests/commitment.rs b/poly-commitment/src/tests/commitment.rs deleted file mode 100644 index fac066d8a7..0000000000 --- a/poly-commitment/src/tests/commitment.rs +++ /dev/null @@ -1,279 +0,0 @@ -use crate::{ - commitment::{ - combined_inner_product, BatchEvaluationProof, BlindedCommitment, CommitmentCurve, - Evaluation, PolyComm, - }, - evaluation_proof::{DensePolynomialOrEvaluations, OpeningProof}, - srs::SRS, - SRS as _, -}; -use ark_ff::{UniformRand, Zero}; -use ark_poly::{univariate::DensePolynomial, DenseUVPolynomial, Radix2EvaluationDomain}; -use colored::Colorize; -use groupmap::GroupMap; -use mina_curves::pasta::{Fp, Vesta, VestaParameters}; -use mina_poseidon::{ - constants::PlonkSpongeConstantsKimchi as SC, sponge::DefaultFqSponge, FqSponge as _, -}; -use o1_utils::ExtendedDensePolynomial as _; -use rand::{CryptoRng, Rng, SeedableRng}; -use std::time::{Duration, Instant}; - -// Note: Because the current API uses large tuples of types, I re-create types -// in this test to facilitate aggregated proofs and batch verification of proofs. -// TODO: improve the polynomial commitment API - -/// A commitment -pub struct Commitment { - /// the commitment itself, potentially in chunks - chunked_commitment: PolyComm, -} - -/// An evaluated commitment (given a number of evaluation points) -pub struct EvaluatedCommitment { - /// the commitment - commit: Commitment, - /// the chunked evaluations given in the same order as the evaluation points - chunked_evals: Vec, -} - -/// A polynomial commitment evaluated at a point. Since a commitment can be chunked, the evaluations can also be chunked. -pub type ChunkedCommitmentEvaluation = Vec; - -mod prover { - use super::*; - - /// This struct represents a commitment with associated secret information - pub struct CommitmentAndSecrets { - /// the commitment evaluated at some points - pub eval_commit: EvaluatedCommitment, - /// the polynomial - pub poly: DensePolynomial, - /// the blinding part - pub chunked_blinding: PolyComm, - } -} - -/// This struct represents an aggregated evaluation proof for a number of polynomial commitments, as well as a number of evaluation points. -pub struct AggregatedEvaluationProof { - /// a number of evaluation points - eval_points: Vec, - /// a number of commitments evaluated at these evaluation points - eval_commitments: Vec, - /// the random value used to separate polynomials - polymask: Fp, - /// the random value used to separate evaluations - evalmask: Fp, - /// an Fq-sponge - fq_sponge: DefaultFqSponge, - /// the actual evaluation proof - pub proof: OpeningProof, -} - -impl AggregatedEvaluationProof { - /// This function converts an aggregated evaluation proof into something the verify API understands - pub fn verify_type( - &self, - ) -> BatchEvaluationProof, OpeningProof> - { - let mut coms = vec![]; - for eval_com in &self.eval_commitments { - assert_eq!(self.eval_points.len(), eval_com.chunked_evals.len()); - coms.push(Evaluation { - commitment: eval_com.commit.chunked_commitment.clone(), - evaluations: eval_com.chunked_evals.clone(), - }); - } - - let combined_inner_product = { - let es: Vec<_> = coms - .iter() - .map(|Evaluation { evaluations, .. }| evaluations.clone()) - .collect(); - combined_inner_product(&self.polymask, &self.evalmask, &es) - }; - - BatchEvaluationProof { - sponge: self.fq_sponge.clone(), - evaluation_points: self.eval_points.clone(), - polyscale: self.polymask, - evalscale: self.evalmask, - evaluations: coms, - opening: &self.proof, - combined_inner_product, - } - } -} - -pub fn generate_random_opening_proof( - mut rng: &mut RNG, - group_map: &::Map, - srs: &SRS, -) -> (Vec, Duration, Duration) { - let num_chunks = 1; - - let fq_sponge = DefaultFqSponge::::new( - mina_poseidon::pasta::fq_kimchi::static_params(), - ); - - let mut time_commit = Duration::new(0, 0); - let mut time_open = Duration::new(0, 0); - - // create 7 distinct "aggregated evaluation proofs" - let mut proofs = vec![]; - for _ in 0..7 { - // generate 7 random evaluation points - let eval_points: Vec = (0..7).map(|_| Fp::rand(&mut rng)).collect(); - - // create 11 polynomials of random degree (of at most 500) - // and commit to them - let mut commitments = vec![]; - for _ in 0..11 { - let len: usize = rng.gen(); - let len = len % 500; - // TODO @volhovm maybe remove the second case. - // every other polynomial is upperbounded - let poly = if len == 0 { - DensePolynomial::::zero() - } else { - DensePolynomial::::rand(len, &mut rng) - }; - - // create commitments for each polynomial, and evaluate each polynomial at the 7 random points - let timer = Instant::now(); - let BlindedCommitment { - commitment: chunked_commitment, - blinders: chunked_blinding, - } = srs.commit(&poly, num_chunks, &mut rng); - time_commit += timer.elapsed(); - - let mut chunked_evals = vec![]; - for point in eval_points.clone() { - let n = poly.len(); - let num_chunks = if n == 0 { - 1 - } else { - n / srs.g.len() + if n % srs.g.len() == 0 { 0 } else { 1 } - }; - chunked_evals.push( - poly.to_chunked_polynomial(num_chunks, srs.g.len()) - .evaluate_chunks(point), - ); - } - - let commit = Commitment { chunked_commitment }; - - let eval_commit = EvaluatedCommitment { - commit, - chunked_evals, - }; - - commitments.push(prover::CommitmentAndSecrets { - eval_commit, - poly, - chunked_blinding, - }); - } - - // create aggregated evaluation proof - #[allow(clippy::type_complexity)] - let mut polynomials: Vec<( - DensePolynomialOrEvaluations>, - PolyComm<_>, - )> = vec![]; - for c in &commitments { - polynomials.push(( - DensePolynomialOrEvaluations::DensePolynomial(&c.poly), - c.chunked_blinding.clone(), - )); - } - - let polymask = Fp::rand(&mut rng); - let evalmask = Fp::rand(&mut rng); - - let timer = Instant::now(); - let proof = srs.open::, _, _>( - group_map, - &polynomials, - &eval_points.clone(), - polymask, - evalmask, - fq_sponge.clone(), - &mut rng, - ); - time_open += timer.elapsed(); - - // prepare for batch verification - let eval_commitments = commitments.into_iter().map(|c| c.eval_commit).collect(); - proofs.push(AggregatedEvaluationProof { - eval_points, - eval_commitments, - polymask, - evalmask, - fq_sponge: fq_sponge.clone(), - proof, - }); - } - - (proofs, time_commit, time_open) -} - -fn test_randomised(mut rng: &mut RNG) { - let group_map = ::Map::setup(); - // create an SRS optimized for polynomials of degree 2^7 - 1 - let srs = SRS::::create(1 << 7); - - // TODO: move to bench - - let (proofs, time_commit, time_open) = - generate_random_opening_proof(&mut rng, &group_map, &srs); - - println!("{} {:?}", "total commitment time:".yellow(), time_commit); - println!( - "{} {:?}", - "total evaluation proof creation time:".magenta(), - time_open - ); - - let timer = Instant::now(); - - // batch verify all the proofs - let mut batch: Vec<_> = proofs.iter().map(|p| p.verify_type()).collect(); - assert!(srs.verify::, _>(&group_map, &mut batch, &mut rng)); - - // TODO: move to bench - println!( - "{} {:?}", - "batch verification time:".green(), - timer.elapsed() - ); -} - -#[test] -/// Tests polynomial commitments, batched openings and -/// verification of a batch of batched opening proofs of polynomial commitments -fn test_commit() -where - ::Err: std::fmt::Debug, -{ - // setup - let mut rng = o1_utils::tests::make_test_rng(None); - test_randomised(&mut rng) -} - -#[test] -/// Deterministic tests of polynomial commitments, batched openings and -/// verification of a batch of batched opening proofs of polynomial commitments -fn test_commit_deterministic() -where - ::Err: std::fmt::Debug, -{ - // Seed deliberately chosen to exercise zero commitments - let seed = [ - 17, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, - ]; - - let mut rng = ::from_seed(seed); - test_randomised(&mut rng) -} diff --git a/poly-commitment/src/tests/mod.rs b/poly-commitment/src/tests/mod.rs deleted file mode 100644 index f65c6f90b4..0000000000 --- a/poly-commitment/src/tests/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -mod batch_15_wires; -mod commitment; -mod serialization; diff --git a/poly-commitment/src/tests/serialization.rs b/poly-commitment/src/tests/serialization.rs deleted file mode 100644 index 2553076f51..0000000000 --- a/poly-commitment/src/tests/serialization.rs +++ /dev/null @@ -1,159 +0,0 @@ -//! This module checks serialization/regression for types defined within the crate. - -use crate::{commitment::CommitmentCurve, srs::SRS, SRS as _}; -use ark_ff::UniformRand; -use o1_utils::serialization::test_generic_serialization_regression_serde; - -#[test] -pub fn ser_regression_canonical_srs() { - use mina_curves::pasta::{Fp, Fq, Pallas, Vesta}; - - let rng = &mut o1_utils::tests::make_test_rng(Some([0u8; 32])); - - let td1 = Fp::rand(rng); - let data_expected = SRS::::create_trusted_setup(td1, 1 << 3); - // Generated with commit 1494cf973d40fb276465929eb7db1952c5de7bdc - let buf_expected: Vec = vec![ - 146, 152, 196, 33, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 196, 33, 2, 208, 196, 173, 72, 216, 133, 169, 56, 56, 33, 35, - 130, 78, 164, 182, 235, 78, 233, 153, 95, 96, 113, 78, 110, 205, 98, 59, 183, 156, 34, 26, - 128, 196, 33, 6, 150, 178, 230, 252, 128, 97, 30, 248, 199, 147, 159, 227, 118, 248, 138, - 60, 143, 178, 158, 37, 232, 110, 15, 134, 143, 127, 109, 206, 204, 155, 27, 128, 196, 33, - 64, 90, 121, 139, 173, 254, 255, 108, 129, 22, 165, 14, 110, 147, 48, 189, 183, 210, 237, - 108, 189, 170, 107, 238, 149, 155, 227, 211, 89, 63, 121, 46, 0, 196, 33, 90, 220, 159, - 218, 37, 222, 219, 32, 63, 233, 183, 226, 174, 205, 38, 189, 143, 33, 160, 169, 226, 235, - 216, 43, 17, 29, 215, 31, 150, 233, 163, 36, 0, 196, 33, 98, 225, 126, 245, 162, 255, 249, - 60, 120, 105, 186, 96, 169, 208, 83, 62, 19, 64, 187, 79, 11, 120, 130, 242, 249, 79, 249, - 99, 210, 225, 25, 36, 0, 196, 33, 49, 147, 222, 224, 242, 240, 198, 119, 133, 90, 152, 19, - 122, 52, 255, 181, 14, 55, 81, 250, 47, 167, 47, 195, 36, 7, 187, 103, 225, 0, 169, 26, 0, - 196, 33, 160, 235, 225, 204, 186, 77, 26, 177, 237, 210, 27, 246, 174, 136, 126, 204, 93, - 48, 11, 220, 178, 86, 174, 156, 6, 6, 86, 90, 105, 215, 117, 16, 128, 196, 33, 1, 34, 38, - 38, 91, 206, 178, 229, 168, 199, 139, 226, 117, 121, 162, 156, 54, 54, 120, 117, 99, 242, - 180, 170, 153, 201, 1, 99, 56, 96, 32, 9, 128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - ]; - - test_generic_serialization_regression_serde(data_expected, buf_expected); - - let td2 = Fq::rand(rng); - let data_expected = SRS::::create_trusted_setup(td2, 1 << 3); - // Generated with commit 1494cf973d40fb276465929eb7db1952c5de7bdc - let buf_expected: Vec = vec![ - 146, 152, 196, 33, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 196, 33, 144, 162, 168, 123, 25, 245, 211, 151, 234, 53, 230, - 184, 254, 200, 193, 156, 214, 207, 155, 171, 186, 143, 221, 21, 24, 49, 159, 187, 221, 215, - 92, 61, 0, 196, 33, 172, 18, 76, 72, 86, 115, 105, 59, 32, 148, 37, 128, 195, 62, 60, 165, - 201, 121, 67, 175, 73, 51, 78, 160, 25, 215, 242, 243, 71, 93, 49, 29, 128, 196, 33, 130, - 19, 51, 124, 82, 66, 228, 1, 66, 62, 240, 98, 240, 241, 101, 177, 252, 8, 42, 36, 76, 215, - 244, 24, 170, 221, 102, 204, 51, 183, 231, 52, 128, 196, 33, 121, 187, 189, 65, 178, 95, - 164, 135, 161, 57, 194, 93, 76, 12, 253, 165, 236, 21, 171, 199, 162, 16, 185, 75, 0, 120, - 171, 4, 1, 21, 184, 1, 0, 196, 33, 165, 208, 157, 8, 127, 129, 67, 81, 56, 223, 87, 125, - 139, 239, 36, 18, 139, 239, 53, 114, 116, 81, 1, 174, 76, 50, 97, 213, 108, 193, 107, 46, - 128, 196, 33, 83, 43, 226, 38, 146, 122, 97, 205, 114, 214, 23, 21, 165, 138, 211, 222, - 224, 190, 130, 70, 142, 203, 203, 89, 49, 138, 144, 104, 8, 247, 147, 11, 128, 196, 33, 73, - 128, 98, 223, 249, 164, 221, 198, 148, 190, 44, 37, 20, 106, 24, 112, 49, 72, 64, 157, 99, - 100, 170, 222, 105, 160, 160, 92, 194, 154, 93, 19, 128, 196, 33, 1, 130, 119, 215, 95, - 139, 130, 47, 90, 13, 171, 187, 79, 106, 134, 121, 50, 181, 54, 202, 63, 25, 38, 174, 42, - 5, 210, 172, 157, 149, 27, 34, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - ]; - test_generic_serialization_regression_serde(data_expected, buf_expected); -} - -#[test] -pub fn ser_regression_canonical_polycomm() { - use crate::{commitment::BlindedCommitment, srs::SRS}; - use ark_poly::{univariate::DensePolynomial, DenseUVPolynomial}; - use mina_curves::pasta::{Fp, Vesta}; - - let rng = &mut o1_utils::tests::make_test_rng(Some([0u8; 32])); - - let srs = SRS::::create(1 << 7); - - let com_length = 300; - let num_chunks = 6; - - let poly = DensePolynomial::::rand(com_length, rng); - - let BlindedCommitment { - commitment: chunked_commitment, - blinders: _blinders, - } = srs.commit(&poly, num_chunks, rng); - - let data_expected = chunked_commitment; - // Generated with commit 1494cf973d40fb276465929eb7db1952c5de7bdc - let buf_expected: Vec = vec![ - 145, 150, 196, 33, 36, 158, 70, 161, 147, 233, 138, 19, 54, 52, 87, 58, 158, 154, 255, 197, - 219, 225, 79, 25, 41, 193, 232, 64, 250, 71, 230, 154, 34, 145, 81, 23, 0, 196, 33, 37, - 227, 246, 88, 42, 44, 53, 244, 102, 92, 197, 246, 56, 56, 135, 155, 248, 155, 243, 23, 76, - 44, 94, 125, 60, 209, 195, 190, 73, 158, 97, 42, 128, 196, 33, 101, 254, 242, 100, 238, - 214, 56, 151, 94, 170, 69, 219, 239, 135, 253, 151, 207, 217, 47, 229, 75, 7, 41, 9, 131, - 205, 85, 171, 166, 213, 96, 52, 128, 196, 33, 53, 166, 32, 175, 166, 196, 121, 1, 25, 236, - 34, 226, 31, 145, 70, 96, 89, 179, 65, 35, 253, 161, 10, 211, 170, 116, 247, 40, 225, 104, - 155, 34, 0, 196, 33, 114, 151, 94, 73, 26, 234, 37, 98, 188, 142, 161, 165, 62, 238, 58, - 76, 200, 16, 62, 210, 124, 127, 229, 81, 119, 145, 43, 157, 254, 237, 154, 57, 128, 196, - 33, 212, 213, 38, 1, 17, 84, 147, 102, 31, 103, 242, 177, 110, 64, 239, 33, 211, 216, 40, - 103, 51, 55, 85, 96, 133, 20, 194, 6, 87, 180, 212, 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, - ]; - - test_generic_serialization_regression_serde(data_expected, buf_expected); -} - -#[test] -pub fn ser_regression_canonical_opening_proof() { - use crate::evaluation_proof::OpeningProof; - use groupmap::GroupMap; - use mina_curves::pasta::Vesta; - - let rng = &mut o1_utils::tests::make_test_rng(Some([0u8; 32])); - - let group_map = ::Map::setup(); - let srs = SRS::::create(1 << 7); - - let data_expected: OpeningProof = - crate::tests::commitment::generate_random_opening_proof(rng, &group_map, &srs).0[0] - .proof - .clone(); - - // Generated with commit 1494cf973d40fb276465929eb7db1952c5de7bdc - let buf_expected: Vec = vec![ - 149, 151, 146, 196, 33, 166, 183, 76, 209, 121, 62, 56, 194, 134, 133, 238, 52, 190, 163, - 234, 149, 245, 202, 234, 70, 56, 187, 130, 25, 73, 135, 136, 95, 248, 80, 124, 39, 128, - 196, 33, 78, 148, 92, 114, 104, 136, 170, 61, 102, 158, 80, 42, 184, 109, 110, 33, 228, - 116, 55, 98, 16, 192, 10, 159, 156, 72, 87, 129, 68, 66, 248, 26, 0, 146, 196, 33, 251, 20, - 242, 183, 87, 51, 106, 226, 80, 224, 139, 186, 33, 52, 203, 117, 6, 129, 167, 88, 252, 193, - 163, 38, 21, 37, 63, 254, 106, 136, 63, 21, 0, 196, 33, 4, 215, 169, 8, 207, 56, 209, 41, - 107, 189, 92, 110, 124, 186, 112, 193, 204, 173, 82, 46, 110, 8, 194, 193, 93, 130, 12, - 216, 24, 151, 94, 61, 128, 146, 196, 33, 55, 26, 135, 155, 211, 30, 67, 184, 93, 78, 146, - 166, 31, 11, 120, 93, 17, 24, 164, 39, 177, 98, 25, 156, 33, 5, 179, 64, 237, 69, 199, 11, - 128, 196, 33, 249, 229, 29, 113, 38, 26, 30, 205, 26, 217, 71, 248, 199, 157, 244, 196, 1, - 108, 74, 39, 173, 55, 118, 216, 191, 232, 27, 95, 190, 38, 96, 35, 128, 146, 196, 33, 109, - 160, 19, 169, 187, 111, 247, 152, 101, 20, 161, 251, 150, 61, 204, 78, 118, 171, 81, 1, - 253, 83, 64, 170, 93, 114, 216, 224, 33, 250, 202, 14, 0, 196, 33, 247, 34, 197, 187, 28, - 46, 42, 6, 126, 129, 132, 151, 39, 150, 115, 138, 229, 50, 220, 8, 170, 81, 173, 13, 54, - 57, 90, 169, 201, 212, 128, 50, 128, 146, 196, 33, 73, 206, 252, 115, 3, 45, 100, 75, 98, - 139, 35, 227, 181, 241, 175, 2, 175, 14, 132, 86, 0, 174, 64, 84, 24, 88, 18, 163, 82, 102, - 164, 19, 0, 196, 33, 128, 62, 85, 80, 76, 1, 35, 195, 197, 48, 46, 1, 13, 183, 105, 91, - 243, 109, 124, 68, 41, 21, 42, 228, 124, 28, 193, 188, 85, 6, 180, 24, 128, 146, 196, 33, - 21, 245, 240, 236, 40, 30, 75, 91, 87, 50, 153, 173, 88, 231, 34, 227, 241, 146, 134, 156, - 217, 161, 155, 165, 76, 142, 69, 82, 45, 49, 163, 56, 0, 196, 33, 247, 220, 178, 199, 227, - 182, 213, 3, 75, 71, 188, 175, 31, 189, 247, 209, 217, 210, 21, 56, 246, 86, 89, 71, 165, - 90, 75, 150, 236, 68, 240, 37, 0, 146, 196, 33, 18, 70, 242, 33, 232, 246, 235, 191, 213, - 96, 68, 75, 173, 43, 125, 25, 239, 71, 93, 71, 74, 159, 50, 34, 251, 139, 133, 228, 241, - 49, 166, 36, 0, 196, 33, 28, 99, 102, 90, 22, 105, 167, 195, 200, 126, 132, 202, 178, 50, - 197, 41, 204, 7, 108, 2, 12, 7, 221, 0, 61, 120, 23, 112, 11, 47, 104, 60, 128, 196, 33, - 34, 178, 82, 48, 155, 153, 34, 99, 173, 221, 221, 236, 235, 5, 135, 165, 18, 39, 120, 175, - 253, 216, 48, 255, 8, 67, 160, 96, 72, 49, 99, 21, 0, 196, 32, 148, 153, 156, 103, 116, 92, - 72, 80, 249, 8, 110, 104, 44, 231, 231, 1, 62, 3, 189, 77, 153, 74, 89, 74, 191, 185, 236, - 20, 209, 93, 77, 51, 196, 32, 33, 224, 176, 185, 62, 76, 18, 58, 48, 219, 106, 206, 35, - 153, 234, 54, 21, 29, 87, 57, 147, 84, 219, 194, 208, 170, 158, 105, 241, 63, 76, 44, 196, - 33, 236, 186, 22, 30, 113, 33, 148, 99, 242, 146, 146, 41, 119, 163, 230, 139, 48, 191, 57, - 161, 79, 240, 7, 167, 28, 62, 213, 170, 132, 195, 255, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - ]; - - test_generic_serialization_regression_serde(data_expected, buf_expected); -} diff --git a/poly-commitment/src/utils.rs b/poly-commitment/src/utils.rs new file mode 100644 index 0000000000..fc43883aff --- /dev/null +++ b/poly-commitment/src/utils.rs @@ -0,0 +1,191 @@ +use crate::{commitment::CommitmentCurve, PolynomialsToCombine}; +use ark_ff::{FftField, Field, One, Zero}; +use ark_poly::{univariate::DensePolynomial, DenseUVPolynomial, EvaluationDomain, Evaluations}; +use o1_utils::ExtendedDensePolynomial; +use rayon::prelude::*; + +/// Represent a polynomial either with its coefficients or its evaluations +pub enum DensePolynomialOrEvaluations<'a, F: FftField, D: EvaluationDomain> { + /// Polynomial represented by its coefficients + DensePolynomial(&'a DensePolynomial), + /// Polynomial represented by its evaluations over a domain D + Evaluations(&'a Evaluations, D), +} + +/// A formal sum of the form +/// `s_0 * p_0 + ... s_n * p_n` +/// where each `s_i` is a scalar and each `p_i` is a polynomial. +/// The parameter `P` is expected to be the coefficients of the polynomial +/// `p_i`, even though we could treat it as the evaluations. +/// +/// This hypothesis is important if `to_dense_polynomial` is called. +#[derive(Default)] +struct ScaledChunkedPolynomial(Vec<(F, P)>); + +impl ScaledChunkedPolynomial { + fn add_poly(&mut self, scale: F, p: P) { + self.0.push((scale, p)) + } +} + +impl<'a, F: Field> ScaledChunkedPolynomial { + /// Compute the resulting scaled polynomial. + /// Example: + /// Given the two polynomials `1 + 2X` and `3 + 4X`, and the scaling + /// factors `2` and `3`, the result is the polynomial `11 + 16X`. + /// ```text + /// 2 * [1, 2] + 3 * [3, 4] = [2, 4] + [9, 12] = [11, 16] + /// ``` + fn to_dense_polynomial(&self) -> DensePolynomial { + // Note: using a reference to avoid reallocation of the result. + let mut res = DensePolynomial::::zero(); + + let scaled: Vec<_> = self + .0 + .par_iter() + .map(|(scale, segment)| { + let scale = *scale; + // We simply scale each coefficients. + // It is simply because DensePolynomial doesn't have a method + // `scale`. + let v = segment.par_iter().map(|x| scale * *x).collect(); + DensePolynomial::from_coefficients_vec(v) + }) + .collect(); + + for p in scaled { + res += &p; + } + + res + } +} + +/// Combine the polynomials using a scalar (`polyscale`), creating a single +/// unified polynomial to open. This function also accepts polynomials in +/// evaluations form. In this case it applies an IFFT, and, if necessarry, +/// applies chunking to it (ie. split it in multiple polynomials of +/// degree less than the SRS size). +/// +/// Parameters: +/// - `plnms`: vector of polynomials, either in evaluations or coefficients form, together with +/// a set of scalars representing their blinders. +/// - `polyscale`: scalar to combine the polynomials, which will be scaled based on the number of +/// polynomials to combine. +/// +/// Output: +/// - `combined_poly`: combined polynomial. The order of the output follows the order of `plnms`. +/// - `combined_comm`: combined scalars representing the blinder commitment. +/// +/// Example: +/// Given the three polynomials `p1(X)`, and `p3(X)` in coefficients +/// forms, p2(X) in evaluation form, +/// and the scaling factor `polyscale`, the result will be the polynomial: +/// +/// ```text +/// p1(X) + polyscale * i_fft(chunks(p2))(X) + polyscale^2 p3(X) +/// ``` +/// +/// Additional complexity is added to handle chunks. +pub fn combine_polys>( + plnms: PolynomialsToCombine, + polyscale: G::ScalarField, + srs_length: usize, +) -> (DensePolynomial, G::ScalarField) { + // Initialising the output for the combined coefficients forms + let mut plnm_coefficients = + ScaledChunkedPolynomial::::default(); + // Initialising the output for the combined evaluations forms + let mut plnm_evals_part = { + // For now just check that all the evaluation polynomials are the same + // degree so that we can do just a single FFT. + // If/when we change this, we can add more complicated code to handle + // different degrees. + let degree = plnms + .iter() + .fold(None, |acc, (p, _)| match p { + DensePolynomialOrEvaluations::DensePolynomial(_) => acc, + DensePolynomialOrEvaluations::Evaluations(_, d) => { + if let Some(n) = acc { + assert_eq!(n, d.size()); + } + Some(d.size()) + } + }) + .unwrap_or(0); + vec![G::ScalarField::zero(); degree] + }; + + // Will contain ∑ comm_chunk * polyscale^i + let mut combined_comm = G::ScalarField::zero(); + + // Will contain polyscale^i + let mut polyscale_to_i = G::ScalarField::one(); + + // Iterating over polynomials in the batch. + // Note that the chunks in the commitment `p_i_comm` are given as `PolyComm`. They are + // evaluations. + // We do modify two different structures depending on the form of the + // polynomial we are currently processing: `plnm` and `plnm_evals_part`. + // We do need to treat both forms separately. + for (p_i, p_i_comm) in plnms { + match p_i { + // Here we scale the polynomial in evaluations forms + // Note that based on the check above, sub_domain.size() always give + // the same value + DensePolynomialOrEvaluations::Evaluations(evals_i, sub_domain) => { + let stride = evals_i.evals.len() / sub_domain.size(); + let evals = &evals_i.evals; + plnm_evals_part + .par_iter_mut() + .enumerate() + .for_each(|(i, x)| { + *x += polyscale_to_i * evals[i * stride]; + }); + for comm_chunk in p_i_comm.into_iter() { + combined_comm += &(*comm_chunk * polyscale_to_i); + polyscale_to_i *= &polyscale; + } + } + + // Here we scale the polynomial in coefficient forms + DensePolynomialOrEvaluations::DensePolynomial(p_i) => { + let mut offset = 0; + // iterating over chunks of the polynomial + for comm_chunk in p_i_comm.into_iter() { + let segment = &p_i.coeffs[std::cmp::min(offset, p_i.coeffs.len()) + ..std::cmp::min(offset + srs_length, p_i.coeffs.len())]; + plnm_coefficients.add_poly(polyscale_to_i, segment); + + combined_comm += &(*comm_chunk * polyscale_to_i); + polyscale_to_i *= &polyscale; + offset += srs_length; + } + } + } + } + + // Now, we will combine both evaluations and coefficients forms + + // plnm will be our final combined polynomial. We first treat the + // polynomials in coefficients forms, which is simply scaling the + // coefficients and add them. + let mut combined_plnm = plnm_coefficients.to_dense_polynomial(); + + if !plnm_evals_part.is_empty() { + // n is the number of evaluations, which is a multiple of the + // domain size. + // We treat now each chunk. + let n = plnm_evals_part.len(); + let max_poly_size = srs_length; + // equiv to divceil, but unstable in rust < 1.73. + let num_chunks = n / max_poly_size + if n % max_poly_size == 0 { 0 } else { 1 }; + // Interpolation on the whole domain, i.e. it can be d2, d4, etc. + combined_plnm += &Evaluations::from_vec_and_domain(plnm_evals_part, D::new(n).unwrap()) + .interpolate() + .to_chunked_polynomial(num_chunks, max_poly_size) + .linearize(polyscale); + } + + (combined_plnm, combined_comm) +} diff --git a/poly-commitment/src/tests/batch_15_wires.rs b/poly-commitment/tests/batch_15_wires.rs similarity index 98% rename from poly-commitment/src/tests/batch_15_wires.rs rename to poly-commitment/tests/batch_15_wires.rs index 788ca1bc43..eafeba8f5b 100644 --- a/poly-commitment/src/tests/batch_15_wires.rs +++ b/poly-commitment/tests/batch_15_wires.rs @@ -1,12 +1,6 @@ //! This module tests polynomial commitments, batched openings and //! verification of a batch of batched opening proofs of polynomial commitments -use crate::{ - commitment::{combined_inner_product, BatchEvaluationProof, CommitmentCurve, Evaluation}, - evaluation_proof::DensePolynomialOrEvaluations, - srs::SRS, - SRS as _, -}; use ark_ff::{UniformRand, Zero}; use ark_poly::{univariate::DensePolynomial, DenseUVPolynomial, Radix2EvaluationDomain}; use colored::Colorize; @@ -16,6 +10,12 @@ use mina_poseidon::{ constants::PlonkSpongeConstantsKimchi as SC, sponge::DefaultFqSponge, FqSponge, }; use o1_utils::ExtendedDensePolynomial as _; +use poly_commitment::{ + commitment::{combined_inner_product, BatchEvaluationProof, CommitmentCurve, Evaluation}, + ipa::SRS, + utils::DensePolynomialOrEvaluations, + SRS as _, +}; use rand::Rng; use std::time::{Duration, Instant}; diff --git a/poly-commitment/tests/commitment.rs b/poly-commitment/tests/commitment.rs new file mode 100644 index 0000000000..b22b1e9853 --- /dev/null +++ b/poly-commitment/tests/commitment.rs @@ -0,0 +1,439 @@ +use ark_ff::{UniformRand, Zero}; +use ark_poly::{univariate::DensePolynomial, DenseUVPolynomial, Radix2EvaluationDomain}; +use colored::Colorize; +use groupmap::GroupMap; +use mina_curves::pasta::{Fp, Vesta, VestaParameters}; +use mina_poseidon::{ + constants::PlonkSpongeConstantsKimchi as SC, sponge::DefaultFqSponge, FqSponge as _, +}; +use o1_utils::{ + serialization::test_generic_serialization_regression_serde, ExtendedDensePolynomial as _, +}; +use poly_commitment::{ + commitment::{ + combined_inner_product, BatchEvaluationProof, BlindedCommitment, CommitmentCurve, + Evaluation, PolyComm, + }, + ipa::{OpeningProof, SRS}, + utils::DensePolynomialOrEvaluations, + SRS as _, +}; +use rand::{CryptoRng, Rng, SeedableRng}; +use std::time::{Duration, Instant}; + +// Note: Because the current API uses large tuples of types, I re-create types +// in this test to facilitate aggregated proofs and batch verification of proofs. +// TODO: improve the polynomial commitment API + +/// A commitment +pub struct Commitment { + /// the commitment itself, potentially in chunks + chunked_commitment: PolyComm, +} + +/// An evaluated commitment (given a number of evaluation points) +pub struct EvaluatedCommitment { + /// the commitment + commit: Commitment, + /// the chunked evaluations given in the same order as the evaluation points + chunked_evals: Vec, +} + +/// A polynomial commitment evaluated at a point. Since a commitment can be +/// chunked, the evaluations can also be chunked. +pub type ChunkedCommitmentEvaluation = Vec; + +mod prover { + use super::*; + + /// This struct represents a commitment with associated secret information + pub struct CommitmentAndSecrets { + /// the commitment evaluated at some points + pub eval_commit: EvaluatedCommitment, + /// the polynomial + pub poly: DensePolynomial, + /// the blinding part + pub chunked_blinding: PolyComm, + } +} + +/// This struct represents an aggregated evaluation proof for a number of +/// polynomial commitments, as well as a number of evaluation points. +pub struct AggregatedEvaluationProof { + /// a number of evaluation points + eval_points: Vec, + /// a number of commitments evaluated at these evaluation points + eval_commitments: Vec, + /// the random value used to separate polynomials + polymask: Fp, + /// the random value used to separate evaluations + evalmask: Fp, + /// an Fq-sponge + fq_sponge: DefaultFqSponge, + /// the actual evaluation proof + pub proof: OpeningProof, +} + +impl AggregatedEvaluationProof { + /// This function converts an aggregated evaluation proof into something the + /// verify API understands + pub fn verify_type( + &self, + ) -> BatchEvaluationProof, OpeningProof> + { + let mut coms = vec![]; + for eval_com in &self.eval_commitments { + assert_eq!(self.eval_points.len(), eval_com.chunked_evals.len()); + coms.push(Evaluation { + commitment: eval_com.commit.chunked_commitment.clone(), + evaluations: eval_com.chunked_evals.clone(), + }); + } + + let combined_inner_product = { + let es: Vec<_> = coms + .iter() + .map(|Evaluation { evaluations, .. }| evaluations.clone()) + .collect(); + combined_inner_product(&self.polymask, &self.evalmask, &es) + }; + + BatchEvaluationProof { + sponge: self.fq_sponge.clone(), + evaluation_points: self.eval_points.clone(), + polyscale: self.polymask, + evalscale: self.evalmask, + evaluations: coms, + opening: &self.proof, + combined_inner_product, + } + } +} + +pub fn generate_random_opening_proof( + mut rng: &mut RNG, + group_map: &::Map, + srs: &SRS, +) -> (Vec, Duration, Duration) { + let num_chunks = 1; + + let fq_sponge = DefaultFqSponge::::new( + mina_poseidon::pasta::fq_kimchi::static_params(), + ); + + let mut time_commit = Duration::new(0, 0); + let mut time_open = Duration::new(0, 0); + + // create 7 distinct "aggregated evaluation proofs" + let mut proofs = vec![]; + for _ in 0..7 { + // generate 7 random evaluation points + let eval_points: Vec = (0..7).map(|_| Fp::rand(&mut rng)).collect(); + + // create 11 polynomials of random degree (of at most 500) + // and commit to them + let mut commitments = vec![]; + for _ in 0..11 { + let len: usize = rng.gen(); + let len = len % 500; + // TODO @volhovm maybe remove the second case. + // every other polynomial is upperbounded + let poly = if len == 0 { + DensePolynomial::::zero() + } else { + DensePolynomial::::rand(len, &mut rng) + }; + + // create commitments for each polynomial, and evaluate each + // polynomial at the 7 random points + let timer = Instant::now(); + let BlindedCommitment { + commitment: chunked_commitment, + blinders: chunked_blinding, + } = srs.commit(&poly, num_chunks, &mut rng); + time_commit += timer.elapsed(); + + let mut chunked_evals = vec![]; + for point in eval_points.clone() { + let n = poly.len(); + let num_chunks = if n == 0 { + 1 + } else { + n / srs.g.len() + if n % srs.g.len() == 0 { 0 } else { 1 } + }; + chunked_evals.push( + poly.to_chunked_polynomial(num_chunks, srs.g.len()) + .evaluate_chunks(point), + ); + } + + let commit = Commitment { chunked_commitment }; + + let eval_commit = EvaluatedCommitment { + commit, + chunked_evals, + }; + + commitments.push(prover::CommitmentAndSecrets { + eval_commit, + poly, + chunked_blinding, + }); + } + + // create aggregated evaluation proof + let mut polynomials: Vec<( + DensePolynomialOrEvaluations>, + PolyComm<_>, + )> = vec![]; + for c in &commitments { + polynomials.push(( + DensePolynomialOrEvaluations::DensePolynomial(&c.poly), + c.chunked_blinding.clone(), + )); + } + + let polymask = Fp::rand(&mut rng); + let evalmask = Fp::rand(&mut rng); + + let timer = Instant::now(); + let proof = srs.open::, _, _>( + group_map, + &polynomials, + &eval_points.clone(), + polymask, + evalmask, + fq_sponge.clone(), + &mut rng, + ); + time_open += timer.elapsed(); + + // prepare for batch verification + let eval_commitments = commitments.into_iter().map(|c| c.eval_commit).collect(); + proofs.push(AggregatedEvaluationProof { + eval_points, + eval_commitments, + polymask, + evalmask, + fq_sponge: fq_sponge.clone(), + proof, + }); + } + + (proofs, time_commit, time_open) +} + +fn test_randomised(mut rng: &mut RNG) { + let group_map = ::Map::setup(); + // create an SRS optimized for polynomials of degree 2^7 - 1 + let srs = SRS::::create(1 << 7); + + // TODO: move to bench + + let (proofs, time_commit, time_open) = + generate_random_opening_proof(&mut rng, &group_map, &srs); + + println!("{} {:?}", "total commitment time:".yellow(), time_commit); + println!( + "{} {:?}", + "total evaluation proof creation time:".magenta(), + time_open + ); + + let timer = Instant::now(); + + // batch verify all the proofs + let mut batch: Vec<_> = proofs.iter().map(|p| p.verify_type()).collect(); + assert!(srs.verify::, _>(&group_map, &mut batch, &mut rng)); + + // TODO: move to bench + println!( + "{} {:?}", + "batch verification time:".green(), + timer.elapsed() + ); +} + +#[test] +/// Tests polynomial commitments, batched openings and +/// verification of a batch of batched opening proofs of polynomial commitments +fn test_commit() +where + ::Err: std::fmt::Debug, +{ + // setup + let mut rng = o1_utils::tests::make_test_rng(None); + test_randomised(&mut rng) +} + +#[test] +/// Deterministic tests of polynomial commitments, batched openings and +/// verification of a batch of batched opening proofs of polynomial commitments +fn test_commit_deterministic() +where + ::Err: std::fmt::Debug, +{ + // Seed deliberately chosen to exercise zero commitments + let seed = [ + 17, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, + ]; + + let mut rng = ::from_seed(seed); + test_randomised(&mut rng) +} + +#[test] +pub fn ser_regression_canonical_srs() { + use mina_curves::pasta::{Fp, Fq, Pallas, Vesta}; + use poly_commitment::ipa::SRS; + + let rng = &mut o1_utils::tests::make_test_rng(Some([0u8; 32])); + + let td1 = Fp::rand(rng); + let data_expected = unsafe { SRS::::create_trusted_setup(td1, 1 << 3) }; + // Generated with commit 1494cf973d40fb276465929eb7db1952c5de7bdc + let buf_expected: Vec = vec![ + 146, 152, 196, 33, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 196, 33, 2, 208, 196, 173, 72, 216, 133, 169, 56, 56, 33, 35, + 130, 78, 164, 182, 235, 78, 233, 153, 95, 96, 113, 78, 110, 205, 98, 59, 183, 156, 34, 26, + 128, 196, 33, 6, 150, 178, 230, 252, 128, 97, 30, 248, 199, 147, 159, 227, 118, 248, 138, + 60, 143, 178, 158, 37, 232, 110, 15, 134, 143, 127, 109, 206, 204, 155, 27, 128, 196, 33, + 64, 90, 121, 139, 173, 254, 255, 108, 129, 22, 165, 14, 110, 147, 48, 189, 183, 210, 237, + 108, 189, 170, 107, 238, 149, 155, 227, 211, 89, 63, 121, 46, 0, 196, 33, 90, 220, 159, + 218, 37, 222, 219, 32, 63, 233, 183, 226, 174, 205, 38, 189, 143, 33, 160, 169, 226, 235, + 216, 43, 17, 29, 215, 31, 150, 233, 163, 36, 0, 196, 33, 98, 225, 126, 245, 162, 255, 249, + 60, 120, 105, 186, 96, 169, 208, 83, 62, 19, 64, 187, 79, 11, 120, 130, 242, 249, 79, 249, + 99, 210, 225, 25, 36, 0, 196, 33, 49, 147, 222, 224, 242, 240, 198, 119, 133, 90, 152, 19, + 122, 52, 255, 181, 14, 55, 81, 250, 47, 167, 47, 195, 36, 7, 187, 103, 225, 0, 169, 26, 0, + 196, 33, 160, 235, 225, 204, 186, 77, 26, 177, 237, 210, 27, 246, 174, 136, 126, 204, 93, + 48, 11, 220, 178, 86, 174, 156, 6, 6, 86, 90, 105, 215, 117, 16, 128, 196, 33, 1, 34, 38, + 38, 91, 206, 178, 229, 168, 199, 139, 226, 117, 121, 162, 156, 54, 54, 120, 117, 99, 242, + 180, 170, 153, 201, 1, 99, 56, 96, 32, 9, 128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]; + + test_generic_serialization_regression_serde(data_expected, buf_expected); + + let td2 = Fq::rand(rng); + let data_expected = unsafe { SRS::::create_trusted_setup(td2, 1 << 3) }; + // Generated with commit 1494cf973d40fb276465929eb7db1952c5de7bdc + let buf_expected: Vec = vec![ + 146, 152, 196, 33, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 196, 33, 144, 162, 168, 123, 25, 245, 211, 151, 234, 53, 230, + 184, 254, 200, 193, 156, 214, 207, 155, 171, 186, 143, 221, 21, 24, 49, 159, 187, 221, 215, + 92, 61, 0, 196, 33, 172, 18, 76, 72, 86, 115, 105, 59, 32, 148, 37, 128, 195, 62, 60, 165, + 201, 121, 67, 175, 73, 51, 78, 160, 25, 215, 242, 243, 71, 93, 49, 29, 128, 196, 33, 130, + 19, 51, 124, 82, 66, 228, 1, 66, 62, 240, 98, 240, 241, 101, 177, 252, 8, 42, 36, 76, 215, + 244, 24, 170, 221, 102, 204, 51, 183, 231, 52, 128, 196, 33, 121, 187, 189, 65, 178, 95, + 164, 135, 161, 57, 194, 93, 76, 12, 253, 165, 236, 21, 171, 199, 162, 16, 185, 75, 0, 120, + 171, 4, 1, 21, 184, 1, 0, 196, 33, 165, 208, 157, 8, 127, 129, 67, 81, 56, 223, 87, 125, + 139, 239, 36, 18, 139, 239, 53, 114, 116, 81, 1, 174, 76, 50, 97, 213, 108, 193, 107, 46, + 128, 196, 33, 83, 43, 226, 38, 146, 122, 97, 205, 114, 214, 23, 21, 165, 138, 211, 222, + 224, 190, 130, 70, 142, 203, 203, 89, 49, 138, 144, 104, 8, 247, 147, 11, 128, 196, 33, 73, + 128, 98, 223, 249, 164, 221, 198, 148, 190, 44, 37, 20, 106, 24, 112, 49, 72, 64, 157, 99, + 100, 170, 222, 105, 160, 160, 92, 194, 154, 93, 19, 128, 196, 33, 1, 130, 119, 215, 95, + 139, 130, 47, 90, 13, 171, 187, 79, 106, 134, 121, 50, 181, 54, 202, 63, 25, 38, 174, 42, + 5, 210, 172, 157, 149, 27, 34, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]; + test_generic_serialization_regression_serde(data_expected, buf_expected); +} + +#[test] +pub fn ser_regression_canonical_polycomm() { + use ark_poly::{univariate::DensePolynomial, DenseUVPolynomial}; + use mina_curves::pasta::{Fp, Vesta}; + use poly_commitment::{commitment::BlindedCommitment, ipa::SRS}; + + let rng = &mut o1_utils::tests::make_test_rng(Some([0u8; 32])); + + let srs = SRS::::create(1 << 7); + + let com_length = 300; + let num_chunks = 6; + + let poly = DensePolynomial::::rand(com_length, rng); + + let BlindedCommitment { + commitment: chunked_commitment, + blinders: _blinders, + } = srs.commit(&poly, num_chunks, rng); + + let data_expected = chunked_commitment; + // Generated with commit 1494cf973d40fb276465929eb7db1952c5de7bdc + let buf_expected: Vec = vec![ + 145, 150, 196, 33, 36, 158, 70, 161, 147, 233, 138, 19, 54, 52, 87, 58, 158, 154, 255, 197, + 219, 225, 79, 25, 41, 193, 232, 64, 250, 71, 230, 154, 34, 145, 81, 23, 0, 196, 33, 37, + 227, 246, 88, 42, 44, 53, 244, 102, 92, 197, 246, 56, 56, 135, 155, 248, 155, 243, 23, 76, + 44, 94, 125, 60, 209, 195, 190, 73, 158, 97, 42, 128, 196, 33, 101, 254, 242, 100, 238, + 214, 56, 151, 94, 170, 69, 219, 239, 135, 253, 151, 207, 217, 47, 229, 75, 7, 41, 9, 131, + 205, 85, 171, 166, 213, 96, 52, 128, 196, 33, 53, 166, 32, 175, 166, 196, 121, 1, 25, 236, + 34, 226, 31, 145, 70, 96, 89, 179, 65, 35, 253, 161, 10, 211, 170, 116, 247, 40, 225, 104, + 155, 34, 0, 196, 33, 114, 151, 94, 73, 26, 234, 37, 98, 188, 142, 161, 165, 62, 238, 58, + 76, 200, 16, 62, 210, 124, 127, 229, 81, 119, 145, 43, 157, 254, 237, 154, 57, 128, 196, + 33, 212, 213, 38, 1, 17, 84, 147, 102, 31, 103, 242, 177, 110, 64, 239, 33, 211, 216, 40, + 103, 51, 55, 85, 96, 133, 20, 194, 6, 87, 180, 212, 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, + ]; + + test_generic_serialization_regression_serde(data_expected, buf_expected); +} + +#[test] +pub fn ser_regression_canonical_opening_proof() { + use groupmap::GroupMap; + use mina_curves::pasta::Vesta; + use poly_commitment::ipa::{OpeningProof, SRS}; + + let rng = &mut o1_utils::tests::make_test_rng(Some([0u8; 32])); + + let group_map = ::Map::setup(); + let srs = SRS::::create(1 << 7); + + let data_expected: OpeningProof = generate_random_opening_proof(rng, &group_map, &srs).0 + [0] + .proof + .clone(); + + // Generated with commit 1494cf973d40fb276465929eb7db1952c5de7bdc + let buf_expected: Vec = vec![ + 149, 151, 146, 196, 33, 166, 183, 76, 209, 121, 62, 56, 194, 134, 133, 238, 52, 190, 163, + 234, 149, 245, 202, 234, 70, 56, 187, 130, 25, 73, 135, 136, 95, 248, 80, 124, 39, 128, + 196, 33, 78, 148, 92, 114, 104, 136, 170, 61, 102, 158, 80, 42, 184, 109, 110, 33, 228, + 116, 55, 98, 16, 192, 10, 159, 156, 72, 87, 129, 68, 66, 248, 26, 0, 146, 196, 33, 251, 20, + 242, 183, 87, 51, 106, 226, 80, 224, 139, 186, 33, 52, 203, 117, 6, 129, 167, 88, 252, 193, + 163, 38, 21, 37, 63, 254, 106, 136, 63, 21, 0, 196, 33, 4, 215, 169, 8, 207, 56, 209, 41, + 107, 189, 92, 110, 124, 186, 112, 193, 204, 173, 82, 46, 110, 8, 194, 193, 93, 130, 12, + 216, 24, 151, 94, 61, 128, 146, 196, 33, 55, 26, 135, 155, 211, 30, 67, 184, 93, 78, 146, + 166, 31, 11, 120, 93, 17, 24, 164, 39, 177, 98, 25, 156, 33, 5, 179, 64, 237, 69, 199, 11, + 128, 196, 33, 249, 229, 29, 113, 38, 26, 30, 205, 26, 217, 71, 248, 199, 157, 244, 196, 1, + 108, 74, 39, 173, 55, 118, 216, 191, 232, 27, 95, 190, 38, 96, 35, 128, 146, 196, 33, 109, + 160, 19, 169, 187, 111, 247, 152, 101, 20, 161, 251, 150, 61, 204, 78, 118, 171, 81, 1, + 253, 83, 64, 170, 93, 114, 216, 224, 33, 250, 202, 14, 0, 196, 33, 247, 34, 197, 187, 28, + 46, 42, 6, 126, 129, 132, 151, 39, 150, 115, 138, 229, 50, 220, 8, 170, 81, 173, 13, 54, + 57, 90, 169, 201, 212, 128, 50, 128, 146, 196, 33, 73, 206, 252, 115, 3, 45, 100, 75, 98, + 139, 35, 227, 181, 241, 175, 2, 175, 14, 132, 86, 0, 174, 64, 84, 24, 88, 18, 163, 82, 102, + 164, 19, 0, 196, 33, 128, 62, 85, 80, 76, 1, 35, 195, 197, 48, 46, 1, 13, 183, 105, 91, + 243, 109, 124, 68, 41, 21, 42, 228, 124, 28, 193, 188, 85, 6, 180, 24, 128, 146, 196, 33, + 21, 245, 240, 236, 40, 30, 75, 91, 87, 50, 153, 173, 88, 231, 34, 227, 241, 146, 134, 156, + 217, 161, 155, 165, 76, 142, 69, 82, 45, 49, 163, 56, 0, 196, 33, 247, 220, 178, 199, 227, + 182, 213, 3, 75, 71, 188, 175, 31, 189, 247, 209, 217, 210, 21, 56, 246, 86, 89, 71, 165, + 90, 75, 150, 236, 68, 240, 37, 0, 146, 196, 33, 18, 70, 242, 33, 232, 246, 235, 191, 213, + 96, 68, 75, 173, 43, 125, 25, 239, 71, 93, 71, 74, 159, 50, 34, 251, 139, 133, 228, 241, + 49, 166, 36, 0, 196, 33, 28, 99, 102, 90, 22, 105, 167, 195, 200, 126, 132, 202, 178, 50, + 197, 41, 204, 7, 108, 2, 12, 7, 221, 0, 61, 120, 23, 112, 11, 47, 104, 60, 128, 196, 33, + 34, 178, 82, 48, 155, 153, 34, 99, 173, 221, 221, 236, 235, 5, 135, 165, 18, 39, 120, 175, + 253, 216, 48, 255, 8, 67, 160, 96, 72, 49, 99, 21, 0, 196, 32, 148, 153, 156, 103, 116, 92, + 72, 80, 249, 8, 110, 104, 44, 231, 231, 1, 62, 3, 189, 77, 153, 74, 89, 74, 191, 185, 236, + 20, 209, 93, 77, 51, 196, 32, 33, 224, 176, 185, 62, 76, 18, 58, 48, 219, 106, 206, 35, + 153, 234, 54, 21, 29, 87, 57, 147, 84, 219, 194, 208, 170, 158, 105, 241, 63, 76, 44, 196, + 33, 236, 186, 22, 30, 113, 33, 148, 99, 242, 146, 146, 41, 119, 163, 230, 139, 48, 191, 57, + 161, 79, 240, 7, 167, 28, 62, 213, 170, 132, 195, 255, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]; + + test_generic_serialization_regression_serde(data_expected, buf_expected); +} diff --git a/poly-commitment/tests/ipa_commitment.rs b/poly-commitment/tests/ipa_commitment.rs new file mode 100644 index 0000000000..bbab4def03 --- /dev/null +++ b/poly-commitment/tests/ipa_commitment.rs @@ -0,0 +1,227 @@ +use ark_ff::{One, UniformRand, Zero}; +use ark_poly::{ + univariate::DensePolynomial, DenseUVPolynomial, EvaluationDomain, Evaluations, Polynomial, + Radix2EvaluationDomain as D, Radix2EvaluationDomain, +}; +use groupmap::GroupMap; +use mina_curves::pasta::{Fp, Pallas, Vesta as VestaG}; +use mina_poseidon::{ + constants::PlonkSpongeConstantsKimchi as SC, sponge::DefaultFqSponge, FqSponge, +}; +use o1_utils::ExtendedDensePolynomial; +use poly_commitment::{ + commitment::{combined_inner_product, BatchEvaluationProof, CommitmentCurve, Evaluation}, + ipa::SRS, + pbt_srs, + utils::DensePolynomialOrEvaluations, + PolyComm, SRS as _, +}; +use rand::Rng; +use std::array; + +#[test] +fn test_lagrange_commitments() { + let n = 64; + let domain = D::::new(n).unwrap(); + + let srs = SRS::::create(n); + srs.get_lagrange_basis(domain); + + let num_chunks = domain.size() / srs.g.len(); + + let expected_lagrange_commitments: Vec<_> = (0..n) + .map(|i| { + let mut e = vec![Fp::zero(); n]; + e[i] = Fp::one(); + let p = Evaluations::>::from_vec_and_domain(e, domain).interpolate(); + srs.commit_non_hiding(&p, num_chunks) + }) + .collect(); + + let computed_lagrange_commitments = srs.get_lagrange_basis_from_domain_size(domain.size()); + for i in 0..n { + assert_eq!( + computed_lagrange_commitments[i], + expected_lagrange_commitments[i], + ); + } +} + +#[test] +// This tests with two chunks. +fn test_chunked_lagrange_commitments() { + let n = 64; + let divisor = 4; + let domain = D::::new(n).unwrap(); + + let srs = SRS::::create(n / divisor); + srs.get_lagrange_basis(domain); + + let num_chunks = domain.size() / srs.g.len(); + assert!(num_chunks == divisor); + + let expected_lagrange_commitments: Vec<_> = (0..n) + .map(|i| { + // Generating the i-th element of the lagrange basis + let mut e = vec![Fp::zero(); n]; + e[i] = Fp::one(); + let p = Evaluations::>::from_vec_and_domain(e, domain).interpolate(); + // Committing, and requesting [num_chunks] chunks. + srs.commit_non_hiding(&p, num_chunks) + }) + .collect(); + + let computed_lagrange_commitments = srs.get_lagrange_basis_from_domain_size(domain.size()); + for i in 0..n { + assert_eq!( + computed_lagrange_commitments[i], + expected_lagrange_commitments[i], + ); + } +} + +#[test] +// TODO @volhovm I don't understand what this test does and +// whether it is worth leaving. +/// Same as test_chunked_lagrange_commitments, but with a slight +/// offset in the SRS +fn test_offset_chunked_lagrange_commitments() { + let n = 64; + let domain = D::::new(n).unwrap(); + + let srs = SRS::::create(n / 2 + 1); + srs.get_lagrange_basis(domain); + + // Is this even taken into account?... + let num_chunks = (domain.size() + srs.g.len() - 1) / srs.g.len(); + assert!(num_chunks == 2); + + let expected_lagrange_commitments: Vec<_> = (0..n) + .map(|i| { + let mut e = vec![Fp::zero(); n]; + e[i] = Fp::one(); + let p = Evaluations::>::from_vec_and_domain(e, domain).interpolate(); + srs.commit_non_hiding(&p, num_chunks) // this requires max = Some(64) + }) + .collect(); + + let computed_lagrange_commitments = srs.get_lagrange_basis_from_domain_size(domain.size()); + for i in 0..n { + assert_eq!( + computed_lagrange_commitments[i], + expected_lagrange_commitments[i], + ); + } +} + +#[test] +fn test_opening_proof() { + // create two polynomials + let coeffs: [Fp; 10] = array::from_fn(|i| Fp::from(i as u32)); + let poly1 = DensePolynomial::::from_coefficients_slice(&coeffs); + let poly2 = DensePolynomial::::from_coefficients_slice(&coeffs[..5]); + + // create an SRS + let srs = SRS::::create(20); + let mut rng = &mut o1_utils::tests::make_test_rng(None); + + // commit the two polynomials + let commitment1 = srs.commit(&poly1, 1, rng); + let commitment2 = srs.commit(&poly2, 1, rng); + + // create an aggregated opening proof + let (u, v) = (Fp::rand(rng), Fp::rand(rng)); + let group_map = ::Map::setup(); + let sponge = DefaultFqSponge::<_, SC>::new(mina_poseidon::pasta::fq_kimchi::static_params()); + + let polys: Vec<( + DensePolynomialOrEvaluations<_, Radix2EvaluationDomain<_>>, + PolyComm<_>, + )> = vec![ + ( + DensePolynomialOrEvaluations::DensePolynomial(&poly1), + commitment1.blinders, + ), + ( + DensePolynomialOrEvaluations::DensePolynomial(&poly2), + commitment2.blinders, + ), + ]; + // Generate a random number of evaluation point + let nb_elem: u32 = rng.gen_range(1..7); + let elm: Vec = (0..nb_elem).map(|_| Fp::rand(&mut rng)).collect(); + let opening_proof = srs.open(&group_map, &polys, &elm, v, u, sponge.clone(), rng); + + // evaluate the polynomials at the points + let poly1_chunked_evals: Vec> = elm + .iter() + .map(|elmi| { + poly1 + .to_chunked_polynomial(1, srs.g.len()) + .evaluate_chunks(*elmi) + }) + .collect(); + + fn sum(c: &[Fp]) -> Fp { + c.iter().fold(Fp::zero(), |a, &b| a + b) + } + + for (i, chunks) in poly1_chunked_evals.iter().enumerate() { + assert_eq!(sum(chunks), poly1.evaluate(&elm[i])) + } + + let poly2_chunked_evals: Vec> = elm + .iter() + .map(|elmi| { + poly2 + .to_chunked_polynomial(1, srs.g.len()) + .evaluate_chunks(*elmi) + }) + .collect(); + + for (i, chunks) in poly2_chunked_evals.iter().enumerate() { + assert_eq!(sum(chunks), poly2.evaluate(&elm[i])) + } + + let evaluations = vec![ + Evaluation { + commitment: commitment1.commitment, + evaluations: poly1_chunked_evals, + }, + Evaluation { + commitment: commitment2.commitment, + evaluations: poly2_chunked_evals, + }, + ]; + + let combined_inner_product = { + let es: Vec<_> = evaluations + .iter() + .map(|Evaluation { evaluations, .. }| evaluations.clone()) + .collect(); + combined_inner_product(&v, &u, &es) + }; + + { + // create the proof + let mut batch = vec![BatchEvaluationProof { + sponge, + evaluation_points: elm, + polyscale: v, + evalscale: u, + evaluations, + opening: &opening_proof, + combined_inner_product, + }]; + + assert!(srs.verify(&group_map, &mut batch, rng)); + } +} + +// Testing how many chunks are generated with different polynomial sizes and +// different number of chunks requested. +#[test] +fn test_regression_commit_non_hiding_expected_number_of_chunks() { + pbt_srs::test_regression_commit_non_hiding_expected_number_of_chunks::>(); + pbt_srs::test_regression_commit_non_hiding_expected_number_of_chunks::>() +} diff --git a/poly-commitment/tests/kzg.rs b/poly-commitment/tests/kzg.rs new file mode 100644 index 0000000000..0b553fa1fb --- /dev/null +++ b/poly-commitment/tests/kzg.rs @@ -0,0 +1,248 @@ +use ark_bn254::{Config, Fr as ScalarField, G1Affine as G1, G2Affine as G2}; +use ark_ec::{bn::Bn, AffineRepr}; +use ark_ff::UniformRand; +use ark_poly::{ + univariate::DensePolynomial, DenseUVPolynomial, EvaluationDomain, Polynomial, + Radix2EvaluationDomain as D, +}; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; +use mina_curves::pasta::{Fp, Vesta as VestaG}; +use poly_commitment::{ + commitment::Evaluation, + ipa::SRS, + kzg::{combine_evaluations, KZGProof, PairingSRS}, + pbt_srs, + utils::DensePolynomialOrEvaluations, + PolyComm, SRS as _, +}; + +#[test] +fn test_combine_evaluations() { + let nb_of_chunks = 1; + + // we ignore commitments + let dummy_commitments = PolyComm:: { + chunks: vec![VestaG::zero(); nb_of_chunks], + }; + + let polyscale = Fp::from(2); + // Using only one evaluation. Starting with eval_p1 + { + let eval_p1 = Evaluation { + commitment: dummy_commitments.clone(), + evaluations: vec![ + // Eval at first point. Only one chunk. + vec![Fp::from(1)], + // Eval at second point. Only one chunk. + vec![Fp::from(2)], + ], + }; + + let output = combine_evaluations::(&vec![eval_p1], polyscale); + // We have 2 evaluation points. + assert_eq!(output.len(), 2); + // polyscale is not used. + let exp_output = [Fp::from(1), Fp::from(2)]; + output.iter().zip(exp_output.iter()).for_each(|(o, e)| { + assert_eq!(o, e); + }); + } + + // And after that eval_p2 + { + let eval_p2 = Evaluation { + commitment: dummy_commitments.clone(), + evaluations: vec![ + // Eval at first point. Only one chunk. + vec![Fp::from(3)], + // Eval at second point. Only one chunk. + vec![Fp::from(4)], + ], + }; + + let output = combine_evaluations::(&vec![eval_p2], polyscale); + // We have 2 evaluation points + assert_eq!(output.len(), 2); + // polyscale is not used. + let exp_output = [Fp::from(3), Fp::from(4)]; + output.iter().zip(exp_output.iter()).for_each(|(o, e)| { + assert_eq!(o, e); + }); + } + + // Now with two evaluations + { + let eval_p1 = Evaluation { + commitment: dummy_commitments.clone(), + evaluations: vec![ + // Eval at first point. Only one chunk. + vec![Fp::from(1)], + // Eval at second point. Only one chunk. + vec![Fp::from(2)], + ], + }; + + let eval_p2 = Evaluation { + commitment: dummy_commitments.clone(), + evaluations: vec![ + // Eval at first point. Only one chunk. + vec![Fp::from(3)], + // Eval at second point. Only one chunk. + vec![Fp::from(4)], + ], + }; + + let output = combine_evaluations::(&vec![eval_p1, eval_p2], polyscale); + // We have 2 evaluation points + assert_eq!(output.len(), 2); + let exp_output = [Fp::from(1 + 3 * 2), Fp::from(2 + 4 * 2)]; + output.iter().zip(exp_output.iter()).for_each(|(o, e)| { + assert_eq!(o, e); + }); + } + + // Now with two evaluations and two chunks + { + let eval_p1 = Evaluation { + commitment: dummy_commitments.clone(), + evaluations: vec![ + // Eval at first point. + vec![Fp::from(1), Fp::from(3)], + // Eval at second point. + vec![Fp::from(2), Fp::from(4)], + ], + }; + + let eval_p2 = Evaluation { + commitment: dummy_commitments.clone(), + evaluations: vec![ + // Eval at first point. + vec![Fp::from(5), Fp::from(7)], + // Eval at second point. + vec![Fp::from(6), Fp::from(8)], + ], + }; + + let output = combine_evaluations::(&vec![eval_p1, eval_p2], polyscale); + // We have 2 evaluation points + assert_eq!(output.len(), 2); + let o1 = Fp::from(1 + 3 * 2 + 5 * 4 + 7 * 8); + let o2 = Fp::from(2 + 4 * 2 + 6 * 4 + 8 * 8); + let exp_output = [o1, o2]; + output.iter().zip(exp_output.iter()).for_each(|(o, e)| { + assert_eq!(o, e); + }); + } +} + +#[test] +fn test_kzg_proof() { + let n = 64; + let domain = D::::new(n).unwrap(); + + let mut rng = o1_utils::tests::make_test_rng(None); + let x = ScalarField::rand(&mut rng); + + let srs = unsafe { SRS::::create_trusted_setup(x, n) }; + let verifier_srs = unsafe { SRS::::create_trusted_setup(x, 3) }; + srs.get_lagrange_basis(domain); + + let srs = PairingSRS { + full_srs: srs, + verifier_srs, + }; + + let polynomials: Vec<_> = (0..4) + .map(|_| { + let coeffs = (0..63).map(|_| ScalarField::rand(&mut rng)).collect(); + DensePolynomial::from_coefficients_vec(coeffs) + }) + .collect(); + + let comms: Vec<_> = polynomials + .iter() + .map(|p| srs.full_srs.commit(p, 1, &mut rng)) + .collect(); + + let polynomials_and_blinders: Vec<(DensePolynomialOrEvaluations<_, D<_>>, _)> = polynomials + .iter() + .zip(comms.iter()) + .map(|(p, comm)| { + let p = DensePolynomialOrEvaluations::DensePolynomial(p); + (p, comm.blinders.clone()) + }) + .collect(); + + let evaluation_points = vec![ScalarField::rand(&mut rng), ScalarField::rand(&mut rng)]; + + let evaluations: Vec<_> = polynomials + .iter() + .zip(comms) + .map(|(p, commitment)| { + let evaluations = evaluation_points + .iter() + .map(|x| { + // Inputs are chosen to use only 1 chunk + vec![p.evaluate(x)] + }) + .collect(); + Evaluation { + commitment: commitment.commitment, + evaluations, + } + }) + .collect(); + + let polyscale = ScalarField::rand(&mut rng); + + let kzg_proof = KZGProof::>::create( + &srs, + polynomials_and_blinders.as_slice(), + &evaluation_points, + polyscale, + ) + .unwrap(); + + let res = kzg_proof.verify(&srs, &evaluations, polyscale, &evaluation_points); + assert!(res); +} + +/// Our points in G2 are not actually in the correct subgroup and serialize well. +#[test] +fn check_srs_g2_valid_and_serializes() { + type BN254 = Bn; + type BN254G2BaseField = ::BaseField; + + let srs: PairingSRS = PairingSRS::create(1 << 5); + + let mut vec: Vec = vec![0u8; 1024]; + + for actual in [ + srs.verifier_srs.h, + srs.verifier_srs.g[0], + srs.verifier_srs.g[1], + ] { + // Check it's valid + assert!(!actual.is_zero()); + assert!(actual.is_on_curve()); + assert!(actual.is_in_correct_subgroup_assuming_on_curve()); + + // Check it serializes well + let actual_y: BN254G2BaseField = actual.y; + let res = actual_y.serialize_compressed(vec.as_mut_slice()); + assert!(res.is_ok()); + let expected: BN254G2BaseField = + CanonicalDeserialize::deserialize_compressed(vec.as_slice()).unwrap(); + assert!(expected == actual_y, "serialization failed"); + } +} + +// Testing how many chunks are generated with different polynomial sizes and +// different number of chunks requested. +#[test] +fn test_regression_commit_non_hiding_expected_number_of_chunks() { + type BN254 = Bn; + type Srs = PairingSRS; + + pbt_srs::test_regression_commit_non_hiding_expected_number_of_chunks::(); +} diff --git a/poseidon/benches/poseidon.rs b/poseidon/benches/poseidon.rs new file mode 100644 index 0000000000..80bb0f213e --- /dev/null +++ b/poseidon/benches/poseidon.rs @@ -0,0 +1,37 @@ +use mina_poseidon::{ + constants::{PlonkSpongeConstantsKimchi, PlonkSpongeConstantsLegacy}, + pasta::{fp_kimchi as SpongeParametersKimchi, fp_legacy as SpongeParametersLegacy}, + permutation::poseidon_block_cipher, +}; +use rand::{rngs::StdRng, thread_rng, Rng, SeedableRng}; + +use ark_ff::UniformRand; +use criterion::{criterion_group, criterion_main, Criterion}; +use mina_curves::pasta::Fp; + +pub fn bench_poseidon_absorb_permutation_pasta_fp(c: &mut Criterion) { + // FIXME: use o1_utils test rng + let seed = thread_rng().gen(); + eprintln!("Seed: {seed:?}"); + let mut rng = StdRng::from_seed(seed); + + let input: [Fp; 3] = std::array::from_fn(|_| Fp::rand(&mut rng)); + let mut input: Vec = Vec::from(input); + + let params = SpongeParametersKimchi::static_params(); + c.bench_function("poseidon_absorb_permutation kimchi", |b| { + b.iter(|| { + poseidon_block_cipher::(params, &mut input); + }) + }); + + let params = SpongeParametersLegacy::static_params(); + c.bench_function("poseidon_absorb_permutation legacy", |b| { + b.iter(|| { + poseidon_block_cipher::(params, &mut input); + }) + }); +} + +criterion_group!(benches, bench_poseidon_absorb_permutation_pasta_fp); +criterion_main!(benches); diff --git a/poseidon/src/lib.rs b/poseidon/src/lib.rs index 943d54cb87..e2ff0f1e60 100644 --- a/poseidon/src/lib.rs +++ b/poseidon/src/lib.rs @@ -5,19 +5,4 @@ pub mod permutation; pub mod poseidon; pub mod sponge; -#[cfg(test)] -mod tests; - -use ark_ff::Field; - -pub trait FqSponge { - fn new(p: &'static poseidon::ArithmeticSpongeParams) -> Self; - fn absorb_g(&mut self, g: &[G]); - fn absorb_fq(&mut self, x: &[Fq]); - fn absorb_fr(&mut self, x: &[Fr]); - fn challenge(&mut self) -> Fr; - fn challenge_fq(&mut self) -> Fq; - - fn digest(self) -> Fr; - fn digest_fq(self) -> Fq; -} +pub use sponge::FqSponge; // Commonly used so reexported for convenience diff --git a/poseidon/src/pasta/params.sage b/poseidon/src/pasta/params.sage index eee440277f..148078e70b 100755 --- a/poseidon/src/pasta/params.sage +++ b/poseidon/src/pasta/params.sage @@ -35,6 +35,14 @@ parser.add_argument('name', type=str, help='Name of parameter set (e.g. \'\', 5 parser.add_argument('--rounds', type=int, default=100, help='Number of round constants') args = parser.parse_args() +# FIXME: This is a hack to make the script work for BN254. We should generalize the script later. +# BN254/Grumpkin ("Ethereum" curves) +## BN254 Base field +_bn254_q = 21888242871839275222246405745257275088548364400416034343698204186575808495617 +## BN254 Scalar field +_bn254_p = 21888242871839275222246405745257275088696311157297823662689037894645226208583 + +# Pasta (Pallas/Vesta) curves _pasta_p = 28948022309329048855892746252171976963363056481941560715954676764349967630337 _pasta_q = 28948022309329048855892746252171976963363056481941647379679742748393362948097 diff --git a/poseidon/src/permutation.rs b/poseidon/src/permutation.rs index b3d58776a5..c64f83f1ca 100644 --- a/poseidon/src/permutation.rs +++ b/poseidon/src/permutation.rs @@ -1,4 +1,5 @@ -//! The permutation module contains the function implementing the permutation used in Poseidon +//! The permutation module contains the function implementing the permutation +//! used in Poseidon. use crate::{ constants::SpongeConstants, @@ -30,6 +31,12 @@ fn apply_mds_matrix( } } +/// Apply a full round of the permutation. +/// A full round is composed of the following steps: +/// - Apply the S-box to each element of the state. +/// - Apply the MDS matrix to the state. +/// - Add the round constants to the state. +/// The function has side-effect and the parameter state is modified. pub fn full_round( params: &ArithmeticSpongeParams, state: &mut Vec, @@ -55,7 +62,10 @@ pub fn half_rounds( for state_i in state.iter_mut() { *state_i = sbox::(*state_i); } - apply_mds_matrix::(params, state); + let res = apply_mds_matrix::(params, state); + for (i, state_i) in state.iter_mut().enumerate() { + *state_i = res[i] + } } for r in 0..SC::PERM_ROUNDS_PARTIAL { @@ -66,7 +76,10 @@ pub fn half_rounds( state[i].add_assign(x); } state[0] = sbox::(state[0]); - apply_mds_matrix::(params, state); + let res = apply_mds_matrix::(params, state); + res.iter().enumerate().for_each(|(i, x)| { + state[i] = *x; + }); } for r in 0..SC::PERM_HALF_ROUNDS_FULL { @@ -80,7 +93,10 @@ pub fn half_rounds( for state_i in state.iter_mut() { *state_i = sbox::(*state_i); } - apply_mds_matrix::(params, state); + let res = apply_mds_matrix::(params, state); + res.iter().enumerate().for_each(|(i, x)| { + state[i] = *x; + }); } } diff --git a/poseidon/src/poseidon.rs b/poseidon/src/poseidon.rs index 2ce1d1f3d3..056a964c3e 100644 --- a/poseidon/src/poseidon.rs +++ b/poseidon/src/poseidon.rs @@ -25,8 +25,19 @@ pub trait Sponge { fn reset(&mut self); } -pub fn sbox(x: F) -> F { - x.pow([SC::PERM_SBOX as u64]) +pub fn sbox(mut x: F) -> F { + if SC::PERM_SBOX == 7 { + // This is much faster than using the generic `pow`. Hard-code to get the ~50% speed-up + // that it gives to hashing. + let mut square = x; + square.square_in_place(); + x *= square; + square.square_in_place(); + x *= square; + x + } else { + x.pow([SC::PERM_SBOX as u64]) + } } #[derive(Clone, Debug)] diff --git a/poseidon/src/sponge.rs b/poseidon/src/sponge.rs index 10fc680680..be9ba0a319 100644 --- a/poseidon/src/sponge.rs +++ b/poseidon/src/sponge.rs @@ -5,7 +5,39 @@ use crate::{ use ark_ec::models::short_weierstrass::{Affine, SWCurveConfig}; use ark_ff::{BigInteger, Field, One, PrimeField, Zero}; -pub use crate::FqSponge; +/// Abstracts a sponge operating on a base field `Fq` of the curve +/// `G`. The parameter `Fr` is modelling the scalar field of the +/// curve. +pub trait FqSponge { + /// Creates a new sponge. + fn new(p: &'static ArithmeticSpongeParams) -> Self; + + /// Absorbs a base field element. This operation is the most + /// straightforward and calls the underlying sponge directly. + fn absorb_fq(&mut self, x: &[Fq]); + + /// Absorbs a base field point, that is a pair of `Fq` elements. + fn absorb_g(&mut self, g: &[G]); + + /// Absorbs an element of the scalar field `Fr` --- it is done + /// by converting the element to the base field first. + fn absorb_fr(&mut self, x: &[Fr]); + + /// Squeeze out a base field challenge. This operation is the most + /// direct and calls the underlying sponge. + fn challenge_fq(&mut self) -> Fq; + + /// Squeeze out a challenge in the scalar field. Implemented by + /// squeezing out base points and then converting them to a scalar + /// field element using binary representation. + fn challenge(&mut self) -> Fr; + + /// Returns a base field digest by squeezing the underlying sponge directly. + fn digest_fq(self) -> Fq; + + /// Returns a scalar field digest using the binary representation technique. + fn digest(self) -> Fr; +} pub const CHALLENGE_LENGTH_IN_LIMBS: usize = 2; diff --git a/poseidon/src/tests/mod.rs b/poseidon/src/tests/mod.rs deleted file mode 100644 index e3edb136c6..0000000000 --- a/poseidon/src/tests/mod.rs +++ /dev/null @@ -1 +0,0 @@ -mod poseidon_tests; diff --git a/poseidon/src/tests/poseidon_tests.rs b/poseidon/tests/poseidon_tests.rs similarity index 97% rename from poseidon/src/tests/poseidon_tests.rs rename to poseidon/tests/poseidon_tests.rs index 4d94293f25..61345b2298 100644 --- a/poseidon/src/tests/poseidon_tests.rs +++ b/poseidon/tests/poseidon_tests.rs @@ -1,9 +1,9 @@ -use crate::{ +use mina_curves::pasta::Fp; +use mina_poseidon::{ constants::{PlonkSpongeConstantsKimchi, PlonkSpongeConstantsLegacy}, pasta::{fp_kimchi as SpongeParametersKimchi, fp_legacy as SpongeParametersLegacy}, poseidon::{ArithmeticSponge as Poseidon, Sponge as _}, }; -use mina_curves::pasta::Fp; use o1_utils::FieldHelpers; use serde::Deserialize; use std::{fs::File, path::PathBuf}; // needed for ::new() sponge @@ -29,7 +29,7 @@ where { // read test vectors from given file let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); - path.push("src/tests/test_vectors"); + path.push("tests/test_vectors"); path.push(test_vector_file); let file = File::open(&path).expect("couldn't open test vector file"); let test_vectors: TestVectors = diff --git a/poseidon/src/tests/test_vectors/kimchi.json b/poseidon/tests/test_vectors/kimchi.json similarity index 100% rename from poseidon/src/tests/test_vectors/kimchi.json rename to poseidon/tests/test_vectors/kimchi.json diff --git a/poseidon/src/tests/test_vectors/legacy.json b/poseidon/tests/test_vectors/legacy.json similarity index 100% rename from poseidon/src/tests/test_vectors/legacy.json rename to poseidon/tests/test_vectors/legacy.json diff --git a/proof-systems-vendors b/proof-systems-vendors deleted file mode 160000 index 7ea36d68a4..0000000000 --- a/proof-systems-vendors +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 7ea36d68a4f325b1444850978a174ded8b0d4afe diff --git a/signer/Cargo.toml b/signer/Cargo.toml index db3dc3e985..11bafec639 100644 --- a/signer/Cargo.toml +++ b/signer/Cargo.toml @@ -20,10 +20,10 @@ o1-utils.workspace = true ark-ec.workspace = true ark-ff.workspace = true -rand.workspace = true +rand.workspace = true blake2.workspace = true -hex.workspace = true -bitvec.workspace = true +hex.workspace = true +bitvec.workspace = true sha2.workspace = true -bs58.workspace = true +bs58.workspace = true thiserror.workspace = true diff --git a/signer/src/keypair.rs b/signer/src/keypair.rs index 831e01efb7..8f76383020 100644 --- a/signer/src/keypair.rs +++ b/signer/src/keypair.rs @@ -27,7 +27,7 @@ pub type Result = std::result::Result; #[derive(Clone, PartialEq, Eq)] pub struct Keypair { /// Secret key - pub(crate) secret: SecKey, + pub secret: SecKey, /// Public key pub public: PubKey, } @@ -108,102 +108,3 @@ impl fmt::Display for Keypair { write!(f, "{}", self.public) } } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn from_hex() { - assert_eq!( - Keypair::from_hex(""), - Err(KeypairError::SecretKey(SecKeyError::SecretKeyBytes)) - ); - assert_eq!( - Keypair::from_hex("1428fadcf0c02396e620f14f176fddb5d769b7de2027469d027a80142ef8f07"), - Err(KeypairError::SecretKey(SecKeyError::SecretKeyHex)) - ); - assert_eq!( - Keypair::from_hex("0f5314f176fddb5d769b7de2027469d027ad428fadcf0c02396e6280142efb7d8"), - Err(KeypairError::SecretKey(SecKeyError::SecretKeyHex)) - ); - assert_eq!( - Keypair::from_hex("g64244176fddb5d769b7de2027469d027ad428fadcf0c02396e6280142efb7d8"), - Err(KeypairError::SecretKey(SecKeyError::SecretKeyHex)) - ); - assert_eq!( - Keypair::from_hex("4244176fddb5d769b7de2027469d027ad428fadcc0c02396e6280142efb718"), - Err(KeypairError::SecretKey(SecKeyError::SecretKeyBytes)) - ); - - Keypair::from_hex("164244176fddb5d769b7de2027469d027ad428fadcc0c02396e6280142efb718") - .expect("failed to decode keypair secret key"); - } - - #[test] - fn get_address() { - macro_rules! assert_get_address_eq { - ($sec_key_hex:expr, $target_address:expr) => { - let kp = Keypair::from_hex($sec_key_hex).expect("failed to create keypair"); - assert_eq!(kp.get_address(), $target_address); - }; - } - - assert_get_address_eq!( - "164244176fddb5d769b7de2027469d027ad428fadcc0c02396e6280142efb718", - "B62qnzbXmRNo9q32n4SNu2mpB8e7FYYLH8NmaX6oFCBYjjQ8SbD7uzV" - ); - assert_get_address_eq!( - "3ca187a58f09da346844964310c7e0dd948a9105702b716f4d732e042e0c172e", - "B62qicipYxyEHu7QjUqS7QvBipTs5CzgkYZZZkPoKVYBu6tnDUcE9Zt" - ); - assert_get_address_eq!( - "336eb4a19b3d8905824b0f2254fb495573be302c17582748bf7e101965aa4774", - "B62qrKG4Z8hnzZqp1AL8WsQhQYah3quN1qUj3SyfJA8Lw135qWWg1mi" - ); - assert_get_address_eq!( - "1dee867358d4000f1dafa5978341fb515f89eeddbe450bd57df091f1e63d4444", - "B62qoqiAgERjCjXhofXiD7cMLJSKD8hE8ZtMh4jX5MPNgKB4CFxxm1N" - ); - assert_get_address_eq!( - "20f84123a26e58dd32b0ea3c80381f35cd01bc22a20346cc65b0a67ae48532ba", - "B62qkiT4kgCawkSEF84ga5kP9QnhmTJEYzcfgGuk6okAJtSBfVcjm1M" - ); - assert_get_address_eq!( - "3414fc16e86e6ac272fda03cf8dcb4d7d47af91b4b726494dab43bf773ce1779", - "B62qoG5Yk4iVxpyczUrBNpwtx2xunhL48dydN53A2VjoRwF8NUTbVr4" - ); - } - - #[test] - fn to_bytes() { - let bytes = [ - 61, 18, 244, 30, 36, 241, 5, 54, 107, 96, 154, 162, 58, 78, 242, 140, 186, 233, 25, 35, - 145, 119, 39, 94, 162, 123, 208, 202, 189, 29, 235, 209, - ]; - assert_eq!( - Keypair::from_bytes(&bytes) - .expect("failed to decode keypair") - .to_bytes(), - bytes - ); - - // negative test (too many bytes) - assert_eq!( - Keypair::from_bytes(&[ - 61, 18, 244, 30, 36, 241, 5, 54, 107, 96, 154, 162, 58, 78, 242, 140, 186, 233, 25, - 35, 145, 119, 39, 94, 162, 123, 208, 202, 189, 29, 235, 209, 0 - ]), - Err(KeypairError::SecretKey(SecKeyError::SecretKeyBytes)) - ); - - // negative test (too few bytes) - assert_eq!( - Keypair::from_bytes(&[ - 61, 18, 244, 30, 36, 241, 5, 54, 107, 96, 154, 162, 58, 78, 242, 140, 186, 233, 25, - 35, 145, 119, 39, 94, 162, 123, 208, 202, 189, 29 - ]), - Err(KeypairError::SecretKey(SecKeyError::SecretKeyBytes)) - ); - } -} diff --git a/signer/src/pubkey.rs b/signer/src/pubkey.rs index 60718e1a23..a7eaddba11 100644 --- a/signer/src/pubkey.rs +++ b/signer/src/pubkey.rs @@ -338,235 +338,3 @@ impl CompressedPubKey { hex::encode(self.to_bytes()) } } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn from_hex() { - assert_eq!( - PubKey::to_hex( - &PubKey::from_hex( - "44100485d466a4c9f481d43be9a6d4a9a5e97adac19777b14b6b7a81edee39093179f4f12897797cfe78bf2fd6321f36b00bd0592defbf39a199b4168735883c" - ) - .expect("failed to decode pub key"), - ), - "44100485d466a4c9f481d43be9a6d4a9a5e97adac19777b14b6b7a81edee39093179f4f12897797cfe78bf2fd6321f36b00bd0592defbf39a199b4168735883c" - ); - - assert_eq!( - PubKey::to_hex( - &PubKey::from_hex( - "0c18d735252f4401eb845d08a87d780e9dc6ed053d0e071395277f0a34d45a29125030f6cd465cccf2ebbaabc397a6823932bc55dc1ddf9954c9e78da07c0928" - ) - .expect("failed to decode pub key"), - ), - "0c18d735252f4401eb845d08a87d780e9dc6ed053d0e071395277f0a34d45a29125030f6cd465cccf2ebbaabc397a6823932bc55dc1ddf9954c9e78da07c0928" - ); - - assert_eq!( - PubKey::from_hex("44100485d466a4c9f281d43be9a6d4a9a5e97adac19777b14b6b7a81edee39093179f4f12897797cfe78bf2fd6321f36b00bd0592defbf39a199b4168735883c"), - Err(PubKeyError::XCoordinate) - ); - - assert_eq!( - PubKey::from_hex("44100485d466a4c9f481d43be9a6d4a9a5e97adac19777b14b6b7a81edee39093179f4f12897797cfe78bf2fd6321f36b00bd0592defbf39a199b4168715883c"), - Err(PubKeyError::NonCurvePoint) - ); - - assert_eq!( - PubKey::from_hex("z8f71009a6502f89a469a037f14e97386034ea3855877159f29f4a532a2e5f28"), - Err(PubKeyError::Hex) - ); - - assert_eq!( - PubKey::from_hex("44100485d466a4c9f481d43be9a6d4aa5e97adac19777b14b6b7a81edee39093179f4f12897797cfe78bf2fd6321f36b00bd0592defbf39a199b4168735883c"), - Err(PubKeyError::Hex) - ); - } - - #[test] - fn from_secret_key() { - assert_eq!(PubKey::from_secret_key( - SecKey::from_hex("090dd91a2505081a158782c5a24fef63f326749b383423c29827e465f7ca262b").expect("failed to decode sec key") - ).expect("failed to decode pub key").to_hex(), - "3f8c32817851b7c1ad99495463ef5e99c3b3240524f0df3ff7fc41181d849e0086fa821d54de15c523a840c5f62df90aeabb1097b85c6a88e163e9d74e505803" - ); - - assert_eq!(PubKey::from_secret_key( - SecKey::from_hex("086d78e0e5deb62daeef8e3a5574d52a3d3bff5281b4dd49140564c7d80468c9").expect("failed to decode sec key") - ).expect("failed to decode pub key").to_hex(), - "666c450f5e888d3b2341d77b32cb6d0cd4912829ea9c41030d1fd2baff6b9a30c267208638544299e8d369e80b25a24bdd07383b6ea908028d9a406b528d4a01" - ); - - assert_eq!(PubKey::from_secret_key( - SecKey::from_hex("0859771e9394e96dd6d01d57ef074dc25313e63bd331fa5478a9fed9e24855a0").expect("failed to decode sec key") - ).expect("failed to decode pub key").to_hex(), - "6ed0776ab11e3dd3b637cce03a90529e518220132f1a61dd9c0d50aa998abf1d2f43c0f1eb73888ef6f7dac4d7094d3c92cd67abab39b828c5f10aff0b6a0002" - ); - } - - #[test] - fn from_address() { - macro_rules! assert_from_address_check { - ($address:expr) => { - let pk = PubKey::from_address($address).expect("failed to create pubkey"); - assert_eq!(pk.into_address(), $address); - }; - } - - assert_from_address_check!("B62qnzbXmRNo9q32n4SNu2mpB8e7FYYLH8NmaX6oFCBYjjQ8SbD7uzV"); - assert_from_address_check!("B62qicipYxyEHu7QjUqS7QvBipTs5CzgkYZZZkPoKVYBu6tnDUcE9Zt"); - assert_from_address_check!("B62qoG5Yk4iVxpyczUrBNpwtx2xunhL48dydN53A2VjoRwF8NUTbVr4"); - assert_from_address_check!("B62qrKG4Z8hnzZqp1AL8WsQhQYah3quN1qUj3SyfJA8Lw135qWWg1mi"); - assert_from_address_check!("B62qoqiAgERjCjXhofXiD7cMLJSKD8hE8ZtMh4jX5MPNgKB4CFxxm1N"); - assert_from_address_check!("B62qkiT4kgCawkSEF84ga5kP9QnhmTJEYzcfgGuk6okAJtSBfVcjm1M"); - } - - #[test] - fn to_bytes() { - let mut bytes = vec![ - 68, 16, 4, 133, 212, 102, 164, 201, 244, 129, 212, 59, 233, 166, 212, 169, 165, 233, - 122, 218, 193, 151, 119, 177, 75, 107, 122, 129, 237, 238, 57, 9, 49, 121, 244, 241, - 40, 151, 121, 124, 254, 120, 191, 47, 214, 50, 31, 54, 176, 11, 208, 89, 45, 239, 191, - 57, 161, 153, 180, 22, 135, 53, 136, 60, - ]; - assert_eq!( - PubKey::from_bytes(&bytes) - .expect("failed to decode pub key") - .to_bytes(), - bytes - ); - - bytes[0] = 0; // negative test: invalid curve point - assert_eq!(PubKey::from_bytes(&bytes), Err(PubKeyError::NonCurvePoint)); - - bytes[0] = 68; - let mut bytes = [bytes, vec![255u8, 102u8]].concat(); // negative test: to many bytes - assert_eq!( - PubKey::from_bytes(&bytes), - Err(PubKeyError::YCoordinateBytes) - ); - - bytes.remove(0); // negative test: to few bytes - bytes.remove(0); - assert_eq!( - PubKey::from_bytes(&bytes), - Err(PubKeyError::XCoordinateBytes) - ); - } - - #[test] - fn compressed_from_hex() { - assert_eq!(PubKey::from_hex( - "44100485d466a4c9f481d43be9a6d4a9a5e97adac19777b14b6b7a81edee39093179f4f12897797cfe78bf2fd6321f36b00bd0592defbf39a199b4168735883c" - ).expect("failed to decode pub key").into_address(), - CompressedPubKey::from_hex( - "44100485d466a4c9f481d43be9a6d4a9a5e97adac19777b14b6b7a81edee390901" - ).expect("failed to decode compressed pub key").into_address() - ); - - assert_eq!(PubKey::from_hex( - "44100485d466a4c9f481d43be9a6d4a9a5e97adac19777b14b6b7a81edee39093179f4f12897797cfe78bf2fd6321f36b00bd0592defbf39a199b4168735883c" - ).expect("failed to decode pub key").into_compressed(), - CompressedPubKey::from_hex( - "44100485d466a4c9f481d43be9a6d4a9a5e97adac19777b14b6b7a81edee390901" - ).expect("failed to decode compressed pub key") - ); - - assert_ne!(PubKey::from_hex( - "44100485d466a4c9f481d43be9a6d4a9a5e97adac19777b14b6b7a81edee39093179f4f12897797cfe78bf2fd6321f36b00bd0592defbf39a199b4168735883c" - ).expect("failed to decode pub key").into_compressed(), - CompressedPubKey::from_hex( // Invalid parity bit - "44100485d466a4c9f481d43be9a6d4a9a5e97adac19777b14b6b7a81edee390900" - ).expect("failed to decode compressed pub key") - ); - - assert_eq!(PubKey::from_hex( - "0c18d735252f4401eb845d08a87d780e9dc6ed053d0e071395277f0a34d45a29125030f6cd465cccf2ebbaabc397a6823932bc55dc1ddf9954c9e78da07c0928" - ).expect("failed to decode pub key").into_compressed(), - CompressedPubKey::from_hex( - "0c18d735252f4401eb845d08a87d780e9dc6ed053d0e071395277f0a34d45a2900" - ).expect("failed to decode compressed pub key") - ); - - assert_eq!( - CompressedPubKey::from_hex( - // Missing parity bit - "0c18d735252f4401eb845d08a87d780e9dc6ed053d0e071395277f0a34d45a29" - ), - Err(PubKeyError::YCoordinateParityBytes) - ); - - assert_eq!( - CompressedPubKey::from_hex( - // Wrong parity bytes - "0c18d735252f4401eb845d08a87d780e9dc6ed053d0e071395277f0a34d45a290101" - ), - Err(PubKeyError::YCoordinateParityBytes) - ); - - assert_eq!( - CompressedPubKey::from_hex( - // Invalid parity byte - "0c18d735252f4401eb845d08a87d780e9dc6ed053d0e071395277f0a34d45a2902" - ), - Err(PubKeyError::YCoordinateParity) - ); - - assert!(CompressedPubKey::from_hex( - // OK parity byte (odd) - "0c18d735252f4401eb845d08a87d780e9dc6ed053d0e071395277f0a34d45a2900" - ) - .is_ok()); - - assert!(CompressedPubKey::from_hex( - // OK parity byte (odd) - "0c18d735252f4401eb845d08a87d780e9dc6ed053d0e071395277f0a34d45a2901" - ) - .is_ok()); - - assert_ne!(PubKey::from_hex( - "0c18d735252f4401eb845d08a87d780e9dc6ed053d0e071395277f0a34d45a29125030f6cd465cccf2ebbaabc397a6823932bc55dc1ddf9954c9e78da07c0928" - ).expect("failed to decode pub key").into_compressed(), - CompressedPubKey::from_hex( - "0c18d735252f4401eb845d08a87d780e9dc6ed053d0e071395277f0a34d45a2901" - ).expect("failed to decode compressed pub key") - ); - } - - #[test] - fn compressed_to_bytes() { - let mut bytes = vec![ - 68, 16, 4, 133, 212, 102, 164, 201, 244, 129, 212, 59, 233, 166, 212, 169, 165, 233, - 122, 218, 193, 151, 119, 177, 75, 107, 122, 129, 237, 238, 57, 9, 1, - ]; - assert_eq!( - CompressedPubKey::from_bytes(&bytes) - .expect("failed to decode pub key") - .to_bytes(), - bytes - ); - - bytes[4] = 73; // negative test: invalid x - assert_eq!( - CompressedPubKey::from_bytes(&bytes), - Err(PubKeyError::XCoordinate) - ); - - bytes[0] = 212; - let mut bytes = [bytes, vec![255u8]].concat(); // negative test: to many bytes - assert_eq!( - CompressedPubKey::from_bytes(&bytes), - Err(PubKeyError::YCoordinateParityBytes) - ); - - bytes.remove(0); // negative test: to few bytes - bytes.remove(0); - assert_eq!( - CompressedPubKey::from_bytes(&bytes), - Err(PubKeyError::XCoordinateBytes) - ); - } -} diff --git a/signer/src/seckey.rs b/signer/src/seckey.rs index d71d2d5939..b2979dd54e 100644 --- a/signer/src/seckey.rs +++ b/signer/src/seckey.rs @@ -151,101 +151,3 @@ impl SecKey { bs58::encode(raw).into_string() } } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn from_hex() { - assert_eq!( - SecKey::to_hex( - &SecKey::from_hex( - "3d12f41e24f105366b609aa23a4ef28cbae919239177275ea27bd0cabd1debd1" - ) - .expect("failed to decode sec key"), - ), - "3d12f41e24f105366b609aa23a4ef28cbae919239177275ea27bd0cabd1debd1" - ); - - assert_eq!( - SecKey::to_hex( - &SecKey::from_hex( - "285f2e2a534a9ff25971875538ea346038974ef137a069a4892f50e60910f7d8" - ) - .expect("failed to decode sec key"), - ), - "285f2e2a534a9ff25971875538ea346038974ef137a069a4892f50e60910f7d8" - ); - - assert_eq!( - SecKey::from_hex("d8f71009e6502f89a469a037f14e97386034ea3855877159f29f4a532a2e5f28"), - Err(SecKeyError::SecretKeyBytes) - ); - - assert_eq!( - SecKey::from_hex("d8f71009g6502f89a469a037f14e97386034ea3855877159f29f4a532a2e5f28"), - Err(SecKeyError::SecretKeyHex) - ); - } - - #[test] - fn to_bytes() { - let bytes = [ - 40, 95, 46, 42, 83, 74, 159, 242, 89, 113, 135, 85, 56, 234, 52, 96, 56, 151, 78, 241, - 55, 160, 105, 164, 137, 47, 80, 230, 9, 16, 247, 216, - ]; - assert_eq!( - SecKey::from_bytes(&bytes) - .expect("failed to decode sec key") - .to_bytes(), - bytes - ); - - // negative test (too many bytes) - assert_eq!( - SecKey::from_bytes(&[ - 40, 95, 46, 42, 83, 74, 159, 242, 89, 113, 135, 85, 56, 234, 52, 96, 56, 151, 78, - 241, 55, 160, 105, 164, 137, 47, 80, 230, 9, 16, 247, 216, 10 - ]), - Err(SecKeyError::SecretKeyBytes) - ); - - // negative test (too few bytes) - assert_eq!( - SecKey::from_bytes(&[ - 40, 95, 46, 42, 83, 74, 159, 242, 89, 113, 135, 85, 56, 234, 52, 96, 56, 151, 78, - 241, 55, 160, 105, 164, 137, 47, 80, 230, 9, 16, 247 - ]), - Err(SecKeyError::SecretKeyBytes) - ); - } - - #[test] - fn base58() { - assert_eq!( - SecKey::from_base58("EKFS3M4Fe1VkVjPMn2jauXa1Vv6w6gRES5oLbH3vZmP26uQESodY") - .expect("failed to decode sec key") - .to_base58(), - "EKFS3M4Fe1VkVjPMn2jauXa1Vv6w6gRES5oLbH3vZmP26uQESodY" - ); - - // invalid checksum - assert_eq!( - SecKey::from_base58("EKFS3M4Fe1VkVjPMn2jauXa1Vv6w6gRES5oLbH3vZmP26uQESodZ"), - Err(SecKeyError::SecretKeyChecksum) - ); - - // invalid version - assert_eq!( - SecKey::from_base58("ETq4cWR9pAQtUFQ8L78UhhfVkMJaw6gxbXRU9jQ24F8jPEh7tn3q"), - Err(SecKeyError::SecretKeyVersion) - ); - - // invalid length - assert_eq!( - SecKey::from_base58("EKFS3M4Fe1VkVjPMn2a1Vv6w6gRES5oLbH3vZmP26uQESodY"), - Err(SecKeyError::SecretKeyLength) - ); - } -} diff --git a/signer/tests/keypair.rs b/signer/tests/keypair.rs new file mode 100644 index 0000000000..b864459463 --- /dev/null +++ b/signer/tests/keypair.rs @@ -0,0 +1,95 @@ +use mina_signer::{keypair::KeypairError, seckey::SecKeyError, Keypair}; + +#[test] +fn from_hex() { + assert_eq!( + Keypair::from_hex(""), + Err(KeypairError::SecretKey(SecKeyError::SecretKeyBytes)) + ); + assert_eq!( + Keypair::from_hex("1428fadcf0c02396e620f14f176fddb5d769b7de2027469d027a80142ef8f07"), + Err(KeypairError::SecretKey(SecKeyError::SecretKeyHex)) + ); + assert_eq!( + Keypair::from_hex("0f5314f176fddb5d769b7de2027469d027ad428fadcf0c02396e6280142efb7d8"), + Err(KeypairError::SecretKey(SecKeyError::SecretKeyHex)) + ); + assert_eq!( + Keypair::from_hex("g64244176fddb5d769b7de2027469d027ad428fadcf0c02396e6280142efb7d8"), + Err(KeypairError::SecretKey(SecKeyError::SecretKeyHex)) + ); + assert_eq!( + Keypair::from_hex("4244176fddb5d769b7de2027469d027ad428fadcc0c02396e6280142efb718"), + Err(KeypairError::SecretKey(SecKeyError::SecretKeyBytes)) + ); + + Keypair::from_hex("164244176fddb5d769b7de2027469d027ad428fadcc0c02396e6280142efb718") + .expect("failed to decode keypair secret key"); +} + +#[test] +fn get_address() { + macro_rules! assert_get_address_eq { + ($sec_key_hex:expr, $target_address:expr) => { + let kp = Keypair::from_hex($sec_key_hex).expect("failed to create keypair"); + assert_eq!(kp.get_address(), $target_address); + }; + } + + assert_get_address_eq!( + "164244176fddb5d769b7de2027469d027ad428fadcc0c02396e6280142efb718", + "B62qnzbXmRNo9q32n4SNu2mpB8e7FYYLH8NmaX6oFCBYjjQ8SbD7uzV" + ); + assert_get_address_eq!( + "3ca187a58f09da346844964310c7e0dd948a9105702b716f4d732e042e0c172e", + "B62qicipYxyEHu7QjUqS7QvBipTs5CzgkYZZZkPoKVYBu6tnDUcE9Zt" + ); + assert_get_address_eq!( + "336eb4a19b3d8905824b0f2254fb495573be302c17582748bf7e101965aa4774", + "B62qrKG4Z8hnzZqp1AL8WsQhQYah3quN1qUj3SyfJA8Lw135qWWg1mi" + ); + assert_get_address_eq!( + "1dee867358d4000f1dafa5978341fb515f89eeddbe450bd57df091f1e63d4444", + "B62qoqiAgERjCjXhofXiD7cMLJSKD8hE8ZtMh4jX5MPNgKB4CFxxm1N" + ); + assert_get_address_eq!( + "20f84123a26e58dd32b0ea3c80381f35cd01bc22a20346cc65b0a67ae48532ba", + "B62qkiT4kgCawkSEF84ga5kP9QnhmTJEYzcfgGuk6okAJtSBfVcjm1M" + ); + assert_get_address_eq!( + "3414fc16e86e6ac272fda03cf8dcb4d7d47af91b4b726494dab43bf773ce1779", + "B62qoG5Yk4iVxpyczUrBNpwtx2xunhL48dydN53A2VjoRwF8NUTbVr4" + ); +} + +#[test] +fn to_bytes() { + let bytes = [ + 61, 18, 244, 30, 36, 241, 5, 54, 107, 96, 154, 162, 58, 78, 242, 140, 186, 233, 25, 35, + 145, 119, 39, 94, 162, 123, 208, 202, 189, 29, 235, 209, + ]; + assert_eq!( + Keypair::from_bytes(&bytes) + .expect("failed to decode keypair") + .to_bytes(), + bytes + ); + + // negative test (too many bytes) + assert_eq!( + Keypair::from_bytes(&[ + 61, 18, 244, 30, 36, 241, 5, 54, 107, 96, 154, 162, 58, 78, 242, 140, 186, 233, 25, 35, + 145, 119, 39, 94, 162, 123, 208, 202, 189, 29, 235, 209, 0 + ]), + Err(KeypairError::SecretKey(SecKeyError::SecretKeyBytes)) + ); + + // negative test (too few bytes) + assert_eq!( + Keypair::from_bytes(&[ + 61, 18, 244, 30, 36, 241, 5, 54, 107, 96, 154, 162, 58, 78, 242, 140, 186, 233, 25, 35, + 145, 119, 39, 94, 162, 123, 208, 202, 189, 29 + ]), + Err(KeypairError::SecretKey(SecKeyError::SecretKeyBytes)) + ); +} diff --git a/signer/tests/pubkey.rs b/signer/tests/pubkey.rs new file mode 100644 index 0000000000..eec3d0ca87 --- /dev/null +++ b/signer/tests/pubkey.rs @@ -0,0 +1,228 @@ +use mina_signer::{pubkey::PubKeyError, CompressedPubKey, PubKey, SecKey}; + +#[test] +fn from_hex() { + assert_eq!( + PubKey::to_hex( + &PubKey::from_hex( + "44100485d466a4c9f481d43be9a6d4a9a5e97adac19777b14b6b7a81edee39093179f4f12897797cfe78bf2fd6321f36b00bd0592defbf39a199b4168735883c" + ) + .expect("failed to decode pub key"), + ), + "44100485d466a4c9f481d43be9a6d4a9a5e97adac19777b14b6b7a81edee39093179f4f12897797cfe78bf2fd6321f36b00bd0592defbf39a199b4168735883c" + ); + + assert_eq!( + PubKey::to_hex( + &PubKey::from_hex( + "0c18d735252f4401eb845d08a87d780e9dc6ed053d0e071395277f0a34d45a29125030f6cd465cccf2ebbaabc397a6823932bc55dc1ddf9954c9e78da07c0928" + ) + .expect("failed to decode pub key"), + ), + "0c18d735252f4401eb845d08a87d780e9dc6ed053d0e071395277f0a34d45a29125030f6cd465cccf2ebbaabc397a6823932bc55dc1ddf9954c9e78da07c0928" + ); + + assert_eq!( + PubKey::from_hex("44100485d466a4c9f281d43be9a6d4a9a5e97adac19777b14b6b7a81edee39093179f4f12897797cfe78bf2fd6321f36b00bd0592defbf39a199b4168735883c"), + Err(PubKeyError::XCoordinate) + ); + + assert_eq!( + PubKey::from_hex("44100485d466a4c9f481d43be9a6d4a9a5e97adac19777b14b6b7a81edee39093179f4f12897797cfe78bf2fd6321f36b00bd0592defbf39a199b4168715883c"), + Err(PubKeyError::NonCurvePoint) + ); + + assert_eq!( + PubKey::from_hex("z8f71009a6502f89a469a037f14e97386034ea3855877159f29f4a532a2e5f28"), + Err(PubKeyError::Hex) + ); + + assert_eq!( + PubKey::from_hex("44100485d466a4c9f481d43be9a6d4aa5e97adac19777b14b6b7a81edee39093179f4f12897797cfe78bf2fd6321f36b00bd0592defbf39a199b4168735883c"), + Err(PubKeyError::Hex) + ); +} + +#[test] +fn from_secret_key() { + assert_eq!(PubKey::from_secret_key( + SecKey::from_hex("090dd91a2505081a158782c5a24fef63f326749b383423c29827e465f7ca262b").expect("failed to decode sec key") + ).expect("failed to decode pub key").to_hex(), + "3f8c32817851b7c1ad99495463ef5e99c3b3240524f0df3ff7fc41181d849e0086fa821d54de15c523a840c5f62df90aeabb1097b85c6a88e163e9d74e505803" + ); + + assert_eq!(PubKey::from_secret_key( + SecKey::from_hex("086d78e0e5deb62daeef8e3a5574d52a3d3bff5281b4dd49140564c7d80468c9").expect("failed to decode sec key") + ).expect("failed to decode pub key").to_hex(), + "666c450f5e888d3b2341d77b32cb6d0cd4912829ea9c41030d1fd2baff6b9a30c267208638544299e8d369e80b25a24bdd07383b6ea908028d9a406b528d4a01" + ); + + assert_eq!(PubKey::from_secret_key( + SecKey::from_hex("0859771e9394e96dd6d01d57ef074dc25313e63bd331fa5478a9fed9e24855a0").expect("failed to decode sec key") + ).expect("failed to decode pub key").to_hex(), + "6ed0776ab11e3dd3b637cce03a90529e518220132f1a61dd9c0d50aa998abf1d2f43c0f1eb73888ef6f7dac4d7094d3c92cd67abab39b828c5f10aff0b6a0002" + ); +} + +#[test] +fn from_address() { + macro_rules! assert_from_address_check { + ($address:expr) => { + let pk = PubKey::from_address($address).expect("failed to create pubkey"); + assert_eq!(pk.into_address(), $address); + }; + } + + assert_from_address_check!("B62qnzbXmRNo9q32n4SNu2mpB8e7FYYLH8NmaX6oFCBYjjQ8SbD7uzV"); + assert_from_address_check!("B62qicipYxyEHu7QjUqS7QvBipTs5CzgkYZZZkPoKVYBu6tnDUcE9Zt"); + assert_from_address_check!("B62qoG5Yk4iVxpyczUrBNpwtx2xunhL48dydN53A2VjoRwF8NUTbVr4"); + assert_from_address_check!("B62qrKG4Z8hnzZqp1AL8WsQhQYah3quN1qUj3SyfJA8Lw135qWWg1mi"); + assert_from_address_check!("B62qoqiAgERjCjXhofXiD7cMLJSKD8hE8ZtMh4jX5MPNgKB4CFxxm1N"); + assert_from_address_check!("B62qkiT4kgCawkSEF84ga5kP9QnhmTJEYzcfgGuk6okAJtSBfVcjm1M"); +} + +#[test] +fn to_bytes() { + let mut bytes = vec![ + 68, 16, 4, 133, 212, 102, 164, 201, 244, 129, 212, 59, 233, 166, 212, 169, 165, 233, 122, + 218, 193, 151, 119, 177, 75, 107, 122, 129, 237, 238, 57, 9, 49, 121, 244, 241, 40, 151, + 121, 124, 254, 120, 191, 47, 214, 50, 31, 54, 176, 11, 208, 89, 45, 239, 191, 57, 161, 153, + 180, 22, 135, 53, 136, 60, + ]; + assert_eq!( + PubKey::from_bytes(&bytes) + .expect("failed to decode pub key") + .to_bytes(), + bytes + ); + + bytes[0] = 0; // negative test: invalid curve point + assert_eq!(PubKey::from_bytes(&bytes), Err(PubKeyError::NonCurvePoint)); + + bytes[0] = 68; + let mut bytes = [bytes, vec![255u8, 102u8]].concat(); // negative test: to many bytes + assert_eq!( + PubKey::from_bytes(&bytes), + Err(PubKeyError::YCoordinateBytes) + ); + + bytes.remove(0); // negative test: to few bytes + bytes.remove(0); + assert_eq!( + PubKey::from_bytes(&bytes), + Err(PubKeyError::XCoordinateBytes) + ); +} + +#[test] +fn compressed_from_hex() { + assert_eq!(PubKey::from_hex( + "44100485d466a4c9f481d43be9a6d4a9a5e97adac19777b14b6b7a81edee39093179f4f12897797cfe78bf2fd6321f36b00bd0592defbf39a199b4168735883c" + ).expect("failed to decode pub key").into_address(), + CompressedPubKey::from_hex( + "44100485d466a4c9f481d43be9a6d4a9a5e97adac19777b14b6b7a81edee390901" + ).expect("failed to decode compressed pub key").into_address() + ); + + assert_eq!(PubKey::from_hex( + "44100485d466a4c9f481d43be9a6d4a9a5e97adac19777b14b6b7a81edee39093179f4f12897797cfe78bf2fd6321f36b00bd0592defbf39a199b4168735883c" + ).expect("failed to decode pub key").into_compressed(), + CompressedPubKey::from_hex( + "44100485d466a4c9f481d43be9a6d4a9a5e97adac19777b14b6b7a81edee390901" + ).expect("failed to decode compressed pub key") + ); + + assert_ne!(PubKey::from_hex( + "44100485d466a4c9f481d43be9a6d4a9a5e97adac19777b14b6b7a81edee39093179f4f12897797cfe78bf2fd6321f36b00bd0592defbf39a199b4168735883c" + ).expect("failed to decode pub key").into_compressed(), + CompressedPubKey::from_hex( // Invalid parity bit + "44100485d466a4c9f481d43be9a6d4a9a5e97adac19777b14b6b7a81edee390900" + ).expect("failed to decode compressed pub key") + ); + + assert_eq!(PubKey::from_hex( + "0c18d735252f4401eb845d08a87d780e9dc6ed053d0e071395277f0a34d45a29125030f6cd465cccf2ebbaabc397a6823932bc55dc1ddf9954c9e78da07c0928" + ).expect("failed to decode pub key").into_compressed(), + CompressedPubKey::from_hex( + "0c18d735252f4401eb845d08a87d780e9dc6ed053d0e071395277f0a34d45a2900" + ).expect("failed to decode compressed pub key") + ); + + assert_eq!( + CompressedPubKey::from_hex( + // Missing parity bit + "0c18d735252f4401eb845d08a87d780e9dc6ed053d0e071395277f0a34d45a29" + ), + Err(PubKeyError::YCoordinateParityBytes) + ); + + assert_eq!( + CompressedPubKey::from_hex( + // Wrong parity bytes + "0c18d735252f4401eb845d08a87d780e9dc6ed053d0e071395277f0a34d45a290101" + ), + Err(PubKeyError::YCoordinateParityBytes) + ); + + assert_eq!( + CompressedPubKey::from_hex( + // Invalid parity byte + "0c18d735252f4401eb845d08a87d780e9dc6ed053d0e071395277f0a34d45a2902" + ), + Err(PubKeyError::YCoordinateParity) + ); + + assert!(CompressedPubKey::from_hex( + // OK parity byte (odd) + "0c18d735252f4401eb845d08a87d780e9dc6ed053d0e071395277f0a34d45a2900" + ) + .is_ok()); + + assert!(CompressedPubKey::from_hex( + // OK parity byte (odd) + "0c18d735252f4401eb845d08a87d780e9dc6ed053d0e071395277f0a34d45a2901" + ) + .is_ok()); + + assert_ne!(PubKey::from_hex( + "0c18d735252f4401eb845d08a87d780e9dc6ed053d0e071395277f0a34d45a29125030f6cd465cccf2ebbaabc397a6823932bc55dc1ddf9954c9e78da07c0928" + ).expect("failed to decode pub key").into_compressed(), + CompressedPubKey::from_hex( + "0c18d735252f4401eb845d08a87d780e9dc6ed053d0e071395277f0a34d45a2901" + ).expect("failed to decode compressed pub key") + ); +} + +#[test] +fn compressed_to_bytes() { + let mut bytes = vec![ + 68, 16, 4, 133, 212, 102, 164, 201, 244, 129, 212, 59, 233, 166, 212, 169, 165, 233, 122, + 218, 193, 151, 119, 177, 75, 107, 122, 129, 237, 238, 57, 9, 1, + ]; + assert_eq!( + CompressedPubKey::from_bytes(&bytes) + .expect("failed to decode pub key") + .to_bytes(), + bytes + ); + + bytes[4] = 73; // negative test: invalid x + assert_eq!( + CompressedPubKey::from_bytes(&bytes), + Err(PubKeyError::XCoordinate) + ); + + bytes[0] = 212; + let mut bytes = [bytes, vec![255u8]].concat(); // negative test: to many bytes + assert_eq!( + CompressedPubKey::from_bytes(&bytes), + Err(PubKeyError::YCoordinateParityBytes) + ); + + bytes.remove(0); // negative test: to few bytes + bytes.remove(0); + assert_eq!( + CompressedPubKey::from_bytes(&bytes), + Err(PubKeyError::XCoordinateBytes) + ); +} diff --git a/signer/tests/seckey.rs b/signer/tests/seckey.rs new file mode 100644 index 0000000000..4b161b8822 --- /dev/null +++ b/signer/tests/seckey.rs @@ -0,0 +1,90 @@ +use mina_signer::{seckey::SecKeyError, SecKey}; + +#[test] +fn from_hex() { + assert_eq!( + SecKey::to_hex( + &SecKey::from_hex("3d12f41e24f105366b609aa23a4ef28cbae919239177275ea27bd0cabd1debd1") + .expect("failed to decode sec key"), + ), + "3d12f41e24f105366b609aa23a4ef28cbae919239177275ea27bd0cabd1debd1" + ); + + assert_eq!( + SecKey::to_hex( + &SecKey::from_hex("285f2e2a534a9ff25971875538ea346038974ef137a069a4892f50e60910f7d8") + .expect("failed to decode sec key"), + ), + "285f2e2a534a9ff25971875538ea346038974ef137a069a4892f50e60910f7d8" + ); + + assert_eq!( + SecKey::from_hex("d8f71009e6502f89a469a037f14e97386034ea3855877159f29f4a532a2e5f28"), + Err(SecKeyError::SecretKeyBytes) + ); + + assert_eq!( + SecKey::from_hex("d8f71009g6502f89a469a037f14e97386034ea3855877159f29f4a532a2e5f28"), + Err(SecKeyError::SecretKeyHex) + ); +} + +#[test] +fn to_bytes() { + let bytes = [ + 40, 95, 46, 42, 83, 74, 159, 242, 89, 113, 135, 85, 56, 234, 52, 96, 56, 151, 78, 241, 55, + 160, 105, 164, 137, 47, 80, 230, 9, 16, 247, 216, + ]; + assert_eq!( + SecKey::from_bytes(&bytes) + .expect("failed to decode sec key") + .to_bytes(), + bytes + ); + + // negative test (too many bytes) + assert_eq!( + SecKey::from_bytes(&[ + 40, 95, 46, 42, 83, 74, 159, 242, 89, 113, 135, 85, 56, 234, 52, 96, 56, 151, 78, 241, + 55, 160, 105, 164, 137, 47, 80, 230, 9, 16, 247, 216, 10 + ]), + Err(SecKeyError::SecretKeyBytes) + ); + + // negative test (too few bytes) + assert_eq!( + SecKey::from_bytes(&[ + 40, 95, 46, 42, 83, 74, 159, 242, 89, 113, 135, 85, 56, 234, 52, 96, 56, 151, 78, 241, + 55, 160, 105, 164, 137, 47, 80, 230, 9, 16, 247 + ]), + Err(SecKeyError::SecretKeyBytes) + ); +} + +#[test] +fn base58() { + assert_eq!( + SecKey::from_base58("EKFS3M4Fe1VkVjPMn2jauXa1Vv6w6gRES5oLbH3vZmP26uQESodY") + .expect("failed to decode sec key") + .to_base58(), + "EKFS3M4Fe1VkVjPMn2jauXa1Vv6w6gRES5oLbH3vZmP26uQESodY" + ); + + // invalid checksum + assert_eq!( + SecKey::from_base58("EKFS3M4Fe1VkVjPMn2jauXa1Vv6w6gRES5oLbH3vZmP26uQESodZ"), + Err(SecKeyError::SecretKeyChecksum) + ); + + // invalid version + assert_eq!( + SecKey::from_base58("ETq4cWR9pAQtUFQ8L78UhhfVkMJaw6gxbXRU9jQ24F8jPEh7tn3q"), + Err(SecKeyError::SecretKeyVersion) + ); + + // invalid length + assert_eq!( + SecKey::from_base58("EKFS3M4Fe1VkVjPMn2a1Vv6w6gRES5oLbH3vZmP26uQESodY"), + Err(SecKeyError::SecretKeyLength) + ); +} diff --git a/srs/test_bn254.srs b/srs/test_bn254.srs new file mode 100644 index 0000000000..543839ae9b Binary files /dev/null and b/srs/test_bn254.srs differ diff --git a/srs/test_pallas.srs b/srs/test_pallas.srs new file mode 100644 index 0000000000..c16a644540 Binary files /dev/null and b/srs/test_pallas.srs differ diff --git a/srs/test_vesta.srs b/srs/test_vesta.srs new file mode 100644 index 0000000000..8ed35e5b17 Binary files /dev/null and b/srs/test_vesta.srs differ diff --git a/tools/kimchi-visu/src/lib.rs b/tools/kimchi-visu/src/lib.rs index 390a6924b5..7a5d770172 100644 --- a/tools/kimchi-visu/src/lib.rs +++ b/tools/kimchi-visu/src/lib.rs @@ -13,7 +13,7 @@ use kimchi::{ curve::KimchiCurve, prover_index::ProverIndex, }; -use poly_commitment::{commitment::CommitmentCurve, evaluation_proof::OpeningProof}; +use poly_commitment::{commitment::CommitmentCurve, ipa::OpeningProof}; use serde::Serialize; use std::{ collections::HashMap, diff --git a/turshi/src/helper.rs b/turshi/src/helper.rs index 10b72bcad7..f397a70af7 100644 --- a/turshi/src/helper.rs +++ b/turshi/src/helper.rs @@ -1,4 +1,4 @@ -//! This module inlcudes some field helpers that are useful for Cairo +//! This module includes some field helpers that are useful for Cairo use ark_ff::Field; use o1_utils::FieldHelpers; @@ -46,37 +46,3 @@ impl CairoFieldHelpers for F { hex::encode(bytes) } } - -#[cfg(test)] -mod tests { - use super::*; - use ark_ec::AffineRepr; - use mina_curves::pasta::Pallas as CurvePoint; - use o1_utils::FieldHelpers; - - /// Base field element type - pub type BaseField = ::BaseField; - - #[test] - fn test_field_to_bits() { - let fe = BaseField::from(256u32); - let bits = fe.to_bits(); - println!("{:?}", &bits[0..16]); - } - - #[test] - fn test_field_to_chunks() { - let fe = BaseField::from(0x480680017fff8000u64); - let chunk = fe.u16_chunk(1); - assert_eq!(chunk, BaseField::from(0x7fff)); - } - - #[test] - fn test_hex_and_u64() { - let fe = BaseField::from(0x480680017fff8000u64); - let change = BaseField::from_hex(&fe.to_hex()).unwrap(); - assert_eq!(fe, change); - let word = change.to_u64(); - assert_eq!(word, 0x480680017fff8000u64); - } -} diff --git a/turshi/src/lib.rs b/turshi/src/lib.rs index 0439932a88..58180b749a 100644 --- a/turshi/src/lib.rs +++ b/turshi/src/lib.rs @@ -1,9 +1,11 @@ -#![warn(missing_docs)] -//! This module contains the code that executes a compiled Cairo program and generates the memory. +//! This module contains the code that executes a compiled Cairo program and +//! generates the memory. //! The Cairo runner includes code to execute a bytecode compiled Cairo program, //! and obtain a memory instantiation after the execution. It uses some code to -//! represent Cairo instructions and their decomposition, together with their logic -//! which is represented as steps of computation making up the full program. +//! represent Cairo instructions and their decomposition, together with their +//! logic which is represented as steps of computation making up the full +//! program. + pub mod flags; pub mod helper; pub mod memory; diff --git a/turshi/src/memory.rs b/turshi/src/memory.rs index 3b720f9327..b9b972c877 100644 --- a/turshi/src/memory.rs +++ b/turshi/src/memory.rs @@ -100,40 +100,3 @@ impl CairoMemory { self[addr].map(|x| x.word()) } } - -#[cfg(test)] -mod tests { - use super::*; - use ark_ff::One; - use mina_curves::pasta::Fp as F; - - #[test] - fn test_cairo_bytecode() { - // This test starts with the public memory corresponding to a simple Cairo program - // func main{}(): - // tempvar x = 10; - // return() - // end - // And checks that memory writing and reading works as expected by completing - // the total memory of executing the program - let instrs = [0x480680017fff8000, 10, 0x208b7fff7fff7ffe] - .iter() - .map(|&i: &i64| F::from(i)) - .collect(); - let mut memory = CairoMemory::new(instrs); - memory.write(F::from(memory.len()), F::from(7u64)); - memory.write(F::from(memory.len()), F::from(7u64)); - memory.write(F::from(memory.len()), F::from(10u64)); - println!("{memory}"); - // Check content of an address - assert_eq!( - memory.read(F::one()).unwrap(), - F::from(0x480680017fff8000u64) - ); - // Check that the program contained 3 words - assert_eq!(3, memory.get_codelen()); - // Check we have 6 words, excluding the dummy entry - assert_eq!(6, memory.len() - 1); - memory.read(F::from(10u32)); - } -} diff --git a/turshi/src/runner.rs b/turshi/src/runner.rs index f7d2616bfa..4629c24baf 100644 --- a/turshi/src/runner.rs +++ b/turshi/src/runner.rs @@ -249,14 +249,14 @@ impl FlagBits for CairoInstruction { /// A data structure to store a current step of Cairo computation pub struct CairoStep<'a, F> { /// state of the computation - mem: &'a mut CairoMemory, + pub mem: &'a mut CairoMemory, // comment instr for efficiency /// current pointers - curr: CairoState, + pub curr: CairoState, /// (if any) next pointers - next: Option>, + pub next: Option>, /// state auxiliary variables - vars: CairoContext, + pub vars: CairoContext, } impl<'a, F: Field> CairoStep<'a, F> { @@ -490,15 +490,15 @@ impl<'a, F: Field> CairoStep<'a, F> { /// This struct stores the needed information to run a program pub struct CairoProgram<'a, F> { /// total number of steps - steps: F, + pub steps: F, /// full execution memory - mem: &'a mut CairoMemory, + pub mem: &'a mut CairoMemory, /// initial computation registers - ini: CairoState, + pub ini: CairoState, /// final computation pointers - fin: CairoState, + pub fin: CairoState, /// execution trace as a vector of [CairoInstruction] - trace: Vec>, + pub trace: Vec>, } impl<'a, F: Field> CairoProgram<'a, F> { @@ -574,127 +574,3 @@ impl<'a, F: Field> CairoProgram<'a, F> { self.fin = CairoState::new(curr.pc, curr.ap, curr.fp); } } - -#[cfg(test)] -mod tests { - use super::*; - use mina_curves::pasta::Fp as F; - - #[test] - fn test_cairo_step() { - // This tests that CairoStep works for a 2 word instruction - // tempvar x = 10; - let instrs = [0x480680017fff8000, 10, 0x208b7fff7fff7ffe] - .iter() - .map(|&i: &i64| F::from(i)) - .collect(); - let mut mem = CairoMemory::new(instrs); - // Need to know how to find out - // Is it final ap and/or final fp? Will write to starkware guys to learn about this - mem.write(F::from(4u32), F::from(7u32)); - mem.write(F::from(5u32), F::from(7u32)); - let ptrs = CairoState::new(F::from(1u32), F::from(6u32), F::from(6u32)); - let mut step = CairoStep::new(&mut mem, ptrs); - - step.execute(); - assert_eq!(step.next.unwrap().pc, F::from(3u32)); - assert_eq!(step.next.unwrap().ap, F::from(7u32)); - assert_eq!(step.next.unwrap().fp, F::from(6u32)); - - println!("{}", step.mem); - } - - #[test] - fn test_cairo_program() { - let instrs = [0x480680017fff8000, 10, 0x208b7fff7fff7ffe] - .iter() - .map(|&i: &i64| F::from(i)) - .collect(); - let mut mem = CairoMemory::::new(instrs); - // Need to know how to find out - // Is it final ap and/or final fp? Will write to starkware guys to learn about this - mem.write(F::from(4u32), F::from(7u32)); //beginning of output - mem.write(F::from(5u32), F::from(7u32)); //end of output - let prog = CairoProgram::new(&mut mem, 1); - println!("{}", prog.mem); - } - - #[test] - fn test_cairo_output() { - // This is a test for a longer program, involving builtins, imports and outputs - // One can generate more tests here: https://www.cairo-lang.org/playground/ - /* - %builtins output - from starkware.cairo.common.serialize import serialize_word - func main{output_ptr : felt*}(): - tempvar x = 10 - tempvar y = x + x - tempvar z = y * y + x - serialize_word(x) - serialize_word(y) - serialize_word(z) - return () - end - */ - let instrs = [ - 0x400380007ffc7ffd, - 0x482680017ffc8000, - 1, - 0x208b7fff7fff7ffe, - 0x480680017fff8000, - 10, - 0x48307fff7fff8000, - 0x48507fff7fff8000, - 0x48307ffd7fff8000, - 0x480a7ffd7fff8000, - 0x48127ffb7fff8000, - 0x1104800180018000, - -11, - 0x48127ff87fff8000, - 0x1104800180018000, - -14, - 0x48127ff67fff8000, - 0x1104800180018000, - -17, - 0x208b7fff7fff7ffe, - /*41, // beginning of outputs - 44, // end of outputs - 44, // input - */ - ] - .iter() - .map(|&i: &i64| F::from(i)) - .collect(); - - let mut mem = CairoMemory::::new(instrs); - // Need to know how to find out - mem.write(F::from(21u32), F::from(41u32)); // beginning of outputs - mem.write(F::from(22u32), F::from(44u32)); // end of outputs - mem.write(F::from(23u32), F::from(44u32)); //end of program - let prog = CairoProgram::new(&mut mem, 5); - assert_eq!(prog.fin().pc, F::from(20u32)); - assert_eq!(prog.fin().ap, F::from(41u32)); - assert_eq!(prog.fin().fp, F::from(24u32)); - println!("{}", prog.mem); - assert_eq!(prog.mem.read(F::from(24u32)).unwrap(), F::from(10u32)); - assert_eq!(prog.mem.read(F::from(25u32)).unwrap(), F::from(20u32)); - assert_eq!(prog.mem.read(F::from(26u32)).unwrap(), F::from(400u32)); - assert_eq!(prog.mem.read(F::from(27u32)).unwrap(), F::from(410u32)); - assert_eq!(prog.mem.read(F::from(28u32)).unwrap(), F::from(41u32)); - assert_eq!(prog.mem.read(F::from(29u32)).unwrap(), F::from(10u32)); - assert_eq!(prog.mem.read(F::from(30u32)).unwrap(), F::from(24u32)); - assert_eq!(prog.mem.read(F::from(31u32)).unwrap(), F::from(14u32)); - assert_eq!(prog.mem.read(F::from(32u32)).unwrap(), F::from(42u32)); - assert_eq!(prog.mem.read(F::from(33u32)).unwrap(), F::from(20u32)); - assert_eq!(prog.mem.read(F::from(34u32)).unwrap(), F::from(24u32)); - assert_eq!(prog.mem.read(F::from(35u32)).unwrap(), F::from(17u32)); - assert_eq!(prog.mem.read(F::from(36u32)).unwrap(), F::from(43u32)); - assert_eq!(prog.mem.read(F::from(37u32)).unwrap(), F::from(410u32)); - assert_eq!(prog.mem.read(F::from(38u32)).unwrap(), F::from(24u32)); - assert_eq!(prog.mem.read(F::from(39u32)).unwrap(), F::from(20u32)); - assert_eq!(prog.mem.read(F::from(40u32)).unwrap(), F::from(44u32)); - assert_eq!(prog.mem.read(F::from(41u32)).unwrap(), F::from(10u32)); - assert_eq!(prog.mem.read(F::from(42u32)).unwrap(), F::from(20u32)); - assert_eq!(prog.mem.read(F::from(43u32)).unwrap(), F::from(410u32)); - } -} diff --git a/turshi/src/word.rs b/turshi/src/word.rs index 714798f263..e7cd5c1d55 100644 --- a/turshi/src/word.rs +++ b/turshi/src/word.rs @@ -2,18 +2,21 @@ //! modulus 0x800000000000011000000000000000000000000000000000000000000000001 //! This is the hexadecimal value for 2 ^ 251 + 17 * 2 ^ 192 + 1 //! Our Pallas curves have 255 bits, so Cairo native instructions will fit. -//! This means that our Cairo implementation can admit a larger domain for immediate values than theirs. +//! This means that our Cairo implementation can admit a larger domain for +//! immediate values than theirs. use crate::{flags::*, helper::CairoFieldHelpers}; use ark_ff::Field; use o1_utils::field_helpers::FieldHelpers; -/// A Cairo word for the runner. Some words are instructions (which fit inside a `u64`). Others are immediate values (any `F` element). +/// A Cairo word for the runner. Some words are instructions (which fit inside a +/// `u64`). Others are immediate values (any `F` element). #[derive(Clone, Copy)] pub struct CairoWord(F); -/// Returns an offset of 16 bits to its biased representation in the interval `[-2^15,2^15)` as a field element -fn bias(offset: F) -> F { +/// Returns an offset of 16 bits to its biased representation in the interval +/// `[-2^15,2^15)` as a field element +pub fn bias(offset: F) -> F { offset - F::from(2u16.pow(15u32)) // -2^15 + sum_(i=0..15) b_i * 2^i } @@ -34,7 +37,8 @@ impl CairoWord { } } -/// This trait contains methods to obtain the offset decomposition of a [CairoWord] +/// This trait contains methods to obtain the offset decomposition of a +/// [CairoWord] pub trait Offsets { /// Returns the destination offset in biased representation fn off_dst(&self) -> F; @@ -46,7 +50,8 @@ pub trait Offsets { fn off_op1(&self) -> F; } -/// This trait contains methods that decompose a field element into [CairoWord] flagbits +/// This trait contains methods that decompose a field element into [CairoWord] +/// flagbits pub trait FlagBits { /// Returns bit-flag for destination register as `F` fn f_dst_fp(&self) -> F; @@ -75,13 +80,16 @@ pub trait FlagBits { /// Returns bit-flag for program counter update being relative jump as `F` fn f_pc_rel(&self) -> F; - /// Returns bit-flag for program counter update being conditional jump as `F` + /// Returns bit-flag for program counter update being conditional jump as + /// `F` fn f_pc_jnz(&self) -> F; - /// Returns bit-flag for allocation counter update being a manual addition as `F` + /// Returns bit-flag for allocation counter update being a manual addition + /// as `F` fn f_ap_add(&self) -> F; - /// Returns bit-flag for allocation counter update being a self increment as `F` + /// Returns bit-flag for allocation counter update being a self increment as + /// `F` fn f_ap_one(&self) -> F; /// Returns bit-flag for operation being a call as `F` @@ -97,7 +105,8 @@ pub trait FlagBits { fn f15(&self) -> F; } -/// This trait contains methods that decompose a field element into [CairoWord] flagsets +/// This trait contains methods that decompose a field element into [CairoWord] +/// flagsets pub trait FlagSets { /// Returns flagset for destination register fn dst_reg(&self) -> u8; @@ -240,67 +249,3 @@ impl FlagSets for CairoWord { 2 * (2 * self.f_opc_aeq().lsb() + self.f_opc_ret().lsb()) + self.f_opc_call().lsb() } } - -#[cfg(test)] -mod tests { - use crate::{ - flags::*, - word::{FlagBits, FlagSets, Offsets}, - }; - use ark_ff::{One, Zero}; - use mina_curves::pasta::Fp as F; - - #[test] - fn test_biased() { - assert_eq!(F::one(), super::bias(F::from(0x8001))); - assert_eq!(F::zero(), super::bias(F::from(0x8000))); - assert_eq!(-F::one(), super::bias(F::from(0x7fff))); - } - - #[test] - fn test_cairo_word() { - // Tests the structure of a Cairo word corresponding to the Cairo instruction: tempvar x = val - // This unit test checks offsets computation, flagbits and flagsets. - let word = super::CairoWord::new(F::from(0x480680017fff8000u64)); - - assert_eq!(word.off_dst(), F::zero()); - assert_eq!(word.off_op0(), -F::one()); - assert_eq!(word.off_op1(), F::one()); - - assert_eq!(word.f_dst_fp(), F::zero()); - assert_eq!(word.f_op0_fp(), F::one()); - assert_eq!(word.f_op1_val(), F::one()); - assert_eq!(word.f_op1_fp(), F::zero()); - assert_eq!(word.f_op1_ap(), F::zero()); - assert_eq!(word.f_res_add(), F::zero()); - assert_eq!(word.f_res_mul(), F::zero()); - assert_eq!(word.f_pc_abs(), F::zero()); - assert_eq!(word.f_pc_rel(), F::zero()); - assert_eq!(word.f_pc_jnz(), F::zero()); - assert_eq!(word.f_ap_add(), F::zero()); - assert_eq!(word.f_ap_one(), F::one()); - assert_eq!(word.f_opc_call(), F::zero()); - assert_eq!(word.f_opc_ret(), F::zero()); - assert_eq!(word.f_opc_aeq(), F::one()); - assert_eq!(word.f15(), F::zero()); - - assert_eq!(word.dst_reg(), DST_AP); - assert_eq!(word.op0_reg(), 1 - OP0_AP); - assert_eq!(word.op1_src(), OP1_VAL); - assert_eq!(word.res_log(), RES_ONE); - assert_eq!(word.pc_up(), PC_SIZ); - assert_eq!(word.ap_up(), AP_ONE); - assert_eq!(word.opcode(), OPC_AEQ); - - assert_eq!( - 0x4806, - u32::from(word.dst_reg()) - + 2 * u32::from(word.op0_reg()) - + 2u32.pow(2) * u32::from(word.op1_src()) - + 2u32.pow(5) * u32::from(word.res_log()) - + 2u32.pow(7) * u32::from(word.pc_up()) - + 2u32.pow(10) * u32::from(word.ap_up()) - + 2u32.pow(12) * u32::from(word.opcode()) - ); - } -} diff --git a/turshi/tests/helper.rs b/turshi/tests/helper.rs new file mode 100644 index 0000000000..6192673c3e --- /dev/null +++ b/turshi/tests/helper.rs @@ -0,0 +1,30 @@ +use ark_ec::AffineRepr; +use mina_curves::pasta::Pallas as CurvePoint; +use o1_utils::FieldHelpers; +use turshi::helper::CairoFieldHelpers; + +/// Base field element type +pub type BaseField = ::BaseField; + +#[test] +fn test_field_to_bits() { + let fe = BaseField::from(256u32); + let bits = fe.to_bits(); + println!("{:?}", &bits[0..16]); +} + +#[test] +fn test_field_to_chunks() { + let fe = BaseField::from(0x480680017fff8000u64); + let chunk = fe.u16_chunk(1); + assert_eq!(chunk, BaseField::from(0x7fff)); +} + +#[test] +fn test_hex_and_u64() { + let fe = BaseField::from(0x480680017fff8000u64); + let change = BaseField::from_hex(&fe.to_hex()).unwrap(); + assert_eq!(fe, change); + let word = change.to_u64(); + assert_eq!(word, 0x480680017fff8000u64); +} diff --git a/turshi/tests/memory.rs b/turshi/tests/memory.rs new file mode 100644 index 0000000000..3868f9d2e2 --- /dev/null +++ b/turshi/tests/memory.rs @@ -0,0 +1,33 @@ +use ark_ff::One; +use mina_curves::pasta::Fp as F; +use turshi::CairoMemory; + +#[test] +fn test_cairo_bytecode() { + // This test starts with the public memory corresponding to a simple Cairo program + // func main{}(): + // tempvar x = 10; + // return() + // end + // And checks that memory writing and reading works as expected by completing + // the total memory of executing the program + let instrs = [0x480680017fff8000, 10, 0x208b7fff7fff7ffe] + .iter() + .map(|&i: &i64| F::from(i)) + .collect(); + let mut memory = CairoMemory::new(instrs); + memory.write(F::from(memory.len()), F::from(7u64)); + memory.write(F::from(memory.len()), F::from(7u64)); + memory.write(F::from(memory.len()), F::from(10u64)); + println!("{memory}"); + // Check content of an address + assert_eq!( + memory.read(F::one()).unwrap(), + F::from(0x480680017fff8000u64) + ); + // Check that the program contained 3 words + assert_eq!(3, memory.get_codelen()); + // Check we have 6 words, excluding the dummy entry + assert_eq!(6, memory.len() - 1); + memory.read(F::from(10u32)); +} diff --git a/turshi/tests/runner.rs b/turshi/tests/runner.rs new file mode 100644 index 0000000000..3be3dca2b0 --- /dev/null +++ b/turshi/tests/runner.rs @@ -0,0 +1,123 @@ +use mina_curves::pasta::Fp as F; +use turshi::{ + runner::{CairoState, CairoStep}, + CairoMemory, CairoProgram, Pointers, +}; + +#[test] +fn test_cairo_step() { + // This tests that CairoStep works for a 2 word instruction + // tempvar x = 10; + let instrs = [0x480680017fff8000, 10, 0x208b7fff7fff7ffe] + .iter() + .map(|&i: &i64| F::from(i)) + .collect(); + let mut mem = CairoMemory::new(instrs); + // Need to know how to find out + // Is it final ap and/or final fp? Will write to starkware guys to learn about this + mem.write(F::from(4u32), F::from(7u32)); + mem.write(F::from(5u32), F::from(7u32)); + let ptrs = CairoState::new(F::from(1u32), F::from(6u32), F::from(6u32)); + let mut step = CairoStep::new(&mut mem, ptrs); + + step.execute(); + assert_eq!(step.next.unwrap().pc(), F::from(3u32)); + assert_eq!(step.next.unwrap().ap(), F::from(7u32)); + assert_eq!(step.next.unwrap().fp(), F::from(6u32)); + + println!("{}", step.mem); +} + +#[test] +fn test_cairo_program() { + let instrs = [0x480680017fff8000, 10, 0x208b7fff7fff7ffe] + .iter() + .map(|&i: &i64| F::from(i)) + .collect(); + let mut mem = CairoMemory::::new(instrs); + // Need to know how to find out + // Is it final ap and/or final fp? Will write to starkware guys to learn about this + mem.write(F::from(4u32), F::from(7u32)); //beginning of output + mem.write(F::from(5u32), F::from(7u32)); //end of output + let prog = CairoProgram::new(&mut mem, 1); + println!("{}", prog.mem); +} + +#[test] +fn test_cairo_output() { + // This is a test for a longer program, involving builtins, imports and outputs + // One can generate more tests here: https://www.cairo-lang.org/playground/ + /* + %builtins output + from starkware.cairo.common.serialize import serialize_word + func main{output_ptr : felt*}(): + tempvar x = 10 + tempvar y = x + x + tempvar z = y * y + x + serialize_word(x) + serialize_word(y) + serialize_word(z) + return () + end + */ + let instrs = [ + 0x400380007ffc7ffd, + 0x482680017ffc8000, + 1, + 0x208b7fff7fff7ffe, + 0x480680017fff8000, + 10, + 0x48307fff7fff8000, + 0x48507fff7fff8000, + 0x48307ffd7fff8000, + 0x480a7ffd7fff8000, + 0x48127ffb7fff8000, + 0x1104800180018000, + -11, + 0x48127ff87fff8000, + 0x1104800180018000, + -14, + 0x48127ff67fff8000, + 0x1104800180018000, + -17, + 0x208b7fff7fff7ffe, + /*41, // beginning of outputs + 44, // end of outputs + 44, // input + */ + ] + .iter() + .map(|&i: &i64| F::from(i)) + .collect(); + + let mut mem = CairoMemory::::new(instrs); + // Need to know how to find out + mem.write(F::from(21u32), F::from(41u32)); // beginning of outputs + mem.write(F::from(22u32), F::from(44u32)); // end of outputs + mem.write(F::from(23u32), F::from(44u32)); //end of program + let prog = CairoProgram::new(&mut mem, 5); + assert_eq!(prog.fin().pc(), F::from(20u32)); + assert_eq!(prog.fin().ap(), F::from(41u32)); + assert_eq!(prog.fin().fp(), F::from(24u32)); + println!("{}", prog.mem); + assert_eq!(prog.mem.read(F::from(24u32)).unwrap(), F::from(10u32)); + assert_eq!(prog.mem.read(F::from(25u32)).unwrap(), F::from(20u32)); + assert_eq!(prog.mem.read(F::from(26u32)).unwrap(), F::from(400u32)); + assert_eq!(prog.mem.read(F::from(27u32)).unwrap(), F::from(410u32)); + assert_eq!(prog.mem.read(F::from(28u32)).unwrap(), F::from(41u32)); + assert_eq!(prog.mem.read(F::from(29u32)).unwrap(), F::from(10u32)); + assert_eq!(prog.mem.read(F::from(30u32)).unwrap(), F::from(24u32)); + assert_eq!(prog.mem.read(F::from(31u32)).unwrap(), F::from(14u32)); + assert_eq!(prog.mem.read(F::from(32u32)).unwrap(), F::from(42u32)); + assert_eq!(prog.mem.read(F::from(33u32)).unwrap(), F::from(20u32)); + assert_eq!(prog.mem.read(F::from(34u32)).unwrap(), F::from(24u32)); + assert_eq!(prog.mem.read(F::from(35u32)).unwrap(), F::from(17u32)); + assert_eq!(prog.mem.read(F::from(36u32)).unwrap(), F::from(43u32)); + assert_eq!(prog.mem.read(F::from(37u32)).unwrap(), F::from(410u32)); + assert_eq!(prog.mem.read(F::from(38u32)).unwrap(), F::from(24u32)); + assert_eq!(prog.mem.read(F::from(39u32)).unwrap(), F::from(20u32)); + assert_eq!(prog.mem.read(F::from(40u32)).unwrap(), F::from(44u32)); + assert_eq!(prog.mem.read(F::from(41u32)).unwrap(), F::from(10u32)); + assert_eq!(prog.mem.read(F::from(42u32)).unwrap(), F::from(20u32)); + assert_eq!(prog.mem.read(F::from(43u32)).unwrap(), F::from(410u32)); +} diff --git a/turshi/tests/word.rs b/turshi/tests/word.rs new file mode 100644 index 0000000000..a984d63874 --- /dev/null +++ b/turshi/tests/word.rs @@ -0,0 +1,60 @@ +use ark_ff::{One, Zero}; +use mina_curves::pasta::Fp as F; +use turshi::{ + flags::*, + word::{bias, CairoWord, FlagBits, FlagSets, Offsets}, +}; + +#[test] +fn test_biased() { + assert_eq!(F::one(), bias(F::from(0x8001))); + assert_eq!(F::zero(), bias(F::from(0x8000))); + assert_eq!(-F::one(), bias(F::from(0x7fff))); +} + +#[test] +fn test_cairo_word() { + // Tests the structure of a Cairo word corresponding to the Cairo instruction: tempvar x = val + // This unit test checks offsets computation, flagbits and flagsets. + let word = CairoWord::new(F::from(0x480680017fff8000u64)); + + assert_eq!(word.off_dst(), F::zero()); + assert_eq!(word.off_op0(), -F::one()); + assert_eq!(word.off_op1(), F::one()); + + assert_eq!(word.f_dst_fp(), F::zero()); + assert_eq!(word.f_op0_fp(), F::one()); + assert_eq!(word.f_op1_val(), F::one()); + assert_eq!(word.f_op1_fp(), F::zero()); + assert_eq!(word.f_op1_ap(), F::zero()); + assert_eq!(word.f_res_add(), F::zero()); + assert_eq!(word.f_res_mul(), F::zero()); + assert_eq!(word.f_pc_abs(), F::zero()); + assert_eq!(word.f_pc_rel(), F::zero()); + assert_eq!(word.f_pc_jnz(), F::zero()); + assert_eq!(word.f_ap_add(), F::zero()); + assert_eq!(word.f_ap_one(), F::one()); + assert_eq!(word.f_opc_call(), F::zero()); + assert_eq!(word.f_opc_ret(), F::zero()); + assert_eq!(word.f_opc_aeq(), F::one()); + assert_eq!(word.f15(), F::zero()); + + assert_eq!(word.dst_reg(), DST_AP); + assert_eq!(word.op0_reg(), 1 - OP0_AP); + assert_eq!(word.op1_src(), OP1_VAL); + assert_eq!(word.res_log(), RES_ONE); + assert_eq!(word.pc_up(), PC_SIZ); + assert_eq!(word.ap_up(), AP_ONE); + assert_eq!(word.opcode(), OPC_AEQ); + + assert_eq!( + 0x4806, + u32::from(word.dst_reg()) + + 2 * u32::from(word.op0_reg()) + + 2u32.pow(2) * u32::from(word.op1_src()) + + 2u32.pow(5) * u32::from(word.res_log()) + + 2u32.pow(7) * u32::from(word.pc_up()) + + 2u32.pow(10) * u32::from(word.ap_up()) + + 2u32.pow(12) * u32::from(word.opcode()) + ); +} diff --git a/utils/Cargo.toml b/utils/Cargo.toml index 453f6201db..73252da05f 100644 --- a/utils/Cargo.toml +++ b/utils/Cargo.toml @@ -23,14 +23,12 @@ num-bigint.workspace = true num-integer.workspace = true num-traits.workspace = true rmp-serde.workspace = true +secp256k1.workspace = true sha2.workspace = true thiserror.workspace = true rand.workspace = true rand_core.workspace = true -mina-curves.workspace = true - [dev-dependencies] ark-ec.workspace = true -num-bigint.workspace = true -secp256k1.workspace = true +mina-curves.workspace = true diff --git a/utils/src/array.rs b/utils/src/array.rs new file mode 100644 index 0000000000..d8beeae43f --- /dev/null +++ b/utils/src/array.rs @@ -0,0 +1,189 @@ +//! This module provides different helpers in creating constant sized +//! arrays and converting them to different formats. +//! +//! Functions in this module are not necessarily optimal in terms of +//! allocations, as they tend to create intermediate vectors. For +//! better performance, either optimise this code, or use +//! (non-fixed-sized) vectors. + +// Use a generic function so that the pointer cast remains type-safe +/// Converts a vector of elements to a boxed one. Semantically +/// equivalent to `vector.into_boxed_slice().try_into().unwrap()`. +pub fn vec_to_boxed_array(vec: Vec) -> Box<[T; N]> { + vec.into_boxed_slice() + .try_into() + .unwrap_or_else(|_| panic!("vec_to_boxed_array: length mismatch, expected {}", N)) +} + +// @volhovm It could potentially be more efficient with unsafe tricks. +/// Converts a two-dimensional vector to a constant sized two-dimensional array. +pub fn vec_to_boxed_array2( + vec: Vec>, +) -> Box<[[T; N]; M]> { + let mut vec_of_slices2: Vec<[T; N]> = vec![]; + vec.into_iter().for_each(|x: Vec| { + let y: Box<[T]> = x.into_boxed_slice(); + let z: Box<[T; N]> = y + .try_into() + .unwrap_or_else(|_| panic!("vec_to_boxed_array2: length mismatch inner array")); + let zz: &[[T; N]] = std::slice::from_ref(z.as_ref()); + vec_of_slices2.extend_from_slice(zz); + }); + + vec_of_slices2 + .into_boxed_slice() + .try_into() + .unwrap_or_else(|_| panic!("vec_to_boxed_array2: length mismatch outer array")) +} + +/// Converts a three-dimensional vector to a constant sized two-dimensional array. +pub fn vec_to_boxed_array3( + vec: Vec>>, +) -> Box<[[[T; N]; M]; K]> { + let mut vec_of_slices2: Vec<[[T; N]; M]> = vec![]; + vec.into_iter().for_each(|x| { + let r: Box<[[T; N]; M]> = vec_to_boxed_array2(x); + let zz: &[[[T; N]; M]] = std::slice::from_ref(r.as_ref()); + vec_of_slices2.extend_from_slice(zz); + }); + + vec_of_slices2 + .into_boxed_slice() + .try_into() + .unwrap_or_else(|_| panic!("vec_to_boxed_array3: length mismatch outer array")) +} + +/// A macro similar to `vec![$elem; $size]` which returns a boxed +/// array, allocated directly on the heap (via a vector, with reallocations). +/// +/// ```rustc +/// let _: Box<[u8; 1024]> = box_array![0; 1024]; +/// ``` +/// +/// See +/// +#[macro_export] +macro_rules! box_array { + ($val:expr ; $len:expr) => {{ + // Use a generic function so that the pointer cast remains type-safe + fn vec_to_boxed_array(vec: Vec) -> Box<[T; $len]> { + (vec.into_boxed_slice()) + .try_into() + .unwrap_or_else(|_| panic!("box_array: length mismatch")) + } + + vec_to_boxed_array(vec![$val; $len]) + }}; +} + +/// A macro similar to `vec![vec![$elem; $size1]; $size2]` which +/// returns a two-dimensional boxed array, allocated directly on the +/// heap (via a vector, with reallocations). +/// +/// ```rustc +/// let _: Box<[[u8; 1024]; 512]> = box_array![0; 1024; 512]; +/// ``` +/// +#[macro_export] +macro_rules! box_array2 { + ($val:expr; $len1:expr; $len2:expr) => {{ + fn vec_to_boxed_array2( + vec: Vec>, + ) -> Box<[[T; N]; M]> { + let mut vec_of_slices2: Vec<[T; N]> = vec![]; + vec.into_iter().for_each(|x: Vec| { + let y: Box<[T]> = x.into_boxed_slice(); + let z: Box<[T; N]> = y + .try_into() + .unwrap_or_else(|_| panic!("vec_to_boxed_array2: length mismatch inner array")); + let zz: &[[T; N]] = std::slice::from_ref(z.as_ref()); + vec_of_slices2.extend_from_slice(zz); + }); + + vec_of_slices2 + .into_boxed_slice() + .try_into() + .unwrap_or_else(|_| panic!("vec_to_boxed_array2: length mismatch outer array")) + } + + vec_to_boxed_array2(vec![vec![$val; $len1]; $len2]) + }}; +} + +#[cfg(test)] +mod tests { + use super::*; + + use ark_ec::AffineRepr; + use ark_ff::{UniformRand, Zero}; + use mina_curves::pasta::Pallas as CurvePoint; + + pub type BaseField = ::BaseField; + + #[test] + /// Tests whether initialising different arrays creates a stack + /// overflow. The usual default size of the stack is 128kB. + fn test_boxed_stack_overflow() { + // Each field element is assumed to be 256 bits, so 1.000.000 + // elements is 30MB. This often overflows the stack if created + // as an array. + let _boxed: Box<[[BaseField; 1000000]; 1]> = + vec_to_boxed_array2(vec![vec![BaseField::zero(); 1000000]; 1]); + let _boxed: Box<[[BaseField; 250000]; 4]> = + vec_to_boxed_array2(vec![vec![BaseField::zero(); 250000]; 4]); + let _boxed: Box<[BaseField; 1000000]> = box_array![BaseField::zero(); 1000000]; + let _boxed: Box<[[BaseField; 250000]; 4]> = box_array2![BaseField::zero(); 250000; 4]; + } + + #[test] + /// Tests whether boxed array tranformations preserve the elements. + fn test_boxed_stack_completeness() { + let mut rng = crate::tests::make_test_rng(None); + + const ARR_SIZE: usize = 100; + + let vector1: Vec = (0..ARR_SIZE).map(|i| i * i).collect(); + let boxed1: Box<[usize; ARR_SIZE]> = vec_to_boxed_array(vector1); + assert!(boxed1[0] == 0); + assert!(boxed1[ARR_SIZE - 1] == (ARR_SIZE - 1) * (ARR_SIZE - 1)); + let n_queries = ARR_SIZE; + for _ in 0..n_queries { + let index = usize::rand(&mut rng) % ARR_SIZE; + assert!(boxed1[index] == index * index); + } + + let vector2: Vec> = (0..ARR_SIZE) + .map(|i| (0..ARR_SIZE).map(|j| i * j).collect()) + .collect(); + let boxed2: Box<[[usize; ARR_SIZE]; ARR_SIZE]> = vec_to_boxed_array2(vector2); + assert!(boxed2[0][0] == 0); + assert!(boxed2[ARR_SIZE - 1][ARR_SIZE - 1] == (ARR_SIZE - 1) * (ARR_SIZE - 1)); + let n_queries = ARR_SIZE; + for _ in 0..n_queries { + let index1 = usize::rand(&mut rng) % ARR_SIZE; + let index2 = usize::rand(&mut rng) % ARR_SIZE; + assert!(boxed2[index1][index2] == index1 * index2); + } + + let vector3: Vec>> = (0..ARR_SIZE) + .map(|i| { + (0..ARR_SIZE) + .map(|j| (0..ARR_SIZE).map(|k| i * j + k).collect()) + .collect() + }) + .collect(); + let boxed3: Box<[[[usize; ARR_SIZE]; ARR_SIZE]; ARR_SIZE]> = vec_to_boxed_array3(vector3); + assert!(boxed3[0][0][0] == 0); + assert!( + boxed3[ARR_SIZE - 1][ARR_SIZE - 1][ARR_SIZE - 1] + == (ARR_SIZE - 1) * (ARR_SIZE - 1) + (ARR_SIZE - 1) + ); + let n_queries = ARR_SIZE; + for _ in 0..n_queries { + let index1 = usize::rand(&mut rng) % ARR_SIZE; + let index2 = usize::rand(&mut rng) % ARR_SIZE; + let index3 = usize::rand(&mut rng) % ARR_SIZE; + assert!(boxed3[index1][index2][index3] == index1 * index2 + index3); + } + } +} diff --git a/utils/src/biguint_helpers.rs b/utils/src/biguint_helpers.rs index 73c5955e2a..be97df1625 100644 --- a/utils/src/biguint_helpers.rs +++ b/utils/src/biguint_helpers.rs @@ -7,6 +7,9 @@ pub trait BigUintHelpers { /// Returns the minimum number of bits required to represent a BigUint /// As opposed to BigUint::bits, this function returns 1 for the input zero fn bitlen(&self) -> usize; + + /// Creates a BigUint from an hexadecimal string in big endian + fn from_hex(s: &str) -> Self; } impl BigUintHelpers for BigUint { @@ -17,4 +20,7 @@ impl BigUintHelpers for BigUint { self.bits() as usize } } + fn from_hex(s: &str) -> Self { + BigUint::parse_bytes(s.as_bytes(), 16).unwrap() + } } diff --git a/utils/src/bitwise_operations.rs b/utils/src/bitwise_operations.rs index 00cfb8aa91..adfc3b8e0f 100644 --- a/utils/src/bitwise_operations.rs +++ b/utils/src/bitwise_operations.rs @@ -102,90 +102,3 @@ impl ToBigUint for Vec { bigvalue } } - -#[cfg(test)] -mod tests { - - use super::*; - - #[test] - fn test_xor_256bits() { - let input1: Vec = vec![ - 123, 18, 7, 249, 123, 134, 183, 124, 11, 37, 29, 2, 76, 29, 3, 1, 100, 101, 102, 103, - 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 200, 201, 202, 203, 204, - 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, - ]; - let input2: Vec = vec![ - 33, 76, 13, 224, 2, 0, 21, 96, 131, 137, 229, 200, 128, 255, 127, 15, 1, 2, 3, 4, 5, 6, - 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 80, 81, 82, 93, 94, 95, 76, 77, 78, 69, 60, 61, - 52, 53, 54, 45, - ]; - let output: Vec = vec![ - 90, 94, 10, 25, 121, 134, 162, 28, 136, 172, 248, 202, 204, 226, 124, 14, 101, 103, - 101, 99, 109, 111, 109, 99, 101, 103, 101, 99, 125, 127, 125, 99, 152, 152, 152, 150, - 146, 146, 130, 130, 158, 148, 238, 238, 224, 224, 224, 250, - ]; - let big1 = BigUint::from_bytes_le(&input1); - let big2 = BigUint::from_bytes_le(&input2); - assert_eq!( - BigUint::bitwise_xor(&big1, &big2), - BigUint::from_bytes_le(&output) - ); - } - - #[test] - fn test_and_256bits() { - let input1: Vec = vec![ - 123, 18, 7, 249, 123, 134, 183, 124, 11, 37, 29, 2, 76, 29, 3, 1, 100, 101, 102, 103, - 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 200, 201, 202, 203, 204, - 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, - ]; - let input2: Vec = vec![ - 33, 76, 13, 224, 2, 0, 21, 96, 131, 137, 229, 200, 128, 255, 127, 15, 1, 2, 3, 4, 5, 6, - 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 80, 81, 82, 93, 94, 95, 76, 77, 78, 69, 60, 61, - 52, 53, 54, 45, - ]; - let output: Vec = vec![ - 33, 0, 5, 224, 2, 0, 21, 96, 3, 1, 5, 0, 0, 29, 3, 1, 0, 0, 2, 4, 0, 0, 2, 8, 8, 8, 10, - 12, 0, 0, 2, 16, 64, 65, 66, 73, 76, 77, 76, 77, 64, 65, 16, 17, 20, 21, 22, 5, - ]; - assert_eq!( - BigUint::bitwise_and( - &BigUint::from_bytes_le(&input1), - &BigUint::from_bytes_le(&input2), - 256, - ), - BigUint::from_bytes_le(&output) - ); - } - - #[test] - fn test_xor_all_byte() { - for byte1 in 0..256 { - for byte2 in 0..256 { - let input1 = BigUint::from(byte1 as u8); - let input2 = BigUint::from(byte2 as u8); - assert_eq!( - BigUint::bitwise_xor(&input1, &input2), - BigUint::from((byte1 ^ byte2) as u8) - ); - } - } - } - - #[test] - fn test_not_all_byte() { - for byte in 0..256 { - let input = BigUint::from(byte as u8); - let negated = BigUint::from(!byte as u8); // full 8 bits - assert_eq!(BigUint::bitwise_not(&input, Some(8)), negated); // full byte - let bits = input.bitlen(); - let min_negated = 2u32.pow(bits as u32) - 1 - byte; - // only up to needed - assert_eq!( - BigUint::bitwise_not(&input, None), - BigUint::from(min_negated) - ); - } - } -} diff --git a/utils/src/chunked_polynomial.rs b/utils/src/chunked_polynomial.rs index 45433df93b..7e1d64259f 100644 --- a/utils/src/chunked_polynomial.rs +++ b/utils/src/chunked_polynomial.rs @@ -48,35 +48,3 @@ impl ChunkedPolynomial { DensePolynomial { coeffs } } } - -#[cfg(test)] -mod tests { - use crate::ExtendedDensePolynomial; - - use super::*; - use ark_ff::One; - use ark_poly::{univariate::DensePolynomial, DenseUVPolynomial}; - use mina_curves::pasta::Fp; - - #[test] - - fn test_chunk_poly() { - let one = Fp::one(); - let zeta = one + one; - let zeta_n = zeta.square(); - let num_chunks = 4; - let res = (one + zeta) - * (one + zeta_n + zeta_n * zeta.square() + zeta_n * zeta.square() * zeta.square()); - - // 1 + x + x^2 + x^3 + x^4 + x^5 + x^6 + x^7 = (1+x) + x^2 (1+x) + x^4 (1+x) + x^6 (1+x) - let coeffs = [one, one, one, one, one, one, one, one]; - let f = DensePolynomial::from_coefficients_slice(&coeffs); - - let eval = f - .to_chunked_polynomial(num_chunks, 2) - .linearize(zeta_n) - .evaluate(&zeta); - - assert!(eval == res); - } -} diff --git a/utils/src/dense_polynomial.rs b/utils/src/dense_polynomial.rs index 2c7859d5d4..7f489d037a 100644 --- a/utils/src/dense_polynomial.rs +++ b/utils/src/dense_polynomial.rs @@ -67,32 +67,3 @@ impl ExtendedDensePolynomial for DensePolynomial { } } } - -// -// Tests -// - -#[cfg(test)] -mod tests { - use super::*; - use ark_ff::One; - use ark_poly::{univariate::DensePolynomial, DenseUVPolynomial}; - use mina_curves::pasta::Fp; - - #[test] - fn test_chunk() { - let one = Fp::one(); - let two = one + one; - let three = two + one; - let num_chunks = 4; - - // 1 + x + x^2 + x^3 + x^4 + x^5 + x^6 + x^7 - let coeffs = [one, one, one, one, one, one, one, one]; - let f = DensePolynomial::from_coefficients_slice(&coeffs); - let evals = f.to_chunked_polynomial(num_chunks, 2).evaluate_chunks(two); - assert_eq!(evals.len(), num_chunks); - for eval in evals.into_iter().take(num_chunks) { - assert!(eval == three); - } - } -} diff --git a/utils/src/field_helpers.rs b/utils/src/field_helpers.rs index c81e87e77b..00f8827d06 100644 --- a/utils/src/field_helpers.rs +++ b/utils/src/field_helpers.rs @@ -1,7 +1,7 @@ //! Useful helper methods to extend [ark_ff::Field]. use ark_ff::{BigInteger, Field, PrimeField}; -use num_bigint::{BigUint, RandBigInt}; +use num_bigint::{BigInt, BigUint, RandBigInt, ToBigInt}; use rand::rngs::StdRng; use std::ops::Neg; use thiserror::Error; @@ -106,6 +106,14 @@ pub trait FieldHelpers { BigUint::from_bytes_le(&self.to_bytes()) } + /// Serialize field element f to a (positive) BigInt directly. + fn to_bigint_positive(&self) -> BigInt + where + F: PrimeField, + { + Self::to_biguint(self).to_bigint().unwrap() + } + /// Create a new field element from this field elements bits fn bits_to_field(&self, start: usize, end: usize) -> Result; @@ -137,6 +145,7 @@ impl FieldHelpers for F { .map_err(|_| FieldHelpersError::DeserializeBytes) } + /// Creates a field element from bits (little endian) fn from_bits(bits: &[bool]) -> Result { let bytes = bits .iter() @@ -162,6 +171,7 @@ impl FieldHelpers for F { hex::encode(self.to_bytes()) } + /// Converts a field element into bit representation (little endian) fn to_bits(&self) -> Vec { self.to_bytes().iter().fold(vec![], |mut bits, byte| { let mut byte = *byte; @@ -199,164 +209,32 @@ pub fn i32_to_field + Neg>(i: i32) -> F { } } -#[cfg(test)] -mod tests { - use super::*; - - use ark_ec::AffineRepr; - use ark_ff::One; - use mina_curves::pasta::Pallas as CurvePoint; - - /// Base field element type - pub type BaseField = ::BaseField; - - #[test] - fn field_hex() { - assert_eq!( - BaseField::from_hex(""), - Err(FieldHelpersError::DeserializeBytes) - ); - assert_eq!( - BaseField::from_hex("1428fadcf0c02396e620f14f176fddb5d769b7de2027469d027a80142ef8f07"), - Err(FieldHelpersError::DecodeHex) - ); - assert_eq!( - BaseField::from_hex( - "0f5314f176fddb5d769b7de2027469d027ad428fadcf0c02396e6280142efb7d8" - ), - Err(FieldHelpersError::DecodeHex) - ); - assert_eq!( - BaseField::from_hex("g64244176fddb5d769b7de2027469d027ad428fadcf0c02396e6280142efb7d8"), - Err(FieldHelpersError::DecodeHex) - ); - assert_eq!( - BaseField::from_hex("0cdaf334e9632268a5aa959c2781fb32bf45565fe244ae42c849d3fdc7c644fd"), - Err(FieldHelpersError::DeserializeBytes) - ); - - assert!(BaseField::from_hex( - "25b89cf1a14e2de6124fea18758bf890af76fff31b7fc68713c7653c61b49d39" - ) - .is_ok()); - - let field_hex = "f2eee8d8f6e5fb182c610cae6c5393fce69dc4d900e7b4923b074e54ad00fb36"; - assert_eq!( - BaseField::to_hex( - &BaseField::from_hex(field_hex).expect("Failed to deserialize field hex") - ), - field_hex - ); - } - - #[test] - fn field_bytes() { - assert!(BaseField::from_bytes(&[ - 46, 174, 218, 228, 42, 116, 97, 213, 149, 45, 39, 185, 126, 202, 208, 104, 182, 152, - 235, 185, 78, 138, 14, 76, 69, 56, 139, 182, 19, 222, 126, 8 - ]) - .is_ok(),); - - assert_eq!( - BaseField::from_bytes(&[46, 174, 218, 228, 42, 116, 97, 213]), - Err(FieldHelpersError::DeserializeBytes) - ); - - assert_eq!( - BaseField::to_hex( - &BaseField::from_bytes(&[ - 46, 174, 218, 228, 42, 116, 97, 213, 149, 45, 39, 185, 126, 202, 208, 104, 182, - 152, 235, 185, 78, 138, 14, 76, 69, 56, 139, 182, 19, 222, 126, 8 - ]) - .expect("Failed to deserialize field bytes") - ), - "2eaedae42a7461d5952d27b97ecad068b698ebb94e8a0e4c45388bb613de7e08" - ); - - fn lifetime_test() -> Result { - let bytes = [0; 32]; - BaseField::from_bytes(&bytes) - } - assert!(lifetime_test().is_ok()); +/// `pows(d, x)` returns a vector containing the first `d` powers of the +/// field element `x` (from `1` to `x^(d-1)`). +pub fn pows(d: usize, x: F) -> Vec { + let mut acc = F::one(); + let mut res = vec![]; + for _ in 1..=d { + res.push(acc); + acc *= x; } + res +} - #[test] - fn field_bits() { - let fe = - BaseField::from_hex("2cc3342ad3cd516175b8f0d0189bc3bdcb7947a4cc96c7cfc8d5df10cc443832") - .expect("Failed to deserialize field hex"); - - let fe_check = - BaseField::from_bits(&fe.to_bits()).expect("Failed to deserialize field bits"); - assert_eq!(fe, fe_check); - - assert!(BaseField::from_bits( - &BaseField::from_hex( - "e9a8f3b489990ed7eddce497b7138c6a06ff802d1b58fca1997c5f2ee971cd32" - ) - .expect("Failed to deserialize field hex") - .to_bits() - ) - .is_ok()); - - assert_eq!( - BaseField::from_bits(&vec![ - true; - ::MODULUS_BIT_SIZE as usize - ]), - Err(FieldHelpersError::DeserializeBytes) - ); - - assert!(BaseField::from_bits(&[false, true, false, true]).is_ok(),); - - assert_eq!( - BaseField::from_bits(&[true, false, false]).expect("Failed to deserialize field bytes"), - BaseField::one() - ); +/// Returns the product of all the field elements belonging to an iterator. +pub fn product(xs: impl Iterator) -> F { + let mut res = F::one(); + for x in xs { + res *= &x; } + res +} - #[test] - fn field_big() { - let fe_1024 = BaseField::from(1024u32); - let big_1024: BigUint = fe_1024.into(); - assert_eq!(big_1024, BigUint::new(vec![1024])); - - assert_eq!( - BaseField::from_biguint(&big_1024).expect("Failed to deserialize big uint"), - fe_1024 - ); - - let be_zero_32bytes = vec![0x00, 0x00, 0x00, 0x00, 0x00]; - let be_zero_1byte = vec![0x00]; - let big_zero_32 = BigUint::from_bytes_be(&be_zero_32bytes); - let big_zero_1 = BigUint::from_bytes_be(&be_zero_1byte); - let field_zero = BaseField::from(0u32); - - assert_eq!( - BigUint::from_bytes_be(&field_zero.into_bigint().to_bytes_be()), - BigUint::from_bytes_be(&be_zero_32bytes) - ); - - assert_eq!( - BaseField::from_biguint(&BigUint::from_bytes_be(&be_zero_32bytes)) - .expect("Failed to convert big uint"), - field_zero - ); - - assert_eq!(big_zero_32, big_zero_1); - - assert_eq!( - BaseField::from_biguint(&big_zero_32).expect("Failed"), - BaseField::from_biguint(&big_zero_1).expect("Failed") - ); - - let bytes = [ - 46, 174, 218, 228, 42, 116, 97, 213, 149, 45, 39, 185, 126, 202, 208, 104, 182, 152, - 235, 185, 78, 138, 14, 76, 69, 56, 139, 182, 19, 222, 126, 8, - ]; - let fe = BaseField::from_bytes(&bytes).expect("failed to create field element from bytes"); - let bi = BigUint::from_bytes_le(&bytes); - assert_eq!(fe.to_biguint(), bi); - assert_eq!(bi.to_field::().unwrap(), fe); +/// COmpute the inner product of two slices of field elements. +pub fn inner_prod(xs: &[F], ys: &[F]) -> F { + let mut res = F::zero(); + for (&x, y) in xs.iter().zip(ys) { + res += &(x * y); } + res } diff --git a/utils/src/foreign_field.rs b/utils/src/foreign_field.rs index 00fc1ac9cd..fb05f4eb0d 100644 --- a/utils/src/foreign_field.rs +++ b/utils/src/foreign_field.rs @@ -1,50 +1,27 @@ //! Describes helpers for foreign field arithmetics +//! Generic parameters are as follows: +//! - `B` is a bit length of one limb +//! - `N` is a number of limbs that is used to represent one foreign field element use crate::field_helpers::FieldHelpers; use ark_ff::{Field, PrimeField}; use num_bigint::BigUint; -use num_traits::{One, Zero}; use std::{ - array, fmt::{Debug, Formatter}, ops::{Index, IndexMut}, }; -/// Index of low limb (in 3-limb foreign elements) -pub const LO: usize = 0; -/// Index of middle limb (in 3-limb foreign elements) -pub const MI: usize = 1; -/// Index of high limb (in 3-limb foreign elements) -pub const HI: usize = 2; - -/// Limb length for foreign field elements -pub const LIMB_BITS: usize = 88; - -/// Two to the power of the limb length -pub const TWO_TO_LIMB: u128 = 2u128.pow(LIMB_BITS as u32); - -/// Number of desired limbs for foreign field elements -pub const LIMB_COUNT: usize = 3; - -/// Exponent of binary modulus (i.e. t) -pub const BINARY_MODULUS_EXP: usize = LIMB_BITS * LIMB_COUNT; - -/// Two to the power of the limb length -pub fn two_to_limb() -> BigUint { - BigUint::from(TWO_TO_LIMB) -} - /// Represents a foreign field element #[derive(Clone, PartialEq, Eq)] /// Represents a foreign field element -pub struct ForeignElement { +pub struct ForeignElement { /// limbs in little endian order pub limbs: [F; N], /// number of limbs used for the foreign field element len: usize, } -impl ForeignElement { +impl ForeignElement { /// Creates a new foreign element from an array containing N limbs pub fn new(limbs: [F; N]) -> Self { Self { limbs, len: N } @@ -61,7 +38,7 @@ impl ForeignElement { /// Initializes a new foreign element from a big unsigned integer /// Panics if the BigUint is too large to fit in the `N` limbs pub fn from_biguint(big: BigUint) -> Self { - let vec = ForeignElement::::big_to_vec(big); + let vec = ForeignElement::::big_to_vec(big); // create an array of N native elements containing the limbs // until the array is full in big endian, so most significant @@ -100,47 +77,84 @@ impl ForeignElement { /// Obtains the big integer representation of the foreign field element pub fn to_biguint(&self) -> BigUint { let mut bytes = vec![]; - // limbs are stored in little endian - for limb in self.limbs { - let crumb = &limb.to_bytes()[0..LIMB_BITS / 8]; - bytes.extend_from_slice(crumb); + if B % 8 == 0 { + // limbs are stored in little endian + for limb in self.limbs { + let crumb = &limb.to_bytes()[0..B / 8]; + bytes.extend_from_slice(crumb); + } + } else { + let mut bits: Vec = vec![]; + for limb in self.limbs { + // Only take lower B bits, as there might be more (zeroes) in the high ones. + let f_bits_lower: Vec = limb.to_bits().into_iter().take(B).collect(); + bits.extend(&f_bits_lower); + } + + let bytes_len = if (B * N) % 8 == 0 { + (B * N) / 8 + } else { + ((B * N) / 8) + 1 + }; + bytes = vec![0u8; bytes_len]; + for i in 0..bits.len() { + bytes[i / 8] |= u8::from(bits[i]) << (i % 8); + } } BigUint::from_bytes_le(&bytes) } - /// Split a foreign field element into a vector of `LIMB_BITS` bits field elements of type `F` in little-endian. - /// Right now it is written so that it gives `LIMB_COUNT` limbs, even if it fits in less bits. + /// Split a foreign field element into a vector of `B` (limb bitsize) bits field + /// elements of type `F` in little-endian. Right now it is written + /// so that it gives `N` (limb count) limbs, even if it fits in less bits. fn big_to_vec(fe: BigUint) -> Vec { - let bytes = fe.to_bytes_le(); - let chunks: Vec<&[u8]> = bytes.chunks(LIMB_BITS / 8).collect(); - chunks - .iter() - .map(|chunk| F::from_random_bytes(chunk).expect("failed to deserialize")) - .collect() + if B % 8 == 0 { + let bytes = fe.to_bytes_le(); + let chunks: Vec<&[u8]> = bytes.chunks(B / 8).collect(); + chunks + .iter() + .map(|chunk| F::from_random_bytes(chunk).expect("failed to deserialize")) + .collect() + } else { + // Quite inefficient + let mut bits = vec![]; // can be slice not vec, but B*N is not const? + assert!( + fe.bits() <= (B * N) as u64, + "BigUint too big to be represented in B*N elements" + ); + for i in 0..B * N { + bits.push(fe.bit(i as u64)); + } + let chunks: Vec<_> = bits.chunks(B).collect(); + chunks + .into_iter() + .map(|chunk| F::from_bits(chunk).expect("failed to deserialize")) + .collect() + } } } -impl ForeignElement { +impl ForeignElement { /// Initializes a new foreign element from an element in the native field pub fn from_field(field: F) -> Self { Self::from_biguint(field.into()) } } -impl Index for ForeignElement { +impl Index for ForeignElement { type Output = F; fn index(&self, idx: usize) -> &Self::Output { &self.limbs[idx] } } -impl IndexMut for ForeignElement { +impl IndexMut for ForeignElement { fn index_mut(&mut self, idx: usize) -> &mut Self::Output { &mut self.limbs[idx] } } -impl Debug for ForeignElement { +impl Debug for ForeignElement { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "ForeignElement(")?; for i in 0..self.len { @@ -153,387 +167,31 @@ impl Debug for ForeignElement { } } -/// Foreign field helpers -pub trait ForeignFieldHelpers { - /// 2^{LIMB_BITS} - fn two_to_limb() -> T; - - /// 2^{2 * LIMB_BITS} - fn two_to_2limb() -> T; - - /// 2^{3 * LIMB_BITS} - fn two_to_3limb() -> T; -} - -impl ForeignFieldHelpers for F { - fn two_to_limb() -> Self { - F::from(2u64).pow([LIMB_BITS as u64]) - } - - fn two_to_2limb() -> Self { - F::from(2u64).pow([2 * LIMB_BITS as u64]) - } - - fn two_to_3limb() -> Self { - F::from(2u64).pow([3 * LIMB_BITS as u64]) - } -} - -/// Foreign field helpers -pub trait BigUintForeignFieldHelpers { - /// 2 - fn two() -> Self; - - /// 2^{LIMB_SIZE} - fn two_to_limb() -> Self; - - /// 2^{2 * LIMB_SIZE} - fn two_to_2limb() -> Self; - - /// 2^t - fn binary_modulus() -> Self; - - /// 2^259 (see foreign field multiplication RFC) - fn max_foreign_field_modulus() -> Self; - - /// Convert to 3 limbs of LIMB_BITS each - fn to_limbs(&self) -> [BigUint; 3]; - - /// Convert to 2 limbs of 2 * LIMB_BITS each. The compressed term is the bottom part - fn to_compact_limbs(&self) -> [BigUint; 2]; - - /// Convert to 3 PrimeField limbs of LIMB_BITS each - fn to_field_limbs(&self) -> [F; 3]; - - /// Convert to 2 PrimeField limbs of 2 * LIMB_BITS each. The compressed term is the bottom part. - fn to_compact_field_limbs(&self) -> [F; 2]; - - /// Negate: 2^T - self - fn negate(&self) -> BigUint; -} - -impl BigUintForeignFieldHelpers for BigUint { - fn two() -> Self { - Self::from(2u32) - } - - fn two_to_limb() -> Self { - BigUint::two().pow(LIMB_BITS as u32) - } - - fn two_to_2limb() -> Self { - BigUint::two().pow(2 * LIMB_BITS as u32) - } - - fn binary_modulus() -> Self { - BigUint::two().pow(3 * LIMB_BITS as u32) - } - - fn max_foreign_field_modulus() -> Self { - // For simplicity and efficiency we use the approximation m = 2^259 - 1 - BigUint::two().pow(259) - BigUint::one() - } - - fn to_limbs(&self) -> [Self; 3] { - let mut limbs = biguint_to_limbs(self, LIMB_BITS); - assert!(limbs.len() <= 3); - limbs.resize(3, BigUint::zero()); - - array::from_fn(|i| limbs[i].clone()) - } - - fn to_compact_limbs(&self) -> [Self; 2] { - let mut limbs = biguint_to_limbs(self, 2 * LIMB_BITS); - assert!(limbs.len() <= 2); - limbs.resize(2, BigUint::zero()); - - array::from_fn(|i| limbs[i].clone()) - } - - fn to_field_limbs(&self) -> [F; 3] { - self.to_limbs().to_field_limbs() - } - - fn to_compact_field_limbs(&self) -> [F; 2] { - self.to_compact_limbs().to_field_limbs() - } - - fn negate(&self) -> BigUint { - assert!(*self < BigUint::binary_modulus()); - let neg_self = BigUint::binary_modulus() - self; - assert_eq!(neg_self.bits(), BINARY_MODULUS_EXP as u64); - neg_self - } -} - -/// PrimeField array BigUint helpers -pub trait FieldArrayBigUintHelpers { - /// Convert limbs from field elements to BigUint - fn to_limbs(&self) -> [BigUint; N]; - - /// Alias for to_limbs - fn to_biguints(&self) -> [BigUint; N] { - self.to_limbs() - } -} - -impl FieldArrayBigUintHelpers for [F; N] { - fn to_limbs(&self) -> [BigUint; N] { - array::from_fn(|i| self[i].to_biguint()) - } -} - -/// PrimeField array compose BigUint -pub trait FieldArrayCompose { - /// Compose field limbs into BigUint - fn compose(&self) -> BigUint; -} +/// Foreign field helpers for `B` the limb size. +pub trait ForeignFieldHelpers { + /// 2^{B} + fn two_to_limb() -> F; -impl FieldArrayCompose for [F; 2] { - fn compose(&self) -> BigUint { - fields_compose(self, &BigUint::two_to_2limb()) - } -} + /// 2^{2 * B} + fn two_to_2limb() -> F; -impl FieldArrayCompose for [F; 3] { - fn compose(&self) -> BigUint { - fields_compose(self, &BigUint::two_to_limb()) - } -} - -/// PrimeField array compact limbs -pub trait FieldArrayCompact { - /// Compose field limbs into BigUint - fn to_compact_limbs(&self) -> [F; 2]; -} - -impl FieldArrayCompact for [F; 3] { - fn to_compact_limbs(&self) -> [F; 2] { - [self[0] + F::two_to_limb() * self[1], self[2]] - } + /// 2^{3 * B} + fn two_to_3limb() -> F; } -/// BigUint array PrimeField helpers -pub trait BigUintArrayFieldHelpers { - /// Convert limbs from BigUint to field element - fn to_field_limbs(&self) -> [F; N]; - - /// Alias for to_field_limbs - fn to_fields(&self) -> [F; N] { - self.to_field_limbs() - } -} - -impl BigUintArrayFieldHelpers for [BigUint; N] { - fn to_field_limbs(&self) -> [F; N] { - biguints_to_fields(self) - } -} - -/// BigUint array compose helper -pub trait BigUintArrayCompose { - /// Compose limbs into BigUint - fn compose(&self) -> BigUint; -} - -impl BigUintArrayCompose<2> for [BigUint; 2] { - fn compose(&self) -> BigUint { - bigunits_compose(self, &BigUint::two_to_2limb()) - } -} - -impl BigUintArrayCompose<3> for [BigUint; 3] { - fn compose(&self) -> BigUint { - bigunits_compose(self, &BigUint::two_to_limb()) - } -} - -// Compose field limbs into BigUint value -fn fields_compose(limbs: &[F; N], base: &BigUint) -> BigUint { - limbs - .iter() - .cloned() - .enumerate() - .fold(BigUint::zero(), |x, (i, limb)| { - x + base.pow(i as u32) * limb.to_biguint() - }) -} - -// Convert array of BigUint to an array of PrimeField -fn biguints_to_fields(limbs: &[BigUint; N]) -> [F; N] { - array::from_fn(|i| { - F::from_random_bytes(&limbs[i].to_bytes_le()) - .expect("failed to convert BigUint to field element") - }) -} - -// Compose limbs into BigUint value -fn bigunits_compose(limbs: &[BigUint; N], base: &BigUint) -> BigUint { - limbs - .iter() - .cloned() - .enumerate() - .fold(BigUint::zero(), |x, (i, limb)| { - x + base.pow(i as u32) * limb - }) -} - -// Split a BigUint up into limbs of size limb_size (in little-endian order) -fn biguint_to_limbs(x: &BigUint, limb_bits: usize) -> Vec { - let bytes = x.to_bytes_le(); - let chunks: Vec<&[u8]> = bytes.chunks(limb_bits / 8).collect(); - chunks - .iter() - .map(|chunk| BigUint::from_bytes_le(chunk)) - .collect() -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::field_helpers::FieldHelpers; - use ark_ec::AffineRepr; - use ark_ff::One; - use mina_curves::pasta::Pallas as CurvePoint; - use num_bigint::RandBigInt; - - /// Base field element type - pub type BaseField = ::BaseField; - - fn secp256k1_modulus() -> BigUint { - BigUint::from_bytes_be(&secp256k1::constants::FIELD_SIZE) - } - - #[test] - fn test_big_be() { - let big = secp256k1_modulus(); - let bytes = big.to_bytes_be(); - assert_eq!( - ForeignElement::::from_be(&bytes), - ForeignElement::::from_biguint(big) - ); - } - - #[test] - fn test_to_biguint() { - let big = secp256k1_modulus(); - let bytes = big.to_bytes_be(); - let fe = ForeignElement::::from_be(&bytes); - assert_eq!(fe.to_biguint(), big); - } - - #[test] - fn test_from_biguint() { - let one = ForeignElement::::from_be(&[0x01]); - assert_eq!( - BaseField::from_biguint(&one.to_biguint()).unwrap(), - BaseField::one() - ); - - let max_big = BaseField::modulus_biguint() - 1u32; - let max_fe = ForeignElement::::from_biguint(max_big.clone()); - assert_eq!( - BaseField::from_biguint(&max_fe.to_biguint()).unwrap(), - BaseField::from_bytes(&max_big.to_bytes_le()).unwrap(), - ); - } - - #[test] - fn test_negate_modulus_safe1() { - secp256k1_modulus().negate(); - } - - #[test] - fn test_negate_modulus_safe2() { - BigUint::binary_modulus().sqrt().negate(); - } - - #[test] - fn test_negate_modulus_safe3() { - (BigUint::binary_modulus() / BigUint::from(2u32)).negate(); - } - - #[test] - #[should_panic] - fn test_negate_modulus_unsafe1() { - (BigUint::binary_modulus() - BigUint::one()).negate(); - } - - #[test] - #[should_panic] - fn test_negate_modulus_unsafe2() { - (BigUint::binary_modulus() + BigUint::one()).negate(); - } - - #[test] - #[should_panic] - fn test_negate_modulus_unsafe3() { - BigUint::binary_modulus().negate(); - } - - #[test] - fn check_negation() { - let rng = &mut crate::tests::make_test_rng(None); - for _ in 0..10 { - rng.gen_biguint(256).negate(); - } - } - - #[test] - fn check_good_limbs() { - let rng = &mut crate::tests::make_test_rng(None); - for _ in 0..100 { - let x = rng.gen_biguint(264); - assert_eq!(x.to_limbs().len(), 3); - assert_eq!(x.to_limbs().compose(), x); - assert_eq!(x.to_compact_limbs().len(), 2); - assert_eq!(x.to_compact_limbs().compose(), x); - assert_eq!(x.to_compact_limbs().compose(), x.to_limbs().compose()); - - assert_eq!(x.to_field_limbs::().len(), 3); - assert_eq!(x.to_field_limbs::().compose(), x); - assert_eq!(x.to_compact_field_limbs::().len(), 2); - assert_eq!(x.to_compact_field_limbs::().compose(), x); - assert_eq!( - x.to_compact_field_limbs::().compose(), - x.to_field_limbs::().compose() - ); - - assert_eq!(x.to_limbs().to_fields::(), x.to_field_limbs()); - assert_eq!(x.to_field_limbs::().to_biguints(), x.to_limbs()); - } - } - - #[test] - #[should_panic] - fn check_bad_limbs_1() { - let rng = &mut crate::tests::make_test_rng(None); - assert_ne!(rng.gen_biguint(265).to_limbs().len(), 3); - } - - #[test] - #[should_panic] - fn check_bad_limbs_2() { - let rng = &mut crate::tests::make_test_rng(None); - assert_ne!(rng.gen_biguint(265).to_compact_limbs().len(), 2); +impl ForeignFieldHelpers + for ForeignElement +{ + fn two_to_limb() -> F { + F::from(2u64).pow([B as u64]) } - #[test] - #[should_panic] - fn check_bad_limbs_3() { - let rng = &mut crate::tests::make_test_rng(None); - assert_ne!(rng.gen_biguint(265).to_field_limbs::().len(), 3); + fn two_to_2limb() -> F { + F::from(2u64).pow([2 * B as u64]) } - #[test] - #[should_panic] - fn check_bad_limbs_4() { - let rng = &mut crate::tests::make_test_rng(None); - assert_ne!( - rng.gen_biguint(265) - .to_compact_field_limbs::() - .len(), - 2 - ); + fn two_to_3limb() -> F { + // Self? + F::from(2u64).pow([3 * B as u64]) } } diff --git a/utils/src/lib.rs b/utils/src/lib.rs index c16e5e25d2..aaba775072 100644 --- a/utils/src/lib.rs +++ b/utils/src/lib.rs @@ -1,8 +1,8 @@ -#![deny(missing_docs)] - -//! A collection of utility functions and constants that can be reused from multiple projects +//! A collection of utility functions and constants that can be reused from +//! multiple projects pub mod adjacent_pairs; +pub mod array; pub mod biguint_helpers; pub mod bitwise_operations; pub mod chunked_evaluations; @@ -21,7 +21,7 @@ pub use chunked_evaluations::ChunkedEvaluations; pub use dense_polynomial::ExtendedDensePolynomial; pub use evaluations::ExtendedEvaluations; pub use field_helpers::{BigUintFieldHelpers, FieldHelpers, RandomField, Two}; -pub use foreign_field::{ForeignElement, LIMB_COUNT}; +pub use foreign_field::ForeignElement; /// Utils only for testing pub mod tests { diff --git a/utils/src/math.rs b/utils/src/math.rs index e2df4219ef..b7781fdccf 100644 --- a/utils/src/math.rs +++ b/utils/src/math.rs @@ -21,26 +21,3 @@ pub fn ceil_log2(d: usize) -> usize { pub fn div_ceil(a: usize, b: usize) -> usize { (a + b - 1) / b } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_log2() { - let tests = [ - (1, 0), - (2, 1), - (3, 2), - (9, 4), - (15, 4), - (15430, 14), - (usize::MAX, 64), - ]; - for (d, expected_res) in tests.iter() { - let res = ceil_log2(*d); - println!("ceil(log2({d})) = {res}, expected = {expected_res}"); - assert!(res == *expected_res) - } - } -} diff --git a/utils/src/serialization.rs b/utils/src/serialization.rs index 01f1e94a0b..49fb498736 100644 --- a/utils/src/serialization.rs +++ b/utils/src/serialization.rs @@ -89,6 +89,46 @@ where } } +/// Same as `SerdeAs` but using unchecked and uncompressed (de)serialization. +pub struct SerdeAsUnchecked; + +impl serde_with::SerializeAs for SerdeAsUnchecked +where + T: CanonicalSerialize, +{ + fn serialize_as(val: &T, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut bytes = vec![]; + val.serialize_uncompressed(&mut bytes) + .map_err(serde::ser::Error::custom)?; + + if serializer.is_human_readable() { + hex::serde::serialize(bytes, serializer) + } else { + Bytes::serialize_as(&bytes, serializer) + } + } +} + +impl<'de, T> serde_with::DeserializeAs<'de, T> for SerdeAsUnchecked +where + T: CanonicalDeserialize, +{ + fn deserialize_as(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let bytes: Vec = if deserializer.is_human_readable() { + hex::serde::deserialize(deserializer)? + } else { + Bytes::deserialize_as(deserializer)? + }; + T::deserialize_uncompressed_unchecked(&mut &bytes[..]).map_err(serde::de::Error::custom) + } +} + /// A generic regression serialization test for serialization via /// `CanonicalSerialize` and `CanonicalDeserialize`. pub fn test_generic_serialization_regression_canonical< @@ -172,85 +212,3 @@ pub fn test_generic_serialization_regression_serde< "Serde: deserialized value...\n {data_read:?}\n does not match the expected one...\n {data_expected:?}" ); } - -#[cfg(test)] -mod tests { - - use super::{ - test_generic_serialization_regression_canonical, - test_generic_serialization_regression_serde, - }; - use ark_ec::short_weierstrass::SWCurveConfig; - use mina_curves::pasta::{Pallas, PallasParameters, Vesta, VestaParameters}; - use serde::{Deserialize, Serialize}; - use serde_with::serde_as; - - #[test] - pub fn ser_regression_canonical_bigint() { - use mina_curves::pasta::Fp; - - // Generated with commit 1494cf973d40fb276465929eb7db1952c5de7bdc - let samples: Vec<(Fp, Vec)> = vec![ - ( - Fp::from(5u64), - vec![ - 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, - ], - ), - ( - Fp::from((1u64 << 62) + 7u64), - vec![ - 7, 0, 0, 0, 0, 0, 0, 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, - ], - ), - ( - Fp::from((1u64 << 30) * 13u64 * 7u64 * 5u64 * 3u64 + 7u64), - vec![ - 7, 0, 0, 64, 85, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, - ], - ), - ( - Fp::from((1u64 << 63) + 7u64) - * Fp::from((1u64 << 63) + 13u64) - * Fp::from((1u64 << 63) + 17u64), - vec![ - 11, 6, 0, 0, 0, 0, 0, 128, 215, 0, 0, 0, 0, 0, 0, 64, 9, 0, 0, 0, 0, 0, 0, 32, - 0, 0, 0, 0, 0, 0, 0, 0, - ], - ), - ]; - - for (data_expected, buf_expected) in samples { - test_generic_serialization_regression_canonical(data_expected, buf_expected); - } - } - - #[test] - pub fn ser_regression_canonical_pasta() { - #[serde_as] - #[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] - struct TestStruct { - #[serde_as(as = "crate::serialization::SerdeAs")] - pallas: Pallas, - #[serde_as(as = "crate::serialization::SerdeAs")] - vesta: Vesta, - } - - let data_expected = TestStruct { - pallas: PallasParameters::GENERATOR, - vesta: VestaParameters::GENERATOR, - }; - - // Generated with commit 1494cf973d40fb276465929eb7db1952c5de7bdc - let buf_expected: Vec = vec![ - 146, 196, 33, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 196, 33, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - ]; - - test_generic_serialization_regression_serde(data_expected, buf_expected); - } -} diff --git a/utils/tests/bitwise_operations.rs b/utils/tests/bitwise_operations.rs new file mode 100644 index 0000000000..8a37ab333f --- /dev/null +++ b/utils/tests/bitwise_operations.rs @@ -0,0 +1,83 @@ +use num_bigint::BigUint; +use o1_utils::{BigUintHelpers, BitwiseOps}; + +#[test] +fn test_xor_256bits() { + let input1: Vec = vec![ + 123, 18, 7, 249, 123, 134, 183, 124, 11, 37, 29, 2, 76, 29, 3, 1, 100, 101, 102, 103, 104, + 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 200, 201, 202, 203, 204, 205, 206, + 207, 208, 209, 210, 211, 212, 213, 214, 215, + ]; + let input2: Vec = vec![ + 33, 76, 13, 224, 2, 0, 21, 96, 131, 137, 229, 200, 128, 255, 127, 15, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, 16, 80, 81, 82, 93, 94, 95, 76, 77, 78, 69, 60, 61, 52, 53, + 54, 45, + ]; + let output: Vec = vec![ + 90, 94, 10, 25, 121, 134, 162, 28, 136, 172, 248, 202, 204, 226, 124, 14, 101, 103, 101, + 99, 109, 111, 109, 99, 101, 103, 101, 99, 125, 127, 125, 99, 152, 152, 152, 150, 146, 146, + 130, 130, 158, 148, 238, 238, 224, 224, 224, 250, + ]; + let big1 = BigUint::from_bytes_le(&input1); + let big2 = BigUint::from_bytes_le(&input2); + assert_eq!( + BigUint::bitwise_xor(&big1, &big2), + BigUint::from_bytes_le(&output) + ); +} + +#[test] +fn test_and_256bits() { + let input1: Vec = vec![ + 123, 18, 7, 249, 123, 134, 183, 124, 11, 37, 29, 2, 76, 29, 3, 1, 100, 101, 102, 103, 104, + 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 200, 201, 202, 203, 204, 205, 206, + 207, 208, 209, 210, 211, 212, 213, 214, 215, + ]; + let input2: Vec = vec![ + 33, 76, 13, 224, 2, 0, 21, 96, 131, 137, 229, 200, 128, 255, 127, 15, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, 16, 80, 81, 82, 93, 94, 95, 76, 77, 78, 69, 60, 61, 52, 53, + 54, 45, + ]; + let output: Vec = vec![ + 33, 0, 5, 224, 2, 0, 21, 96, 3, 1, 5, 0, 0, 29, 3, 1, 0, 0, 2, 4, 0, 0, 2, 8, 8, 8, 10, 12, + 0, 0, 2, 16, 64, 65, 66, 73, 76, 77, 76, 77, 64, 65, 16, 17, 20, 21, 22, 5, + ]; + assert_eq!( + BigUint::bitwise_and( + &BigUint::from_bytes_le(&input1), + &BigUint::from_bytes_le(&input2), + 256, + ), + BigUint::from_bytes_le(&output) + ); +} + +#[test] +fn test_xor_all_byte() { + for byte1 in 0..256 { + for byte2 in 0..256 { + let input1 = BigUint::from(byte1 as u8); + let input2 = BigUint::from(byte2 as u8); + assert_eq!( + BigUint::bitwise_xor(&input1, &input2), + BigUint::from((byte1 ^ byte2) as u8) + ); + } + } +} + +#[test] +fn test_not_all_byte() { + for byte in 0..256 { + let input = BigUint::from(byte as u8); + let negated = BigUint::from(!byte as u8); // full 8 bits + assert_eq!(BigUint::bitwise_not(&input, Some(8)), negated); // full byte + let bits = input.bitlen(); + let min_negated = 2u32.pow(bits as u32) - 1 - byte; + // only up to needed + assert_eq!( + BigUint::bitwise_not(&input, None), + BigUint::from(min_negated) + ); + } +} diff --git a/utils/tests/chunked_polynomials.rs b/utils/tests/chunked_polynomials.rs new file mode 100644 index 0000000000..082776a241 --- /dev/null +++ b/utils/tests/chunked_polynomials.rs @@ -0,0 +1,26 @@ +use ark_ff::{Field, One}; +use ark_poly::{univariate::DensePolynomial, DenseUVPolynomial, Polynomial}; +use mina_curves::pasta::Fp; +use o1_utils::ExtendedDensePolynomial; + +#[test] + +fn test_chunk_poly() { + let one = Fp::one(); + let zeta = one + one; + let zeta_n = zeta.square(); + let num_chunks = 4; + let res = (one + zeta) + * (one + zeta_n + zeta_n * zeta.square() + zeta_n * zeta.square() * zeta.square()); + + // 1 + x + x^2 + x^3 + x^4 + x^5 + x^6 + x^7 = (1+x) + x^2 (1+x) + x^4 (1+x) + x^6 (1+x) + let coeffs = [one, one, one, one, one, one, one, one]; + let f = DensePolynomial::from_coefficients_slice(&coeffs); + + let eval = f + .to_chunked_polynomial(num_chunks, 2) + .linearize(zeta_n) + .evaluate(&zeta); + + assert!(eval == res); +} diff --git a/utils/tests/dense_polynomial.rs b/utils/tests/dense_polynomial.rs new file mode 100644 index 0000000000..8e67a31656 --- /dev/null +++ b/utils/tests/dense_polynomial.rs @@ -0,0 +1,21 @@ +use ark_ff::One; +use ark_poly::{univariate::DensePolynomial, DenseUVPolynomial}; +use mina_curves::pasta::Fp; +use o1_utils::ExtendedDensePolynomial; + +#[test] +fn test_chunk() { + let one = Fp::one(); + let two = one + one; + let three = two + one; + let num_chunks = 4; + + // 1 + x + x^2 + x^3 + x^4 + x^5 + x^6 + x^7 + let coeffs = [one, one, one, one, one, one, one, one]; + let f = DensePolynomial::from_coefficients_slice(&coeffs); + let evals = f.to_chunked_polynomial(num_chunks, 2).evaluate_chunks(two); + assert_eq!(evals.len(), num_chunks); + for eval in evals.into_iter().take(num_chunks) { + assert!(eval == three); + } +} diff --git a/utils/tests/field_helpers.rs b/utils/tests/field_helpers.rs new file mode 100644 index 0000000000..7e6b4fd163 --- /dev/null +++ b/utils/tests/field_helpers.rs @@ -0,0 +1,156 @@ +use ark_ec::AffineRepr; +use ark_ff::{BigInteger, One, PrimeField}; +use mina_curves::pasta::Pallas as CurvePoint; +use num_bigint::BigUint; +use o1_utils::{ + field_helpers::{FieldHelpersError, Result}, + BigUintFieldHelpers, FieldHelpers, +}; + +/// Base field element type +pub type BaseField = ::BaseField; + +#[test] +fn field_hex() { + assert_eq!( + BaseField::from_hex(""), + Err(FieldHelpersError::DeserializeBytes) + ); + assert_eq!( + BaseField::from_hex("1428fadcf0c02396e620f14f176fddb5d769b7de2027469d027a80142ef8f07"), + Err(FieldHelpersError::DecodeHex) + ); + assert_eq!( + BaseField::from_hex("0f5314f176fddb5d769b7de2027469d027ad428fadcf0c02396e6280142efb7d8"), + Err(FieldHelpersError::DecodeHex) + ); + assert_eq!( + BaseField::from_hex("g64244176fddb5d769b7de2027469d027ad428fadcf0c02396e6280142efb7d8"), + Err(FieldHelpersError::DecodeHex) + ); + assert_eq!( + BaseField::from_hex("0cdaf334e9632268a5aa959c2781fb32bf45565fe244ae42c849d3fdc7c644fd"), + Err(FieldHelpersError::DeserializeBytes) + ); + + assert!(BaseField::from_hex( + "25b89cf1a14e2de6124fea18758bf890af76fff31b7fc68713c7653c61b49d39" + ) + .is_ok()); + + let field_hex = "f2eee8d8f6e5fb182c610cae6c5393fce69dc4d900e7b4923b074e54ad00fb36"; + assert_eq!( + BaseField::to_hex( + &BaseField::from_hex(field_hex).expect("Failed to deserialize field hex") + ), + field_hex + ); +} + +#[test] +fn field_bytes() { + assert!(BaseField::from_bytes(&[ + 46, 174, 218, 228, 42, 116, 97, 213, 149, 45, 39, 185, 126, 202, 208, 104, 182, 152, 235, + 185, 78, 138, 14, 76, 69, 56, 139, 182, 19, 222, 126, 8 + ]) + .is_ok(),); + + assert_eq!( + BaseField::from_bytes(&[46, 174, 218, 228, 42, 116, 97, 213]), + Err(FieldHelpersError::DeserializeBytes) + ); + + assert_eq!( + BaseField::to_hex( + &BaseField::from_bytes(&[ + 46, 174, 218, 228, 42, 116, 97, 213, 149, 45, 39, 185, 126, 202, 208, 104, 182, + 152, 235, 185, 78, 138, 14, 76, 69, 56, 139, 182, 19, 222, 126, 8 + ]) + .expect("Failed to deserialize field bytes") + ), + "2eaedae42a7461d5952d27b97ecad068b698ebb94e8a0e4c45388bb613de7e08" + ); + + fn lifetime_test() -> Result { + let bytes = [0; 32]; + BaseField::from_bytes(&bytes) + } + assert!(lifetime_test().is_ok()); +} + +#[test] +fn field_bits() { + let fe = + BaseField::from_hex("2cc3342ad3cd516175b8f0d0189bc3bdcb7947a4cc96c7cfc8d5df10cc443832") + .expect("Failed to deserialize field hex"); + + let fe_check = BaseField::from_bits(&fe.to_bits()).expect("Failed to deserialize field bits"); + assert_eq!(fe, fe_check); + + assert!(BaseField::from_bits( + &BaseField::from_hex("e9a8f3b489990ed7eddce497b7138c6a06ff802d1b58fca1997c5f2ee971cd32") + .expect("Failed to deserialize field hex") + .to_bits() + ) + .is_ok()); + + assert_eq!( + BaseField::from_bits(&vec![ + true; + ::MODULUS_BIT_SIZE as usize + ]), + Err(FieldHelpersError::DeserializeBytes) + ); + + assert!(BaseField::from_bits(&[false, true, false, true]).is_ok(),); + + assert_eq!( + BaseField::from_bits(&[true, false, false]).expect("Failed to deserialize field bytes"), + BaseField::one() + ); +} + +#[test] +fn field_big() { + let fe_1024 = BaseField::from(1024u32); + let big_1024: BigUint = fe_1024.into(); + assert_eq!(big_1024, BigUint::new(vec![1024])); + + assert_eq!( + BaseField::from_biguint(&big_1024).expect("Failed to deserialize big uint"), + fe_1024 + ); + + let be_zero_32bytes = vec![0x00, 0x00, 0x00, 0x00, 0x00]; + let be_zero_1byte = vec![0x00]; + let big_zero_32 = BigUint::from_bytes_be(&be_zero_32bytes); + let big_zero_1 = BigUint::from_bytes_be(&be_zero_1byte); + let field_zero = BaseField::from(0u32); + + assert_eq!( + BigUint::from_bytes_be(&field_zero.into_bigint().to_bytes_be()), + BigUint::from_bytes_be(&be_zero_32bytes) + ); + + assert_eq!( + BaseField::from_biguint(&BigUint::from_bytes_be(&be_zero_32bytes)) + .expect("Failed to convert big uint"), + field_zero + ); + + assert_eq!(big_zero_32, big_zero_1); + + assert_eq!( + BaseField::from_biguint(&big_zero_32).expect("Failed"), + BaseField::from_biguint(&big_zero_1).expect("Failed") + ); + + let bytes = [ + 46, 174, 218, 228, 42, 116, 97, 213, 149, 45, 39, 185, 126, 202, 208, 104, 182, 152, 235, + 185, 78, 138, 14, 76, 69, 56, 139, 182, 19, 222, 126, 8, + ]; + let fe = BaseField::from_bytes(&bytes).expect("failed to create field element from bytes"); + let bi = BigUint::from_bytes_le(&bytes); + assert_eq!(fe.to_biguint(), bi); + assert_eq!(bi.to_field::().unwrap(), fe); +} diff --git a/utils/tests/foreign_field.rs b/utils/tests/foreign_field.rs new file mode 100644 index 0000000000..5c19f11654 --- /dev/null +++ b/utils/tests/foreign_field.rs @@ -0,0 +1,73 @@ +use ark_ec::AffineRepr; +use ark_ff::One; +use mina_curves::pasta::Pallas as CurvePoint; +use num_bigint::BigUint; +use o1_utils::{field_helpers::FieldHelpers, ForeignElement}; + +/// Base field element type +pub type BaseField = ::BaseField; + +fn secp256k1_modulus() -> BigUint { + BigUint::from_bytes_be(&secp256k1::constants::FIELD_SIZE) +} + +const TEST_B_1: usize = 88; +const TEST_N_1: usize = 3; +const TEST_B_2: usize = 15; +const TEST_N_2: usize = 18; + +#[test] +fn test_big_be() { + let big = secp256k1_modulus(); + let bytes = big.to_bytes_be(); + assert_eq!( + ForeignElement::::from_be(&bytes), + ForeignElement::::from_biguint(big.clone()) + ); + assert_eq!( + ForeignElement::::from_be(&bytes), + ForeignElement::::from_biguint(big) + ); +} + +#[test] +fn test_to_biguint() { + let big = secp256k1_modulus(); + let bytes = big.to_bytes_be(); + let fe = ForeignElement::::from_be(&bytes); + assert_eq!(fe.to_biguint(), big); + let fe2 = ForeignElement::::from_be(&bytes); + assert_eq!(fe2.to_biguint(), big); +} + +#[test] +fn test_from_biguint() { + { + let one = ForeignElement::::from_be(&[0x01]); + assert_eq!( + BaseField::from_biguint(&one.to_biguint()).unwrap(), + BaseField::one() + ); + + let max_big = BaseField::modulus_biguint() - 1u32; + let max_fe = ForeignElement::::from_biguint(max_big.clone()); + assert_eq!( + BaseField::from_biguint(&max_fe.to_biguint()).unwrap(), + BaseField::from_bytes(&max_big.to_bytes_le()).unwrap(), + ); + } + { + let one = ForeignElement::::from_be(&[0x01]); + assert_eq!( + BaseField::from_biguint(&one.to_biguint()).unwrap(), + BaseField::one() + ); + + let max_big = BaseField::modulus_biguint() - 1u32; + let max_fe = ForeignElement::::from_biguint(max_big.clone()); + assert_eq!( + BaseField::from_biguint(&max_fe.to_biguint()).unwrap(), + BaseField::from_bytes(&max_big.to_bytes_le()).unwrap(), + ); + } +} diff --git a/utils/tests/math.rs b/utils/tests/math.rs new file mode 100644 index 0000000000..75b62e8b39 --- /dev/null +++ b/utils/tests/math.rs @@ -0,0 +1,23 @@ +use o1_utils::math::ceil_log2; + +#[test] +fn test_log2() { + let tests = [ + (1, 0), + (2, 1), + (3, 2), + (9, 4), + (15, 4), + (16, 4), + (17, 5), + (15430, 14), + (usize::MAX, 64), + ]; + for (d, expected_res) in tests.iter() { + let res = ceil_log2(*d); + assert_eq!( + res, *expected_res, + "ceil(log2({d})) = {res}, expected = {expected_res}" + ) + } +} diff --git a/utils/tests/serialization.rs b/utils/tests/serialization.rs new file mode 100644 index 0000000000..15fded48c8 --- /dev/null +++ b/utils/tests/serialization.rs @@ -0,0 +1,76 @@ +use ark_ec::short_weierstrass::SWCurveConfig; +use mina_curves::pasta::{Pallas, PallasParameters, Vesta, VestaParameters}; +use o1_utils::serialization::{ + test_generic_serialization_regression_canonical, test_generic_serialization_regression_serde, +}; +use serde::{Deserialize, Serialize}; +use serde_with::serde_as; + +#[test] +pub fn ser_regression_canonical_bigint() { + use mina_curves::pasta::Fp; + + // Generated with commit 1494cf973d40fb276465929eb7db1952c5de7bdc + let samples: Vec<(Fp, Vec)> = vec![ + ( + Fp::from(5u64), + vec![ + 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, + ], + ), + ( + Fp::from((1u64 << 62) + 7u64), + vec![ + 7, 0, 0, 0, 0, 0, 0, 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, + ], + ), + ( + Fp::from((1u64 << 30) * 13u64 * 7u64 * 5u64 * 3u64 + 7u64), + vec![ + 7, 0, 0, 64, 85, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, + ], + ), + ( + Fp::from((1u64 << 63) + 7u64) + * Fp::from((1u64 << 63) + 13u64) + * Fp::from((1u64 << 63) + 17u64), + vec![ + 11, 6, 0, 0, 0, 0, 0, 128, 215, 0, 0, 0, 0, 0, 0, 64, 9, 0, 0, 0, 0, 0, 0, 32, 0, + 0, 0, 0, 0, 0, 0, 0, + ], + ), + ]; + + for (data_expected, buf_expected) in samples { + test_generic_serialization_regression_canonical(data_expected, buf_expected); + } +} + +#[test] +pub fn ser_regression_canonical_pasta() { + #[serde_as] + #[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] + struct TestStruct { + #[serde_as(as = "o1_utils::serialization::SerdeAs")] + pallas: Pallas, + #[serde_as(as = "o1_utils::serialization::SerdeAs")] + vesta: Vesta, + } + + let data_expected = TestStruct { + pallas: PallasParameters::GENERATOR, + vesta: VestaParameters::GENERATOR, + }; + + // Generated with commit 1494cf973d40fb276465929eb7db1952c5de7bdc + let buf_expected: Vec = vec![ + 146, 196, 33, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 196, 33, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]; + + test_generic_serialization_regression_serde(data_expected, buf_expected); +}