Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MSM Calldata Updates #192

Merged
merged 12 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions hydra/garaga/starknet/tests_and_calldata_generators/msm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass
from functools import lru_cache

from garaga import garaga_rs
from garaga import modulo_circuit_structs as structs
from garaga.algebra import FunctionFelt, PyFelt
from garaga.definitions import CURVES, STARK, CurveID, G1Point, get_base_field
Expand Down Expand Up @@ -358,13 +359,39 @@ def to_cairo_1_test(self, test_name: str = None):
"""
return code

def _serialize_to_calldata_rust(
self,
include_digits_decomposition=True,
include_points_and_scalars=True,
serialize_as_pure_felt252_array=False,
risc0_mode=False,
) -> list[int]:
return garaga_rs.msm_calldata_builder(
[value for point in self.points for value in [point.x, point.y]],
self.scalars,
self.curve_id.value,
include_digits_decomposition,
include_points_and_scalars,
serialize_as_pure_felt252_array,
risc0_mode,
)

def serialize_to_calldata(
self,
include_digits_decomposition=True,
include_points_and_scalars=True,
serialize_as_pure_felt252_array=False,
risc0_mode=False,
use_rust=False,
) -> list[int]:
if use_rust:
return self._serialize_to_calldata_rust(
include_digits_decomposition,
include_points_and_scalars,
serialize_as_pure_felt252_array,
risc0_mode,
)

inputs = self._get_input_structs(risc0_mode)
option = (
structs.CairoOption.SOME
Expand Down
60 changes: 60 additions & 0 deletions tests/hydra/starknet/test_calldata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import random

import pytest

from garaga.definitions import CURVES, CurveID, G1Point
from garaga.starknet.tests_and_calldata_generators.msm import MSMCalldataBuilder

# Define the curves to be tested
curves = list(CurveID)


@pytest.mark.parametrize("curve_id", curves)
@pytest.mark.parametrize("msm_size", range(1, 2))
@pytest.mark.parametrize("include_digits_decomposition", [True, False])
@pytest.mark.parametrize("include_points_and_scalars", [True, False])
@pytest.mark.parametrize("serialize_as_pure_felt252_array", [True, False])
@pytest.mark.parametrize("risc0_mode", [True, False])
def test_msm_calldata_builder(
curve_id,
msm_size,
include_digits_decomposition,
include_points_and_scalars,
serialize_as_pure_felt252_array,
risc0_mode,
):
curve = CURVES[curve_id.value]
order = curve.n

scalar_limit = min(order, 2**128) if risc0_mode else order

points = [G1Point.gen_random_point(curve_id) for _ in range(msm_size)]
scalars = [random.randint(0, scalar_limit - 1) for _ in range(msm_size)]

msm = MSMCalldataBuilder(
points=points,
scalars=scalars,
curve_id=curve_id,
)

calldata1 = msm.serialize_to_calldata(
include_digits_decomposition=include_digits_decomposition,
include_points_and_scalars=include_points_and_scalars,
serialize_as_pure_felt252_array=serialize_as_pure_felt252_array,
risc0_mode=risc0_mode,
use_rust=False,
)

calldata2 = msm.serialize_to_calldata(
include_digits_decomposition=include_digits_decomposition,
include_points_and_scalars=include_points_and_scalars,
serialize_as_pure_felt252_array=serialize_as_pure_felt252_array,
risc0_mode=risc0_mode,
use_rust=True,
)

assert calldata1 == calldata2


if __name__ == "__main__":
pytest.main()
Loading
Loading