Skip to content

Commit

Permalink
Mempool multifetch (#17139)
Browse files Browse the repository at this point in the history
* update type annotation for CoinStore.get_coin_records to support both List and Set

* update the mempool to fetch multiple coin records per query

* optimize the slow-path of updating the mempool by fetching all coin records up-front, in a single sql query
  • Loading branch information
arvidn authored Dec 22, 2023
1 parent a4cab82 commit 507899f
Show file tree
Hide file tree
Showing 9 changed files with 148 additions and 65 deletions.
11 changes: 8 additions & 3 deletions benchmarks/mempool-long-lived.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
from dataclasses import dataclass
from time import monotonic
from typing import Dict, Optional
from typing import Collection, Dict, List, Optional

from chia_rs import G2Element
from clvm.casts import int_to_bytes
Expand Down Expand Up @@ -81,8 +81,13 @@ def fake_block_record(block_height: uint32, timestamp: uint64) -> BenchBlockReco
async def run_mempool_benchmark() -> None:
coin_records: Dict[bytes32, CoinRecord] = {}

async def get_coin_record(coin_id: bytes32) -> Optional[CoinRecord]:
return coin_records.get(coin_id)
async def get_coin_record(coin_ids: Collection[bytes32]) -> List[CoinRecord]:
ret: List[CoinRecord] = []
for name in coin_ids:
r = coin_records.get(name)
if r is not None:
ret.append(r)
return ret

timestamp = uint64(1631794488)

Expand Down
15 changes: 10 additions & 5 deletions benchmarks/mempool.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dataclasses import dataclass
from subprocess import check_call
from time import monotonic
from typing import Dict, Iterator, List, Optional, Tuple
from typing import Collection, Dict, Iterator, List, Optional, Tuple

from chia.consensus.coinbase import create_farmer_coin, create_pool_coin
from chia.consensus.default_constants import DEFAULT_CONSTANTS
Expand Down Expand Up @@ -78,8 +78,13 @@ def fake_block_record(block_height: uint32, timestamp: uint64) -> BenchBlockReco
async def run_mempool_benchmark() -> None:
all_coins: Dict[bytes32, CoinRecord] = {}

async def get_coin_record(coin_id: bytes32) -> Optional[CoinRecord]:
return all_coins.get(coin_id)
async def get_coin_records(coin_ids: Collection[bytes32]) -> List[CoinRecord]:
ret: List[CoinRecord] = []
for name in coin_ids:
r = all_coins.get(name)
if r is not None:
ret.append(r)
return ret

wt = WalletTool(DEFAULT_CONSTANTS)

Expand Down Expand Up @@ -156,7 +161,7 @@ async def get_coin_record(coin_id: bytes32) -> Optional[CoinRecord]:
else:
print("\n== Multi-threaded")

mempool = MempoolManager(get_coin_record, DEFAULT_CONSTANTS, single_threaded=single_threaded)
mempool = MempoolManager(get_coin_records, DEFAULT_CONSTANTS, single_threaded=single_threaded)

height = start_height
rec = fake_block_record(height, timestamp)
Expand Down Expand Up @@ -186,7 +191,7 @@ async def add_spend_bundles(spend_bundles: List[SpendBundle]) -> None:
print(f" time: {stop - start:0.4f}s")
print(f" per call: {(stop - start) / total_bundles * 1000:0.2f}ms")

mempool = MempoolManager(get_coin_record, DEFAULT_CONSTANTS, single_threaded=single_threaded)
mempool = MempoolManager(get_coin_records, DEFAULT_CONSTANTS, single_threaded=single_threaded)

height = start_height
rec = fake_block_record(height, timestamp)
Expand Down
2 changes: 1 addition & 1 deletion chia/clvm/spend_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ async def create(
self.db_wrapper = await DBWrapper2.create(database=uri, uri=True, reader_count=1, db_version=2)

self.coin_store = await CoinStore.create(self.db_wrapper)
self.mempool_manager = MempoolManager(self.coin_store.get_coin_record, defaults)
self.mempool_manager = MempoolManager(self.coin_store.get_coin_records, defaults)
self.defaults = defaults

# Load the next data if there is any
Expand Down
2 changes: 1 addition & 1 deletion chia/full_node/coin_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ async def get_coin_record(self, coin_name: bytes32) -> Optional[CoinRecord]:
return CoinRecord(coin, row[0], row[1], row[2], row[6])
return None

async def get_coin_records(self, names: List[bytes32]) -> List[CoinRecord]:
async def get_coin_records(self, names: Collection[bytes32]) -> List[CoinRecord]:
if len(names) == 0:
return []

Expand Down
2 changes: 1 addition & 1 deletion chia/full_node/full_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ async def manage(self) -> AsyncIterator[None]:
)

self._mempool_manager = MempoolManager(
get_coin_record=self.coin_store.get_coin_record,
get_coin_records=self.coin_store.get_coin_records,
consensus_constants=self.constants,
multiprocessing_context=self.multiprocessing_context,
single_threaded=single_threaded,
Expand Down
73 changes: 59 additions & 14 deletions chia/full_node/mempool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from concurrent.futures.process import ProcessPoolExecutor
from dataclasses import dataclass
from multiprocessing.context import BaseContext
from typing import Awaitable, Callable, Dict, List, Optional, Set, Tuple, TypeVar
from typing import Awaitable, Callable, Collection, Dict, List, Optional, Set, Tuple, TypeVar

from chia_rs import ELIGIBLE_FOR_DEDUP, GTElement
from chiabip158 import PyBIP158
Expand Down Expand Up @@ -146,7 +146,7 @@ class MempoolManager:
pool: Executor
constants: ConsensusConstants
seen_bundle_hashes: Dict[bytes32, bytes32]
get_coin_record: Callable[[bytes32], Awaitable[Optional[CoinRecord]]]
get_coin_records: Callable[[Collection[bytes32]], Awaitable[List[CoinRecord]]]
nonzero_fee_minimum_fpc: int
mempool_max_total_cost: int
# a cache of MempoolItems that conflict with existing items in the pool
Expand All @@ -159,7 +159,7 @@ class MempoolManager:

def __init__(
self,
get_coin_record: Callable[[bytes32], Awaitable[Optional[CoinRecord]]],
get_coin_records: Callable[[Collection[bytes32]], Awaitable[List[CoinRecord]]],
consensus_constants: ConsensusConstants,
multiprocessing_context: Optional[BaseContext] = None,
*,
Expand All @@ -170,7 +170,7 @@ def __init__(
# Keep track of seen spend_bundles
self.seen_bundle_hashes: Dict[bytes32, bytes32] = {}

self.get_coin_record = get_coin_record
self.get_coin_records = get_coin_records

# The fee per cost must be above this amount to consider the fee "nonzero", and thus able to kick out other
# transactions. This prevents spam. This is equivalent to 0.055 XCH per block, or about 0.00005 XCH for two
Expand Down Expand Up @@ -303,7 +303,12 @@ async def pre_validate_spendbundle(
return ret

async def add_spend_bundle(
self, new_spend: SpendBundle, npc_result: NPCResult, spend_name: bytes32, first_added_height: uint32
self,
new_spend: SpendBundle,
npc_result: NPCResult,
spend_name: bytes32,
first_added_height: uint32,
get_coin_records: Optional[Callable[[Collection[bytes32]], Awaitable[List[CoinRecord]]]] = None,
) -> Tuple[Optional[uint64], MempoolInclusionStatus, Optional[Err]]:
"""
Validates and adds to mempool a new_spend with the given NPCResult, and spend_name, and the current mempool.
Expand All @@ -327,8 +332,14 @@ async def add_spend_bundle(
if existing_item is not None:
return existing_item.cost, MempoolInclusionStatus.SUCCESS, None

if get_coin_records is None:
get_coin_records = self.get_coin_records
err, item, remove_items = await self.validate_spend_bundle(
new_spend, npc_result, spend_name, first_added_height
new_spend,
npc_result,
spend_name,
first_added_height,
get_coin_records,
)
if err is None:
# No error, immediately add to mempool, after removing conflicting TXs.
Expand Down Expand Up @@ -358,6 +369,7 @@ async def validate_spend_bundle(
npc_result: NPCResult,
spend_name: bytes32,
first_added_height: uint32,
get_coin_records: Callable[[Collection[bytes32]], Awaitable[List[CoinRecord]]],
) -> Tuple[Optional[Err], Optional[MempoolItem], List[bytes32]]:
"""
Validates new_spend with the given NPCResult, and spend_name, and the current mempool. The mempool should
Expand Down Expand Up @@ -420,11 +432,14 @@ async def validate_spend_bundle(

removal_record_dict: Dict[bytes32, CoinRecord] = {}
removal_amount: int = 0
removal_records = await get_coin_records(removal_names)
for record in removal_records:
removal_record_dict[record.coin.name()] = record

for name in removal_names:
removal_record = await self.get_coin_record(name)
if removal_record is None and name not in additions_dict:
if name not in removal_record_dict and name not in additions_dict:
return Err.UNKNOWN_UNSPENT, None, []
elif name in additions_dict:
if name in additions_dict:
removal_coin = additions_dict[name]
# The timestamp and block-height of this coin being spent needs
# to be consistent with what we use to check time-lock
Expand All @@ -440,10 +455,10 @@ async def validate_spend_bundle(
False,
self.peak.timestamp,
)

assert removal_record is not None
removal_record_dict[name] = removal_record
else:
removal_record = removal_record_dict[name]
removal_amount = removal_amount + removal_record.coin.amount
removal_record_dict[name] = removal_record

fees = uint64(removal_amount - addition_amount)

Expand Down Expand Up @@ -623,9 +638,35 @@ async def new_peak(
old_pool = self.mempool
self.mempool = Mempool(old_pool.mempool_info, old_pool.fee_estimator)
self.seen_bundle_hashes = {}

# in order to make this a bit quicker, we look-up all the spends in
# a single query, rather than one at a time.
coin_records: Dict[bytes32, CoinRecord] = {}

removals: Set[bytes32] = set()
for item in old_pool.all_items():
for s in item.spend_bundle.coin_spends:
removals.add(s.coin.name())

for record in await self.get_coin_records(removals):
name = record.coin.name()
coin_records[name] = record

async def local_get_coin_records(names: Collection[bytes32]) -> List[CoinRecord]:
ret: List[CoinRecord] = []
for name in names:
r = coin_records.get(name)
if r is not None:
ret.append(r)
return ret

for item in old_pool.all_items():
_, result, err = await self.add_spend_bundle(
item.spend_bundle, item.npc_result, item.spend_bundle_name, item.height_added_to_mempool
item.spend_bundle,
item.npc_result,
item.spend_bundle_name,
item.height_added_to_mempool,
local_get_coin_records,
)
# Only add to `seen` if inclusion worked, so it can be resubmitted in case of a reorg
if result == MempoolInclusionStatus.SUCCESS:
Expand All @@ -642,7 +683,11 @@ async def new_peak(
txs_added = []
for item in potential_txs.values():
cost, status, error = await self.add_spend_bundle(
item.spend_bundle, item.npc_result, item.spend_bundle_name, item.height_added_to_mempool
item.spend_bundle,
item.npc_result,
item.spend_bundle_name,
item.height_added_to_mempool,
self.get_coin_records,
)
if status == MempoolInclusionStatus.SUCCESS:
txs_added.append((item.spend_bundle, item.npc_result, item.spend_bundle_name))
Expand Down
2 changes: 1 addition & 1 deletion tests/core/mempool/test_mempool_fee_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ async def test_basics() -> None:
async def test_fee_increase() -> None:
async with DBConnection(db_version=2) as db_wrapper:
coin_store = await CoinStore.create(db_wrapper)
mempool_manager = MempoolManager(coin_store.get_coin_record, test_constants)
mempool_manager = MempoolManager(coin_store.get_coin_records, test_constants)
assert test_constants.MAX_BLOCK_COST_CLVM == mempool_manager.constants.MAX_BLOCK_COST_CLVM
btc_fee_estimator: BitcoinFeeEstimator = mempool_manager.mempool.fee_estimator # type: ignore
fee_tracker = btc_fee_estimator.get_tracker()
Expand Down
Loading

0 comments on commit 507899f

Please sign in to comment.