Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use KeccakBuiltin for address recovery and signature verification #232

Merged
merged 6 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 2 additions & 34 deletions cairo/src/precompiles/ec_recover.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -79,24 +79,14 @@ namespace PrecompileEcRecover {

let (r_bigint) = uint256_to_bigint(r);
let (s_bigint) = uint256_to_bigint(s);
let (public_key_point, success) = Signature.try_recover_public_key(
let (success, recovered_address) = Signature.try_recover_eth_address(
msg_hash_bigint, r_bigint, s_bigint, v - 27
);

if (success == 0) {
let (output) = alloc();
return (0, output, GAS_COST_EC_RECOVER, 0);
}
let (is_public_key_invalid) = EcRecoverHelpers.ec_point_equal(
public_key_point, EcPoint(BigInt3(0, 0, 0), BigInt3(0, 0, 0))
);
if (is_public_key_invalid != 0) {
let (output) = alloc();
return (0, output, GAS_COST_EC_RECOVER, 0);
}

let (recovered_address) = EcRecoverHelpers.public_key_point_to_eth_address(
public_key_point
);

let (output) = alloc();
memset(output, 0, 12);
Expand All @@ -115,26 +105,4 @@ namespace EcRecoverHelpers {
}
return (is_equal=0);
}

// @notice Convert a public key point to the corresponding Ethereum address.
// @dev Uses the `KeccakBuiltin` builtin, while the one in Starkware's CairoZero library does not.
func public_key_point_to_eth_address{
pedersen_ptr: HashBuiltin*,
range_check_ptr,
bitwise_ptr: BitwiseBuiltin*,
keccak_ptr: KeccakBuiltin*,
}(public_key_point: EcPoint) -> (eth_address: felt) {
alloc_locals;
let (local elements: Uint256*) = alloc();
let (x_uint256: Uint256) = bigint_to_uint256(public_key_point.x);
assert elements[0] = x_uint256;
let (y_uint256: Uint256) = bigint_to_uint256(public_key_point.y);
assert elements[1] = y_uint256;

let (point_hash) = keccak_uint256s_bigend(n_elements=2, elements=elements);

// The Ethereum address is the 20 least significant bytes of the keccak of the public key.
let (high_high, high_low) = unsigned_div_rem(point_hash.high, 2 ** 32);
return (eth_address=point_hash.low + RC_BOUND * high_low);
}
}
99 changes: 73 additions & 26 deletions cairo/src/utils/signature.cairo
Original file line number Diff line number Diff line change
@@ -1,25 +1,27 @@
from starkware.cairo.common.cairo_builtins import BitwiseBuiltin
from starkware.cairo.common.cairo_secp.bigint3 import BigInt3, UnreducedBigInt3
from starkware.cairo.common.cairo_builtins import BitwiseBuiltin, KeccakBuiltin
from starkware.cairo.common.cairo_secp.bigint3 import BigInt3
from starkware.cairo.common.cairo_secp.ec_point import EcPoint
from starkware.cairo.common.cairo_secp.signature import (
validate_signature_entry,
try_get_point_from_x,
get_generator_point,
div_mod_n,
)
from starkware.cairo.common.math_cmp import RC_BOUND
from starkware.cairo.common.cairo_secp.bigint import bigint_to_uint256, uint256_to_bigint
from starkware.cairo.common.builtin_keccak.keccak import keccak_uint256s_bigend
from starkware.cairo.common.cairo_secp.ec import ec_add, ec_mul, ec_negate
from starkware.cairo.common.uint256 import Uint256
from starkware.cairo.common.alloc import alloc
from starkware.cairo.common.cairo_secp.bigint import uint256_to_bigint
from src.utils.maths import unsigned_div_rem

from src.interfaces.interfaces import ICairo1Helpers

namespace Signature {
// A version of verify_eth_signature, with that msg_hash, r and s as Uint256 and
// using the Cairo1 helpers class.
func verify_eth_signature_uint256{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(
msg_hash: Uint256, r: Uint256, s: Uint256, y_parity: felt, eth_address: felt
) {
// A version of verify_eth_signature that uses the keccak builtin.
func verify_eth_signature_uint256{
range_check_ptr, bitwise_ptr: BitwiseBuiltin*, keccak_ptr: KeccakBuiltin*
}(msg_hash: Uint256, r: Uint256, s: Uint256, y_parity: felt, eth_address: felt) {
alloc_locals;
let (msg_hash_bigint: BigInt3) = uint256_to_bigint(msg_hash);
let (r_bigint: BigInt3) = uint256_to_bigint(r);
Expand All @@ -35,34 +37,34 @@ namespace Signature {
}

with_attr error_message("Invalid signature.") {
let (success, recovered_address) = ICairo1Helpers.recover_eth_address(
msg_hash=msg_hash, r=r, s=s, y_parity=y_parity
let (success, recovered_address) = try_recover_eth_address(
msg_hash=msg_hash_bigint, r=r_bigint, s=s_bigint, y_parity=y_parity
);
// TODO: uncomment when we have a working recover_eth_address
// assert success = 1;
// assert eth_address = recovered_address;
assert success = 1;
}

assert eth_address = recovered_address;
return ();
}

// Similar to `recover_public_key`, but handles the case where 'x' does not correspond to a point on the
// @notice Similar to `recover_public_key`, but handles the case where 'x' does not correspond to a point on the
// curve gracefully.
// Receives a signature and the signed message hash.
// Returns the public key associated with the signer, represented as a point on the curve, and `true` if valid.
// Returns the point (0, 0) and `false` otherwise.
// Note:
// Some places use the values 27 and 28 instead of 0 and 1 for v.
// In that case, a subtraction by 27 returns a v that can be used by this function.
// Prover assumptions:
// * r is the x coordinate of some nonzero point on the curve.
// * All the limbs of s and msg_hash are in the range (-2 ** 210.99, 2 ** 210.99).
// * All the limbs of r are in the range (-2 ** 124.99, 2 ** 124.99).
// @param msg_hash The signed message hash.
// @param r The r value of the signature.
// @param s The s value of the signature.
// @param y_parity The y parity value of the signature. true if odd, false if even.
// @return The public key associated with the signer, represented as a point on the curve, and `true` if valid.
// @return The point (0, 0) and `false` otherwise.
// @dev Prover assumptions:
// @dev * r is the x coordinate of some nonzero point on the curve.
// @dev * All the limbs of s and msg_hash are in the range (-2 ** 210.99, 2 ** 210.99).
// @dev * All the limbs of r are in the range (-2 ** 124.99, 2 ** 124.99).
func try_recover_public_key{range_check_ptr}(
msg_hash: BigInt3, r: BigInt3, s: BigInt3, v: felt
msg_hash: BigInt3, r: BigInt3, s: BigInt3, y_parity: felt
) -> (public_key_point: EcPoint, success: felt) {
alloc_locals;
let (local r_point: EcPoint*) = alloc();
let (is_on_curve) = try_get_point_from_x(x=r, v=v, result=r_point);
let (is_on_curve) = try_get_point_from_x(x=r, v=y_parity, result=r_point);
if (is_on_curve == 0) {
enitrat marked this conversation as resolved.
Show resolved Hide resolved
return (public_key_point=EcPoint(x=BigInt3(0, 0, 0), y=BigInt3(0, 0, 0)), success=0);
}
Expand All @@ -84,4 +86,49 @@ namespace Signature {
let (public_key_point) = ec_add(minus_point1, point2);
return (public_key_point=public_key_point, success=1);
}

// @notice Recovers the Ethereum address from a signature.
// @dev If the public key point is not on the curve, the function returns success=0.
// @dev: This function does not validate the r, s values.
// @param msg_hash The signed message hash.
// @param r The r value of the signature.
// @param s The s value of the signature.
// @param y_parity The y parity value of the signature. true if odd, false if even.
// @return The Ethereum address.
func try_recover_eth_address{
range_check_ptr, bitwise_ptr: BitwiseBuiltin*, keccak_ptr: KeccakBuiltin*
}(msg_hash: BigInt3, r: BigInt3, s: BigInt3, y_parity: felt) -> (success: felt, address: felt) {
alloc_locals;
let (public_key_point, success) = try_recover_public_key(
msg_hash=msg_hash, r=r, s=s, y_parity=y_parity
);
if (success == 0) {
return (success=0, address=0);
}
let (x_uint256) = bigint_to_uint256(public_key_point.x);
let (y_uint256) = bigint_to_uint256(public_key_point.y);
let address = Internals.public_key_point_to_eth_address(x=x_uint256, y=y_uint256);
return (success=success, address=address);
}
}

namespace Internals {
// @notice Converts a public key point to the corresponding Ethereum address.
// @param x The x coordinate of the public key point.
// @param y The y coordinate of the public key point.
// @return The Ethereum address.
func public_key_point_to_eth_address{
range_check_ptr, bitwise_ptr: BitwiseBuiltin*, keccak_ptr: KeccakBuiltin*
}(x: Uint256, y: Uint256) -> felt {
alloc_locals;
let (local elements: Uint256*) = alloc();
assert elements[0] = x;
assert elements[1] = y;
let (point_hash: Uint256) = keccak_uint256s_bigend(n_elements=2, elements=elements);

ClementWalter marked this conversation as resolved.
Show resolved Hide resolved
// The Ethereum address is the 20 least significant bytes of the keccak of the public key.
let (_, high_low) = unsigned_div_rem(point_hash.high, 2 ** 32);
let eth_address = point_hash.low + RC_BOUND * high_low;
return eth_address;
}
}
9 changes: 6 additions & 3 deletions cairo/tests/ethereum/cancun/test_trie.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,14 @@ def test_bytes_to_nibble_list(self, cairo_run, bytes_: Bytes):
# def test_root(self, cairo_run, trie, get_storage_root):
# assert root(trie, get_storage_root) == cairo_run("root", trie, get_storage_root)

@given(obj=st.dictionaries(nibble, bytes32), nibble=uint4, level=uint4)
@given(
obj=st.dictionaries(nibble, bytes32).filter(lambda x: len(x) > 0),
ClementWalter marked this conversation as resolved.
Show resolved Hide resolved
nibble=uint4,
level=uint4,
)
def test_get_branche_for_nibble_at_level(self, cairo_run, obj, nibble, level):
assume(
len(obj) > 0 # no empty objects
and min(len(k) for k in obj) > level # longer than level
min(len(k) for k in obj) > level # longer than level
and len({k[:level] for k in obj}) == 1 # same prefix at level
)
branche, value = cairo_run(
Expand Down
3 changes: 2 additions & 1 deletion cairo/tests/fixtures/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ def _factory(entrypoint, *args, **kwargs):
output_ptr = runner.segments.add()
stack.append(output_ptr)
else:
stack.append(gen_arg(python_type, kwargs.get(arg_name, args[i])))
arg_value = kwargs[arg_name] if arg_name in kwargs else args[i]
stack.append(gen_arg(python_type, arg_value))

return_fp = runner.execution_base + 2
end = runner.program_base + len(runner.program.data)
Expand Down
21 changes: 11 additions & 10 deletions cairo/tests/src/precompiles/test_ec_recover.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@ def ecrecover(data):
s = U256.from_be_bytes(data[96:128])

if v != 27 and v != 28:
return
return []
if 0 >= r or r >= SECP256K1N:
return
return []
if 0 >= s or s >= SECP256K1N:
return
return []

try:
public_key = secp256k1_recover(r, s, v - 27, message_hash)
except ValueError:
# unable to extract public key
return
return []

address = keccak256(public_key)[12:32]
padded_address = left_pad_zero_bytes(address, 32)
Expand All @@ -40,11 +40,11 @@ class TestEcRecover:
def test_valid_signature(self, message, cairo_run):
"""Test with valid signatures generated from random messages."""
private_key = generate_random_private_key()
msg = keccak256(message)
(v, r, s) = ec_sign(msg, private_key)
msg_hash = keccak256(message)
(v, r, s) = ec_sign(msg_hash, private_key)

input_data = [
*msg,
*msg_hash,
*v.to_bytes(32, "big"),
*r,
*s,
Expand Down Expand Up @@ -75,8 +75,9 @@ def test_invalid_v(self, v, msg, r, s, cairo_run):
*r.to_bytes(32, "big"),
*s.to_bytes(32, "big"),
]
py_result = ecrecover(input_data)
[output] = cairo_run("test__ec_recover", input=input_data)
assert output == []
assert output == py_result

@given(
v=st.integers(min_value=27, max_value=28),
Expand All @@ -96,8 +97,8 @@ def test_parameter_boundaries(self, cairo_run, v, msg, r, s):
py_result = ecrecover(input_data)
[cairo_result] = cairo_run("test__ec_recover", input=input_data)

if py_result is None:
assert cairo_result == []
if len(py_result) == 0:
assert cairo_result == py_result
else:
py_address, _ = py_result
assert bytes(cairo_result) == bytes(py_address)
36 changes: 30 additions & 6 deletions cairo/tests/src/utils/test_signature.cairo
Original file line number Diff line number Diff line change
@@ -1,13 +1,37 @@
from starkware.cairo.common.cairo_builtins import BitwiseBuiltin

from ethereum.base_types import U256
from src.utils.signature import Signature

func test__verify_eth_signature_uint256{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(
msg_hash: U256, r: U256, s: U256, y_parity: felt, eth_address: felt
) {
from starkware.cairo.common.cairo_builtins import BitwiseBuiltin, KeccakBuiltin
from starkware.cairo.common.uint256 import Uint256
from src.utils.signature import Signature, Internals
from starkware.cairo.common.cairo_secp.bigint import uint256_to_bigint

func test__public_key_point_to_eth_address{
range_check_ptr, bitwise_ptr: BitwiseBuiltin*, keccak_ptr: KeccakBuiltin*
}(x: U256, y: U256) -> felt {
let eth_address = Internals.public_key_point_to_eth_address(x=[x.value], y=[y.value]);

return eth_address;
}

func test__verify_eth_signature_uint256{
range_check_ptr, bitwise_ptr: BitwiseBuiltin*, keccak_ptr: KeccakBuiltin*
}(msg_hash: U256, r: U256, s: U256, y_parity: felt, eth_address: felt) {
Signature.verify_eth_signature_uint256(
[msg_hash.value], [r.value], [s.value], y_parity, eth_address
);
return ();
}

func test__try_recover_eth_address{
range_check_ptr, bitwise_ptr: BitwiseBuiltin*, keccak_ptr: KeccakBuiltin*
}(msg_hash: U256, r: U256, s: U256, y_parity: felt) -> (success: felt, address: felt) {
let (msg_hash_bigint) = uint256_to_bigint([msg_hash.value]);
let (r_bigint) = uint256_to_bigint([r.value]);
let (s_bigint) = uint256_to_bigint([s.value]);

let (success, address) = Signature.try_recover_eth_address(
msg_hash=msg_hash_bigint, r=r_bigint, s=s_bigint, y_parity=y_parity
);

return (success=success, address=address);
}
Loading
Loading