Skip to content

Commit

Permalink
Updates msm calldata binding to support different options (#192)
Browse files Browse the repository at this point in the history
  • Loading branch information
raugfer authored Sep 11, 2024
1 parent b973b2e commit 2fb2964
Show file tree
Hide file tree
Showing 18 changed files with 537 additions and 103 deletions.
27 changes: 27 additions & 0 deletions hydra/garaga/starknet/tests_and_calldata_generators/msm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass
from functools import lru_cache

from garaga import garaga_rs
from garaga import modulo_circuit_structs as structs
from garaga.algebra import FunctionFelt, PyFelt
from garaga.definitions import CURVES, STARK, CurveID, G1Point, get_base_field
Expand Down Expand Up @@ -358,13 +359,39 @@ def to_cairo_1_test(self, test_name: str = None):
"""
return code

def _serialize_to_calldata_rust(
self,
include_digits_decomposition=True,
include_points_and_scalars=True,
serialize_as_pure_felt252_array=False,
risc0_mode=False,
) -> list[int]:
return garaga_rs.msm_calldata_builder(
[value for point in self.points for value in [point.x, point.y]],
self.scalars,
self.curve_id.value,
include_digits_decomposition,
include_points_and_scalars,
serialize_as_pure_felt252_array,
risc0_mode,
)

def serialize_to_calldata(
self,
include_digits_decomposition=True,
include_points_and_scalars=True,
serialize_as_pure_felt252_array=False,
risc0_mode=False,
use_rust=False,
) -> list[int]:
if use_rust:
return self._serialize_to_calldata_rust(
include_digits_decomposition,
include_points_and_scalars,
serialize_as_pure_felt252_array,
risc0_mode,
)

inputs = self._get_input_structs(risc0_mode)
option = (
structs.CairoOption.SOME
Expand Down
60 changes: 60 additions & 0 deletions tests/hydra/starknet/test_calldata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import random

import pytest

from garaga.definitions import CURVES, CurveID, G1Point
from garaga.starknet.tests_and_calldata_generators.msm import MSMCalldataBuilder

# Define the curves to be tested
curves = list(CurveID)


@pytest.mark.parametrize("curve_id", curves)
@pytest.mark.parametrize("msm_size", range(1, 2))
@pytest.mark.parametrize("include_digits_decomposition", [True, False])
@pytest.mark.parametrize("include_points_and_scalars", [True, False])
@pytest.mark.parametrize("serialize_as_pure_felt252_array", [True, False])
@pytest.mark.parametrize("risc0_mode", [True, False])
def test_msm_calldata_builder(
curve_id,
msm_size,
include_digits_decomposition,
include_points_and_scalars,
serialize_as_pure_felt252_array,
risc0_mode,
):
curve = CURVES[curve_id.value]
order = curve.n

scalar_limit = min(order, 2**128) if risc0_mode else order

points = [G1Point.gen_random_point(curve_id) for _ in range(msm_size)]
scalars = [random.randint(0, scalar_limit - 1) for _ in range(msm_size)]

msm = MSMCalldataBuilder(
points=points,
scalars=scalars,
curve_id=curve_id,
)

calldata1 = msm.serialize_to_calldata(
include_digits_decomposition=include_digits_decomposition,
include_points_and_scalars=include_points_and_scalars,
serialize_as_pure_felt252_array=serialize_as_pure_felt252_array,
risc0_mode=risc0_mode,
use_rust=False,
)

calldata2 = msm.serialize_to_calldata(
include_digits_decomposition=include_digits_decomposition,
include_points_and_scalars=include_points_and_scalars,
serialize_as_pure_felt252_array=serialize_as_pure_felt252_array,
risc0_mode=risc0_mode,
use_rust=True,
)

assert calldata1 == calldata2


if __name__ == "__main__":
pytest.main()
Loading

0 comments on commit 2fb2964

Please sign in to comment.