Skip to content

Commit

Permalink
Merge master into chiado-network
Browse files Browse the repository at this point in the history
Signed-off-by: cyc60 <avsysoev60@gmail.com>
  • Loading branch information
cyc60 committed Mar 27, 2024
2 parents ca18922 + 0139700 commit 8a40922
Show file tree
Hide file tree
Showing 11 changed files with 111 additions and 166 deletions.
54 changes: 9 additions & 45 deletions src/common/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import click
from eth_typing import BlockNumber
from sw_utils import InterruptHandler
from sw_utils import InterruptHandler, ProtocolConfig, build_protocol_config
from web3 import Web3
from web3._utils.async_transactions import _max_fee_per_gas
from web3.exceptions import BadFunctionCallOutput
Expand All @@ -13,7 +13,7 @@
from src.common.contracts import keeper_contract, multicall_contract, vault_contract
from src.common.metrics import metrics
from src.common.tasks import BaseTask
from src.common.typings import Oracles, OraclesCache
from src.common.typings import OraclesCache
from src.common.wallet import hot_wallet
from src.config.settings import settings

Expand Down Expand Up @@ -101,53 +101,17 @@ async def update_oracles_cache() -> None:
)


async def get_oracles() -> Oracles:
async def get_protocol_config() -> ProtocolConfig:
await update_oracles_cache()

oracles_cache = cast(OraclesCache, _oracles_cache)

config = oracles_cache.config
rewards_threshold = oracles_cache.rewards_threshold
validators_threshold = oracles_cache.validators_threshold

endpoints = []
public_keys = []
for oracle in config['oracles']:
endpoints.append(oracle['endpoints'])
public_keys.append(oracle['public_key'])

if not 1 <= rewards_threshold <= len(config['oracles']):
raise ValueError('Invalid rewards threshold')

if not 1 <= validators_threshold <= len(config['oracles']):
raise ValueError('Invalid validators threshold')

exit_signature_recover_threshold = config['exit_signature_recover_threshold']

if exit_signature_recover_threshold > validators_threshold:
raise ValueError('Invalid exit signature threshold')

signature_validity_period = config['signature_validity_period']

if signature_validity_period < 0:
raise ValueError('Invalid signature validity period')

if len(public_keys) != len(set(public_keys)):
raise ValueError('Duplicate public keys in oracles config')

validators_approval_batch_limit = config['validators_approval_batch_limit']
validators_exit_rotation_batch_limit = config['validators_exit_rotation_batch_limit']

return Oracles(
rewards_threshold=rewards_threshold,
validators_threshold=validators_threshold,
exit_signature_recover_threshold=exit_signature_recover_threshold,
signature_validity_period=signature_validity_period,
public_keys=public_keys,
endpoints=endpoints,
validators_approval_batch_limit=validators_approval_batch_limit,
validators_exit_rotation_batch_limit=validators_exit_rotation_batch_limit,
pc = build_protocol_config(
config_data=oracles_cache.config,
rewards_threshold=oracles_cache.rewards_threshold,
validators_threshold=oracles_cache.validators_threshold,
)
pc.rewards_threshold = cast(int, pc.rewards_threshold)
return pc


async def get_high_priority_tx_params() -> TxParams:
Expand Down
5 changes: 3 additions & 2 deletions src/common/startup_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from src.common.execution import (
check_hot_wallet_balance,
check_vault_address,
get_oracles,
get_protocol_config,
)
from src.common.utils import format_error, warning_verbose
from src.common.wallet import hot_wallet
Expand Down Expand Up @@ -110,7 +110,8 @@ async def wait_for_execution_node() -> None:


async def collect_healthy_oracles() -> list:
endpoints = (await get_oracles()).endpoints
oracles = (await get_protocol_config()).oracles
endpoints = [oracle.endpoints for oracle in oracles]

async with ClientSession(timeout=ClientTimeout(60)) as session:
results = await asyncio.gather(
Expand Down
27 changes: 1 addition & 26 deletions src/common/typings.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,9 @@
from dataclasses import dataclass
from functools import cached_property

from eth_keys.datatypes import PublicKey
from eth_typing import BlockNumber, ChecksumAddress, HexStr
from web3 import Web3
from eth_typing import BlockNumber
from web3.types import Wei


@dataclass
# pylint: disable-next=too-many-instance-attributes
class Oracles:
rewards_threshold: int
validators_threshold: int
exit_signature_recover_threshold: int
signature_validity_period: int
public_keys: list[HexStr]
endpoints: list[list[str]]

validators_approval_batch_limit: int
validators_exit_rotation_batch_limit: int

@cached_property
def addresses(self) -> list[ChecksumAddress]:
res = []
for public_key in self.public_keys:
public_key_obj = PublicKey(Web3.to_bytes(hexstr=public_key))
res.append(public_key_obj.to_checksum_address())
return res


@dataclass
class OraclesCache:
checkpoint_block: BlockNumber
Expand Down
33 changes: 15 additions & 18 deletions src/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,17 @@
from Cryptodome.Protocol.KDF import scrypt as raw_scrypt
from eth_typing import HexAddress, HexStr
from sw_utils.tests import faker
from sw_utils.tests.factories import get_mocked_protocol_config
from sw_utils.typings import Oracle, ProtocolConfig

from src.commands.create_keys import create_keys
from src.commands.create_wallet import create_wallet
from src.commands.remote_signer_setup import remote_signer_setup
from src.common.credentials import CredentialManager, ScryptKeystore
from src.common.typings import Oracles
from src.common.vault_config import VaultConfig
from src.config.networks import HOLESKY
from src.config.settings import settings
from src.test_fixtures.hashi_vault import hashi_vault_url, mocked_hashi_vault
from src.test_fixtures.hashi_vault import hashi_vault_url, mocked_hashi_vault # noqa
from src.test_fixtures.remote_signer import mocked_remote_signer, remote_signer_url
from src.validators.keystores.remote import RemoteSignerKeystore
from src.validators.signing.tests.oracle_functions import OracleCommittee
Expand Down Expand Up @@ -156,7 +157,7 @@ def _remote_signer_setup(
remote_signer_url: str,
execution_endpoints: str,
runner: CliRunner,
mocked_oracles: Oracles,
mocked_protocol_config: ProtocolConfig,
mocked_remote_signer,
_create_keys,
) -> None:
Expand Down Expand Up @@ -247,24 +248,20 @@ def _mocked_oracle_committee(request: SubRequest) -> OracleCommittee:


@pytest.fixture
def mocked_oracles(
def mocked_protocol_config(
_mocked_oracle_committee: OracleCommittee,
) -> Oracles:
) -> ProtocolConfig:
exit_signature_recover_threshold = _mocked_oracle_committee.exit_signature_recover_threshold

return Oracles(
rewards_threshold=1,
validators_threshold=1,
oracles = []
for index, pub_key in enumerate(_mocked_oracle_committee.oracle_pubkeys):
oracle = Oracle(
public_key=HexStr(pub_key.format(compressed=False)[1:].hex()),
endpoints=[f'http://oracle-endpoint-{index}'],
)
oracles.append(oracle)
return get_mocked_protocol_config(
oracles=oracles,
exit_signature_recover_threshold=exit_signature_recover_threshold,
# Strip first byte (04 prefix) from pubkey
public_keys=[
HexStr(pubkey.format(compressed=False)[1:].hex())
for pubkey in _mocked_oracle_committee.oracle_pubkeys
],
endpoints=[
[f'http://oracle-endpoint-{i}']
for i in range(len(_mocked_oracle_committee.oracle_pubkeys))
],
validators_approval_batch_limit=1,
validators_exit_rotation_batch_limit=2,
signature_validity_period=60,
Expand Down
48 changes: 28 additions & 20 deletions src/exits/tasks.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
import asyncio
import logging
import time
from itertools import chain
from random import shuffle

from aiohttp import ClientError
from eth_typing import BlockNumber
from sw_utils import InterruptHandler
from sw_utils.typings import Oracle, ProtocolConfig
from tenacity import RetryError
from web3.types import HexStr

from src.common.contracts import keeper_contract
from src.common.exceptions import NotEnoughOracleApprovalsError
from src.common.execution import get_oracles
from src.common.execution import get_protocol_config
from src.common.metrics import metrics
from src.common.tasks import BaseTask
from src.common.typings import Oracles
from src.common.utils import get_current_timestamp, is_block_finalized, warning_verbose
from src.config.settings import settings
from src.exits.consensus import get_validator_public_keys
Expand All @@ -38,30 +39,33 @@ async def process_block(self, interrupt_handler: InterruptHandler) -> None:
if self.keystore is None:
return

oracles = await get_oracles()
protocol_config = await get_protocol_config()
update_block = await _fetch_last_update_block()
if update_block and not await is_block_finalized(update_block):
logger.info('Waiting for signatures update block %d to finalize...', update_block)
return

if update_block and not await _check_majority_oracles_synced(oracles, update_block):
if update_block and not await _check_majority_oracles_synced(protocol_config, update_block):
logger.info('Waiting for the majority of oracles to sync exit signatures')
return

outdated_indexes = await _fetch_outdated_indexes(oracles, update_block)
outdated_indexes = await _fetch_outdated_indexes(protocol_config.oracles, update_block)
if outdated_indexes:
await _update_exit_signatures(
keystore=self.keystore,
oracles=oracles,
protocol_config=protocol_config,
outdated_indexes=outdated_indexes,
)


async def _check_majority_oracles_synced(oracles: Oracles, update_block: BlockNumber) -> bool:
threshold = oracles.validators_threshold
async def _check_majority_oracles_synced(
protocol_config: ProtocolConfig, update_block: BlockNumber
) -> bool:
threshold = protocol_config.validators_threshold
endpoints = [oracle.endpoints for oracle in protocol_config.oracles]

pending = {
asyncio.create_task(_fetch_last_update_block_replicas(replicas))
for replicas in oracles.endpoints
asyncio.create_task(_fetch_last_update_block_replicas(replicas)) for replicas in endpoints
}
while pending:
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
Expand Down Expand Up @@ -97,8 +101,10 @@ async def _fetch_last_update_block() -> BlockNumber | None:
return None


async def _fetch_outdated_indexes(oracles: Oracles, update_block: BlockNumber | None) -> list[int]:
endpoints = [endpoint for replicas in oracles.endpoints for endpoint in replicas]
async def _fetch_outdated_indexes(
oracles: list[Oracle], update_block: BlockNumber | None
) -> list[int]:
endpoints = list(chain.from_iterable([oracle.endpoints for oracle in oracles]))
shuffle(endpoints)

for oracle_endpoint in endpoints:
Expand All @@ -118,7 +124,7 @@ async def _fetch_outdated_indexes(oracles: Oracles, update_block: BlockNumber |

async def _update_exit_signatures(
keystore: BaseKeystore,
oracles: Oracles,
protocol_config: ProtocolConfig,
outdated_indexes: list[int],
) -> None:
"""Fetches update signature requests from oracles."""
Expand All @@ -133,9 +139,9 @@ async def _update_exit_signatures(

current_timestamp = get_current_timestamp()
if not deadline or deadline <= current_timestamp:
deadline = current_timestamp + oracles.signature_validity_period
deadline = current_timestamp + protocol_config.signature_validity_period
oracles_request = await _get_oracles_request(
oracles=oracles,
protocol_config=protocol_config,
keystore=keystore,
validators=validators,
)
Expand All @@ -145,7 +151,9 @@ async def _update_exit_signatures(
return
try:
# send approval request to oracles
oracles_approval = await send_signature_rotation_requests(oracles, oracles_request)
oracles_approval = await send_signature_rotation_requests(
protocol_config, oracles_request
)
break
except NotEnoughOracleApprovalsError as e:
logger.error(
Expand Down Expand Up @@ -176,7 +184,7 @@ async def _fetch_exit_signature_block(oracle_endpoint: str) -> BlockNumber | Non


async def _get_oracles_request(
oracles: Oracles,
protocol_config: ProtocolConfig,
keystore: BaseKeystore,
validators: dict[int, HexStr],
) -> SignatureRotationRequest:
Expand All @@ -187,10 +195,10 @@ async def _get_oracles_request(
public_keys=[],
public_key_shards=[],
exit_signature_shards=[],
deadline=get_current_timestamp() + oracles.signature_validity_period,
deadline=get_current_timestamp() + protocol_config.signature_validity_period,
)
failed_indexes = []
exit_rotation_batch_limit = oracles.validators_exit_rotation_batch_limit
exit_rotation_batch_limit = protocol_config.validators_exit_rotation_batch_limit

for validator_index, public_key in validators.items():
if len(request.public_keys) >= exit_rotation_batch_limit:
Expand All @@ -201,7 +209,7 @@ async def _get_oracles_request(
keystore=keystore,
public_key=public_key,
validator_index=validator_index,
oracles=oracles,
protocol_config=protocol_config,
)
else:
failed_indexes.append(validator_index)
Expand Down
Loading

0 comments on commit 8a40922

Please sign in to comment.