Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

garaga-rs : MSM calldata function updates #189

Closed
feltroidprime opened this issue Sep 10, 2024 · 0 comments
Closed

garaga-rs : MSM calldata function updates #189

feltroidprime opened this issue Sep 10, 2024 · 0 comments
Assignees

Comments

@feltroidprime
Copy link
Collaborator

feltroidprime commented Sep 10, 2024

The current implementation of the msm_calldata_builder in

pub fn msm_calldata_builder(
values: &[BigUint],
scalars: &[BigUint],
curve_id: usize,
) -> Vec<BigInt> {
assert_eq!(values.len(), 2 * scalars.len());
let curve_id = CurveID::from(curve_id);
match curve_id {
CurveID::BN254 => handle_curve::<BN254PrimeField>(values, scalars, curve_id as usize),
CurveID::BLS12_381 => {
handle_curve::<BLS12381PrimeField>(values, scalars, curve_id as usize)
}
CurveID::SECP256K1 => {
handle_curve::<SECP256K1PrimeField>(values, scalars, curve_id as usize)
}
CurveID::SECP256R1 => {
handle_curve::<SECP256R1PrimeField>(values, scalars, curve_id as usize)
}
CurveID::X25519 => handle_curve::<X25519PrimeField>(values, scalars, curve_id as usize),
}
}
, (which should be equivalent to its python sibling in )

Is calling the internal function calldata_builder with default boolean parameters for

    include_digits_decomposition: bool,
    include_points_and_scalars: bool,
    serialize_as_pure_felt252_array: bool,

Requested tasks :

  • Modify msm_calldata_builder to include those 3 parameters as well

  • Update the Wasm binding to use those 3 parameters

  • Add a new boolean parameter risc0_mode like in the Python class. The main difference between risc0_mode and "normal" mode is that the scalars should be u128 and not u256. This is a simpler case where the zk-ecip hint is used only one time instead of three. See the difference in the Cairo function for msm_g1 and msm_g1_u128

    garaga/src/src/ec_ops.cairo

    Lines 352 to 557 in 7e16413

    fn msm_g1(
    scalars_digits_decompositions: Option<Span<(Span<felt252>, Span<felt252>)>>,
    hint: MSMHint,
    derive_point_from_x_hint: DerivePointFromXHint,
    points: Span<G1Point>,
    scalars: Span<u256>,
    curve_index: usize
    ) -> G1Point {
    let n = scalars.len();
    assert!(n == points.len(), "scalars and points length mismatch");
    if n == 0 {
    panic!("Msm size must be >= 1");
    }
    // Check result points are either on curve, or the point at infinity
    if !hint.Q_low.is_infinity() {
    hint.Q_low.assert_on_curve(curve_index);
    }
    if !hint.Q_high.is_infinity() {
    hint.Q_high.assert_on_curve(curve_index);
    }
    if !hint.Q_high_shifted.is_infinity() {
    hint.Q_high_shifted.assert_on_curve(curve_index);
    }
    // Validate the degrees of the functions field elements given the msm size
    hint.SumDlogDivLow.validate_degrees(n);
    hint.SumDlogDivHigh.validate_degrees(n);
    hint.SumDlogDivHighShifted.validate_degrees(1);
    // Hash everything to obtain a x coordinate.
    let (s0, s1, s2): (felt252, felt252, felt252) = hades_permutation(
    'MSM_G1', 0, 1
    ); // Init Sponge state
    let (s0, s1, s2) = hades_permutation(
    s0 + curve_index.into(), s1 + n.into(), s2
    ); // Include curve_index and msm size
    let (s0, s1, s2) = hint.SumDlogDivLow.update_hash_state(s0, s1, s2);
    let (s0, s1, s2) = hint.SumDlogDivHigh.update_hash_state(s0, s1, s2);
    let (s0, s1, s2) = hint.SumDlogDivHighShifted.update_hash_state(s0, s1, s2);
    let mut s0 = s0;
    let mut s1 = s1;
    let mut s2 = s2;
    // Check input points are on curve and hash them at the same time.
    for point in points {
    if !point.is_infinity() {
    point.assert_on_curve(curve_index);
    }
    let (_s0, _s1, _s2) = point.update_hash_state(s0, s1, s2);
    s0 = _s0;
    s1 = _s1;
    s2 = _s2;
    };
    // Hash result points
    let (s0, s1, s2) = hint.Q_low.update_hash_state(s0, s1, s2);
    let (s0, s1, s2) = hint.Q_high.update_hash_state(s0, s1, s2);
    let (s0, s1, s2) = hint.Q_high_shifted.update_hash_state(s0, s1, s2);
    // Hash scalars and verify they are below the curve order
    let curve_order = get_n(curve_index);
    let mut s0 = s0;
    let mut s1 = s1;
    let mut s2 = s2;
    for scalar in scalars {
    assert!(*scalar <= curve_order, "One of the scalar is larger than the curve order");
    let (_s0, _s1, _s2) = core::poseidon::hades_permutation(
    s0 + (*scalar.low).into(), s1 + (*scalar.high).into(), s2
    );
    s0 = _s0;
    s1 = _s1;
    s2 = _s2;
    };
    let random_point: G1Point = derive_ec_point_from_X(
    s0,
    derive_point_from_x_hint.y_last_attempt,
    derive_point_from_x_hint.g_rhs_sqrt,
    curve_index
    );
    // Get slope, intercept and other constant from random point
    let (mb): (SlopeInterceptOutput,) = ec::run_SLOPE_INTERCEPT_SAME_POINT_circuit(
    random_point, get_a(curve_index), curve_index
    );
    // Get positive and negative multiplicities of low and high part of scalars
    let (epns_low, epns_high) = neg_3::u256_array_to_low_high_epns(
    scalars, scalars_digits_decompositions
    );
    // Hardcoded epns for 2**128
    let epns_shifted: Array<(felt252, felt252, felt252, felt252)> = array![
    (5279154705627724249993186093248666011, 345561521626566187713367793525016877467, -1, -1)
    ];
    // Verify Q_low = sum(scalar_low * P for scalar_low,P in zip(scalars_low, points))
    zk_ecip_check(
    points, epns_low, hint.Q_low, n, mb, hint.SumDlogDivLow, random_point, curve_index
    );
    // Verify Q_high = sum(scalar_high * P for scalar_high,P in zip(scalars_high, points))
    zk_ecip_check(
    points, epns_high, hint.Q_high, n, mb, hint.SumDlogDivHigh, random_point, curve_index
    );
    // Verify Q_high_shifted = 2^128 * Q_high
    zk_ecip_check(
    array![hint.Q_high].span(),
    epns_shifted,
    hint.Q_high_shifted,
    1,
    mb,
    hint.SumDlogDivHighShifted,
    random_point,
    curve_index
    );
    // Return Q_low + Q_high_shifted = Q_low + 2^128 * Q_high = Σ(ki * Pi)
    return ec_safe_add(hint.Q_low, hint.Q_high_shifted, curve_index);
    }
    // Verifies the mutli scalar multiplication of a set of points on a given curve is equal to
    // hint.Q
    // Uses https://eprint.iacr.org/2022/596.pdf eq 3 and samples a random EC point from the inputs and
    // the hint.
    fn msm_g1_u128(
    scalars_digits_decompositions: Option<Span<Span<felt252>>>,
    hint: MSMHintSmallScalar,
    derive_point_from_x_hint: DerivePointFromXHint,
    points: Span<G1Point>,
    scalars: Span<u128>,
    curve_index: usize
    ) -> G1Point {
    let n = scalars.len();
    assert!(n == points.len(), "scalars and points length mismatch");
    if n == 0 {
    panic!("Msm size must be >= 1");
    }
    // Check result points are either on curve, or the point at infinity
    if !hint.Q.is_infinity() {
    hint.Q.assert_on_curve(curve_index);
    }
    // Validate the degrees of the functions field elements given the msm size
    hint.SumDlogDiv.validate_degrees(n);
    // Hash everything to obtain a x coordinate.
    let (s0, s1, s2): (felt252, felt252, felt252) = hades_permutation(
    'MSM_G1_U128', 0, 1
    ); // Init Sponge state
    let (s0, s1, s2) = hades_permutation(
    s0 + curve_index.into(), s1 + n.into(), s2
    ); // Include curve_index and msm size
    let (s0, s1, s2) = hint.SumDlogDiv.update_hash_state(s0, s1, s2);
    let mut s0 = s0;
    let mut s1 = s1;
    let mut s2 = s2;
    // Check input points are on curve and hash them at the same time.
    for point in points {
    if !point.is_infinity() {
    point.assert_on_curve(curve_index);
    }
    let (_s0, _s1, _s2) = point.update_hash_state(s0, s1, s2);
    s0 = _s0;
    s1 = _s1;
    s2 = _s2;
    };
    // Hash result points
    let (s0, s1, s2) = hint.Q.update_hash_state(s0, s1, s2);
    // Hash scalars. No need to check if scalar is below curve order since it is always at most 128
    // bits.
    let mut s0 = s0;
    let mut s1 = s1;
    let mut s2 = s2;
    for scalar in scalars {
    let (_s0, _s1, _s2) = core::poseidon::hades_permutation(s0 + (*scalar).into(), s1, s2);
    s0 = _s0;
    s1 = _s1;
    s2 = _s2;
    };
    let random_point: G1Point = derive_ec_point_from_X(
    s0,
    derive_point_from_x_hint.y_last_attempt,
    derive_point_from_x_hint.g_rhs_sqrt,
    curve_index
    );
    // Get slope, intercept and other constant from random point
    let (mb): (SlopeInterceptOutput,) = ec::run_SLOPE_INTERCEPT_SAME_POINT_circuit(
    random_point, get_a(curve_index), curve_index
    );
    // Get positive and negative multiplicities of low and high part of scalars
    let epns = neg_3::u128_array_to_epns(scalars, scalars_digits_decompositions);
    // Verify Q = sum(scalar * P for scalar,P in zip(scalars, points))
    zk_ecip_check(points, epns, hint.Q, n, mb, hint.SumDlogDiv, random_point, curve_index);
    return hint.Q;
    }

  • create a python binding for this msm_calldata_builder:

    • call it from by adding an extra flag use_rust. If set to True, the class should call a new method _serialize_to_calldata_rust(self, options) , serialize correctly the class members and call the rust binding.
    • create a test in tests/hydra/starknet/calldata.py. The test should compare using pytest fixtures the rust and python implementation over all curves, msm size from 1 to 2, and all 4 options (include_digits, include_points_and_scalars, serialize_as_pure_felt252_array, risc0_mode) similarly to
      @pytest.mark.parametrize("curve_id", curves)
      @pytest.mark.parametrize("msm_size", range(1, 5))
      def test_verify_ecip(curve_id, msm_size):
      curve = CURVES[curve_id.value]
      order = curve.n
      # Test for G1 points
      Bs_G1 = [G1Point.get_nG(curve_id, 1) for _ in range(msm_size)]
      scalars = [random.randint(1, order - 1) for _ in range(msm_size)]
      Q, sum_dlog = zk_ecip_hint(Bs_G1, scalars, use_rust=False)
      Q_rust, sum_dlog_rust = zk_ecip_hint(Bs_G1, scalars, use_rust=True)
      assert Q == Q_rust, f"Q: {Q}, \nQ_rust: {Q_rust}"
      assert (
      sum_dlog == sum_dlog_rust
      ), f"sum_dlog: {sum_dlog}, \nsum_dlog_rust: {sum_dlog_rust}"
@feltroidprime feltroidprime added priority-medium This issue may be useful, and needs some attention. and removed priority-medium This issue may be useful, and needs some attention. labels Sep 10, 2024
@raugfer raugfer mentioned this issue Sep 10, 2024
10 tasks
This issue was closed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants