Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

task referencer #18914

Merged
merged 17 commits into from
Jan 2, 2025
7 changes: 4 additions & 3 deletions benchmarks/mempool.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from chia.types.spend_bundle import SpendBundle
from chia.util.batches import to_batches
from chia.util.ints import uint32, uint64
from chia.util.task_referencer import create_referenced_task

NUM_ITERS = 200
NUM_PEERS = 5
Expand Down Expand Up @@ -189,7 +190,7 @@ async def add_spend_bundles(spend_bundles: list[SpendBundle]) -> None:
start = monotonic()
for peer in range(NUM_PEERS):
total_bundles += len(large_spend_bundles[peer])
tasks.append(asyncio.create_task(add_spend_bundles(large_spend_bundles[peer])))
tasks.append(create_referenced_task(add_spend_bundles(large_spend_bundles[peer])))
await asyncio.gather(*tasks)
stop = monotonic()
print(f" time: {stop - start:0.4f}s")
Expand All @@ -208,7 +209,7 @@ async def add_spend_bundles(spend_bundles: list[SpendBundle]) -> None:
start = monotonic()
for peer in range(NUM_PEERS):
total_bundles += len(spend_bundles[peer])
tasks.append(asyncio.create_task(add_spend_bundles(spend_bundles[peer])))
tasks.append(create_referenced_task(add_spend_bundles(spend_bundles[peer])))
await asyncio.gather(*tasks)
stop = monotonic()
print(f" time: {stop - start:0.4f}s")
Expand All @@ -221,7 +222,7 @@ async def add_spend_bundles(spend_bundles: list[SpendBundle]) -> None:
start = monotonic()
for peer in range(NUM_PEERS):
total_bundles += len(replacement_spend_bundles[peer])
tasks.append(asyncio.create_task(add_spend_bundles(replacement_spend_bundles[peer])))
tasks.append(create_referenced_task(add_spend_bundles(replacement_spend_bundles[peer])))
await asyncio.gather(*tasks)
stop = monotonic()
print(f" time: {stop - start:0.4f}s")
Expand Down
3 changes: 2 additions & 1 deletion chia/_tests/core/data_layer/test_data_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
from chia.util.hash import std_hash
from chia.util.ints import uint8, uint16, uint32, uint64
from chia.util.keychain import bytes_to_mnemonic
from chia.util.task_referencer import create_referenced_task
from chia.util.timing import adjusted_timeout, backoff_times
from chia.wallet.trading.offer import Offer as TradingOffer
from chia.wallet.transaction_record import TransactionRecord
Expand Down Expand Up @@ -2191,7 +2192,7 @@ async def test_issue_15955_deadlock(
while time.monotonic() < end:
with anyio.fail_after(adjusted_timeout(timeout)):
await asyncio.gather(
*(asyncio.create_task(data_layer.get_value(store_id=store_id, key=key)) for _ in range(10))
*(create_referenced_task(data_layer.get_value(store_id=store_id, key=key)) for _ in range(10))
)


Expand Down
5 changes: 3 additions & 2 deletions chia/_tests/core/farmer/test_farmer_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from asyncio import Task, create_task, gather, sleep
from asyncio import Task, gather, sleep
from collections.abc import Coroutine
from typing import Any, Optional, TypeVar

Expand All @@ -20,13 +20,14 @@
from chia.server.outbound_message import Message, NodeType
from chia.util.hash import std_hash
from chia.util.ints import uint8, uint32, uint64
from chia.util.task_referencer import create_referenced_task

T = TypeVar("T")


async def begin_task(coro: Coroutine[Any, Any, T]) -> Task[T]:
"""Awaitable function that adds a coroutine to the event loop and sets it running."""
task = create_task(coro)
task = create_referenced_task(coro)
await sleep(0)

return task
Expand Down
5 changes: 3 additions & 2 deletions chia/_tests/core/full_node/stores/test_block_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from chia.util.db_wrapper import get_host_parameter_limit
from chia.util.full_block_utils import GeneratorBlockInfo
from chia.util.ints import uint8, uint32, uint64
from chia.util.task_referencer import create_referenced_task

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -242,12 +243,12 @@ async def test_deadlock(tmp_dir: Path, db_version: int, bt: BlockTools, use_cach
rand_i = random.randint(0, 9)
if random.random() < 0.5:
tasks.append(
asyncio.create_task(
create_referenced_task(
store.add_full_block(blocks[rand_i].header_hash, blocks[rand_i], block_records[rand_i])
)
)
if random.random() < 0.5:
tasks.append(asyncio.create_task(store.get_full_block(blocks[rand_i].header_hash)))
tasks.append(create_referenced_task(store.get_full_block(blocks[rand_i].header_hash)))
await asyncio.gather(*tasks)


Expand Down
13 changes: 6 additions & 7 deletions chia/_tests/core/full_node/test_full_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
from chia.util.ints import uint8, uint16, uint32, uint64, uint128
from chia.util.limited_semaphore import LimitedSemaphore
from chia.util.recursive_replace import recursive_replace
from chia.util.task_referencer import create_referenced_task
from chia.util.vdf_prover import get_vdf_info_and_proof
from chia.wallet.util.tx_config import DEFAULT_TX_CONFIG
from chia.wallet.wallet_spend_bundle import WalletSpendBundle
Expand Down Expand Up @@ -807,13 +808,13 @@ async def test_new_peak(self, wallet_nodes, self_hostname):
uint32(0),
block.reward_chain_block.get_unfinished().get_hash(),
)
task_1 = asyncio.create_task(full_node_1.new_peak(new_peak, dummy_peer))
task_1 = create_referenced_task(full_node_1.new_peak(new_peak, dummy_peer))
await time_out_assert(10, time_out_messages(incoming_queue, "request_block", 1))
task_1.cancel()

await full_node_1.full_node.add_block(block, peer)
# Ignores, already have
task_2 = asyncio.create_task(full_node_1.new_peak(new_peak, dummy_peer))
task_2 = create_referenced_task(full_node_1.new_peak(new_peak, dummy_peer))
await time_out_assert(10, time_out_messages(incoming_queue, "request_block", 0))
task_2.cancel()

Expand All @@ -829,8 +830,7 @@ async def suppress_value_error(coro: Coroutine) -> None:
uint32(0),
blocks_reorg[-2].reward_chain_block.get_unfinished().get_hash(),
)
# TODO: stop dropping tasks on the floor
asyncio.create_task(suppress_value_error(full_node_1.new_peak(new_peak, dummy_peer))) # noqa: RUF006
create_referenced_task(suppress_value_error(full_node_1.new_peak(new_peak, dummy_peer)))
await time_out_assert(10, time_out_messages(incoming_queue, "request_block", 0))

# Does not ignore equal weight
Expand All @@ -841,8 +841,7 @@ async def suppress_value_error(coro: Coroutine) -> None:
uint32(0),
blocks_reorg[-1].reward_chain_block.get_unfinished().get_hash(),
)
# TODO: stop dropping tasks on the floor
asyncio.create_task(suppress_value_error(full_node_1.new_peak(new_peak, dummy_peer))) # noqa: RUF006
create_referenced_task(suppress_value_error(full_node_1.new_peak(new_peak, dummy_peer)))
await time_out_assert(10, time_out_messages(incoming_queue, "request_block", 1))

@pytest.mark.anyio
Expand Down Expand Up @@ -1568,7 +1567,7 @@ async def test_double_blocks_same_pospace(self, wallet_nodes, self_hostname):
block_2 = recursive_replace(block_2, "foliage.foliage_transaction_block_signature", new_fbh_sig)
block_2 = recursive_replace(block_2, "transactions_generator", None)

rb_task = asyncio.create_task(full_node_2.full_node.add_block(block_2, dummy_peer))
rb_task = create_referenced_task(full_node_2.full_node.add_block(block_2, dummy_peer))

await time_out_assert(10, time_out_messages(incoming_queue, "request_block", 1))
rb_task.cancel()
Expand Down
3 changes: 2 additions & 1 deletion chia/_tests/core/full_node/test_tx_processing_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from chia.full_node.tx_processing_queue import TransactionQueue, TransactionQueueFull
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.transaction_queue_entry import TransactionQueueEntry
from chia.util.task_referencer import create_referenced_task

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -76,7 +77,7 @@ async def test_one_peer_and_await(seeded_random: random.Random) -> None:
assert list_txs[i - 20] == resulting_txs[i]

# now we validate that the pop command is blocking
task = asyncio.create_task(transaction_queue.pop())
task = create_referenced_task(transaction_queue.pop())
with pytest.raises(asyncio.InvalidStateError): # task is not done, so we expect an error when getting result
task.result()
# add a tx to test task completion
Expand Down
5 changes: 3 additions & 2 deletions chia/_tests/core/server/flood.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import time

from chia._tests.util.misc import create_logger
from chia.util.task_referencer import create_referenced_task

# TODO: CAMPid 0945094189459712842390t591
IP = "127.0.0.1"
Expand Down Expand Up @@ -62,15 +63,15 @@ async def dun() -> None:

task.cancel()

file_task = asyncio.create_task(dun())
file_task = create_referenced_task(dun())

with out_path.open(mode="w") as file:
logger = create_logger(file=file)

async def f() -> None:
await asyncio.gather(*[tcp_echo_client(task_counter=f"{i}", logger=logger) for i in range(0, NUM_CLIENTS)])

task = asyncio.create_task(f())
task = create_referenced_task(f())
try:
await task
except asyncio.CancelledError:
Expand Down
3 changes: 2 additions & 1 deletion chia/_tests/core/server/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from chia._tests.util.misc import create_logger
from chia.server.chia_policy import ChiaPolicy
from chia.server.start_service import async_run
from chia.util.task_referencer import create_referenced_task

if sys.platform == "win32":
import _winapi
Expand Down Expand Up @@ -86,7 +87,7 @@ async def dun() -> None:

thread_end_event.set()

file_task = asyncio.create_task(dun())
file_task = create_referenced_task(dun())

loop = asyncio.get_event_loop()
server = await loop.create_server(functools.partial(EchoServer, logger=logger), ip, port)
Expand Down
3 changes: 2 additions & 1 deletion chia/_tests/core/server/test_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from chia._tests.core.server import serve
from chia._tests.util.misc import create_logger
from chia.server import chia_policy
from chia.util.task_referencer import create_referenced_task
from chia.util.timing import adjusted_timeout

here = pathlib.Path(__file__).parent
Expand Down Expand Up @@ -123,7 +124,7 @@ def _run(self) -> None:
asyncio.set_event_loop_policy(original_event_loop_policy)

async def main(self) -> None:
self.server_task = asyncio.create_task(
self.server_task = create_referenced_task(
serve.async_main(
out_path=self.out_path,
ip=self.ip,
Expand Down
13 changes: 7 additions & 6 deletions chia/_tests/db/test_db_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from chia._tests.util.db_connection import DBConnection, PathDBConnection
from chia._tests.util.misc import Marks, boolean_datacases, datacases
from chia.util.db_wrapper import DBWrapper2, ForeignKeyError, InternalError, NestedForeignKeyDelayedRequestError
from chia.util.task_referencer import create_referenced_task

if TYPE_CHECKING:
ConnectionContextManager = contextlib.AbstractAsyncContextManager[aiosqlite.core.Connection]
Expand Down Expand Up @@ -119,7 +120,7 @@ async def test_concurrent_writers(acquire_outside: bool, get_reader_method: GetR

tasks = []
for index in range(concurrent_task_count):
task = asyncio.create_task(increment_counter(db_wrapper))
task = create_referenced_task(increment_counter(db_wrapper))
tasks.append(task)

await asyncio.wait_for(asyncio.gather(*tasks), timeout=None)
Expand Down Expand Up @@ -263,7 +264,7 @@ async def write() -> None:
async with get_reader() as reader:
assert await query_value(connection=reader) == 0

task = asyncio.create_task(write())
task = create_referenced_task(write())
await writer_committed.wait()

assert await query_value(connection=reader) == 0 if transactioned else 1
Expand Down Expand Up @@ -342,7 +343,7 @@ async def test_concurrent_readers(acquire_outside: bool, get_reader_method: GetR
tasks = []
values: list[int] = []
for index in range(concurrent_task_count):
task = asyncio.create_task(sum_counter(db_wrapper, values))
task = create_referenced_task(sum_counter(db_wrapper, values))
tasks.append(task)

await asyncio.wait_for(asyncio.gather(*tasks), timeout=None)
Expand Down Expand Up @@ -371,11 +372,11 @@ async def test_mixed_readers_writers(acquire_outside: bool, get_reader_method: G
tasks = []
values: list[int] = []
for index in range(concurrent_task_count):
task = asyncio.create_task(increment_counter(db_wrapper))
task = create_referenced_task(increment_counter(db_wrapper))
tasks.append(task)
task = asyncio.create_task(decrement_counter(db_wrapper))
task = create_referenced_task(decrement_counter(db_wrapper))
tasks.append(task)
task = asyncio.create_task(sum_counter(db_wrapper, values))
task = create_referenced_task(sum_counter(db_wrapper, values))
tasks.append(task)

await asyncio.wait_for(asyncio.gather(*tasks), timeout=None)
Expand Down
7 changes: 4 additions & 3 deletions chia/_tests/util/test_limited_semaphore.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

from chia.util.limited_semaphore import LimitedSemaphore, LimitedSemaphoreFullError
from chia.util.task_referencer import create_referenced_task


@pytest.mark.anyio
Expand All @@ -27,16 +28,16 @@ async def acquire(entered_event: Optional[asyncio.Event] = None) -> None:
waiting_events = [asyncio.Event() for _ in range(waiting_limit)]
failed_events = [asyncio.Event() for _ in range(beyond_limit)]

entered_tasks = [asyncio.create_task(acquire(entered_event=event)) for event in entered_events]
waiting_tasks = [asyncio.create_task(acquire(entered_event=event)) for event in waiting_events]
entered_tasks = [create_referenced_task(acquire(entered_event=event)) for event in entered_events]
waiting_tasks = [create_referenced_task(acquire(entered_event=event)) for event in waiting_events]

await asyncio.gather(*(event.wait() for event in entered_events))
assert all(event.is_set() for event in entered_events)
assert all(not event.is_set() for event in waiting_events)

assert semaphore._available_count == 0

failure_tasks = [asyncio.create_task(acquire()) for _ in range(beyond_limit)]
failure_tasks = [create_referenced_task(acquire()) for _ in range(beyond_limit)]

failure_results = await asyncio.gather(*failure_tasks, return_exceptions=True)
assert [str(error) for error in failure_results] == [str(LimitedSemaphoreFullError())] * beyond_limit
Expand Down
17 changes: 9 additions & 8 deletions chia/_tests/util/test_priority_mutex.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from chia._tests.util.misc import Marks, datacases
from chia._tests.util.time_out_assert import time_out_assert_custom_interval
from chia.util.priority_mutex import NestedLockUnsupportedError, PriorityMutex
from chia.util.task_referencer import create_referenced_task
from chia.util.timing import adjusted_timeout

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -65,10 +66,10 @@ async def do_low(i: int) -> None:
log.warning(f"Spend {time.time() - t1} waiting for low {i}")
await kind_of_slow_func()

h = asyncio.create_task(do_high())
h = create_referenced_task(do_high())
l_tasks = []
for i in range(50):
l_tasks.append(asyncio.create_task(do_low(i)))
l_tasks.append(create_referenced_task(do_low(i)))

winner = None

Expand Down Expand Up @@ -334,13 +335,13 @@ async def queued_after() -> None:
async with mutex.acquire(priority=MutexPriority.high):
pass

block_task = asyncio.create_task(block())
block_task = create_referenced_task(block())
await blocker_acquired_event.wait()

cancel_task = asyncio.create_task(to_be_cancelled(mutex=mutex))
cancel_task = create_referenced_task(to_be_cancelled(mutex=mutex))
await wait_queued(mutex=mutex, task=cancel_task)

queued_after_task = asyncio.create_task(queued_after())
queued_after_task = create_referenced_task(queued_after())
await wait_queued(mutex=mutex, task=queued_after_task)

cancel_task.cancel()
Expand Down Expand Up @@ -441,7 +442,7 @@ async def create_acquire_tasks_in_controlled_order(
release_event = asyncio.Event()

for request in requests:
task = asyncio.create_task(request.acquire(mutex=mutex, wait_for=release_event))
task = create_referenced_task(request.acquire(mutex=mutex, wait_for=release_event))
tasks.append(task)
await wait_queued(mutex=mutex, task=task)

Expand All @@ -461,14 +462,14 @@ async def other_task_function() -> None:
await other_task_allow_release_event.wait()

async with mutex.acquire(priority=MutexPriority.high):
other_task = asyncio.create_task(other_task_function())
other_task = create_referenced_task(other_task_function())
await wait_queued(mutex=mutex, task=other_task)

async def another_task_function() -> None:
async with mutex.acquire(priority=MutexPriority.high):
pass

another_task = asyncio.create_task(another_task_function())
another_task = create_referenced_task(another_task_function())
await wait_queued(mutex=mutex, task=another_task)
other_task_allow_release_event.set()

Expand Down
7 changes: 3 additions & 4 deletions chia/daemon/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from chia.util.ints import uint32
from chia.util.json_util import dict_to_json_str
from chia.util.task_referencer import create_referenced_task
from chia.util.ws_message import WsRpcMessage, create_payload_dict


Expand Down Expand Up @@ -67,8 +68,7 @@ async def listener_task() -> None:
finally:
await self.close()

# TODO: stop dropping tasks on the floor
asyncio.create_task(listener_task()) # noqa: RUF006
create_referenced_task(listener_task(), known_unreferenced=True)
await asyncio.sleep(1)

async def listener(self) -> None:
Expand All @@ -92,8 +92,7 @@ async def _get(self, request: WsRpcMessage) -> WsRpcMessage:
string = dict_to_json_str(request)
if self.websocket is None or self.websocket.closed:
raise Exception("Websocket is not connected")
# TODO: stop dropping tasks on the floor
asyncio.create_task(self.websocket.send_str(string)) # noqa: RUF006
create_referenced_task(self.websocket.send_str(string), known_unreferenced=True)
try:
await asyncio.wait_for(self._request_dict[request_id].wait(), timeout=30)
self._request_dict.pop(request_id)
Expand Down
Loading
Loading