From f2d4b1c0e42d9b6046922a4655b2942d9f6db4f8 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 14 Oct 2024 11:48:56 -0400 Subject: [PATCH 1/9] PYTHON-4860 - Async client should use asyncio.Lock and asyncio.Condition --- pymongo/asynchronous/cursor.py | 4 +- pymongo/asynchronous/mongo_client.py | 12 +- pymongo/asynchronous/pool.py | 24 +- pymongo/asynchronous/topology.py | 15 +- pymongo/lock.py | 223 +--- pymongo/synchronous/cursor.py | 2 +- pymongo/synchronous/mongo_client.py | 10 +- pymongo/synchronous/pool.py | 24 +- pymongo/synchronous/topology.py | 15 +- test/asynchronous/test_locks.py | 1026 ++++++++--------- test/conftest.py | 3 +- test/pymongo_mocks.py | 1 + test/test_auth.py | 2 +- test/test_auth_spec.py | 2 +- test/test_bulk.py | 2 +- test/test_change_stream.py | 2 +- test/test_client.py | 4 +- test/test_client_bulk_write.py | 2 +- test/test_client_context.py | 2 +- test/test_collation.py | 2 +- test/test_collection.py | 4 +- test/test_common.py | 2 +- ...nnections_survive_primary_stepdown_spec.py | 4 +- test/test_cursor.py | 2 +- test/test_database.py | 3 +- test/test_encryption.py | 8 +- test/test_grid_file.py | 2 +- test/test_logger.py | 3 +- test/test_monitoring.py | 2 +- test/test_raw_bson.py | 2 +- test/test_retryable_reads.py | 2 +- test/test_retryable_writes.py | 4 +- test/test_session.py | 2 +- test/test_transactions.py | 4 +- test/utils_spec_runner.py | 2 +- tools/synchro.py | 46 +- 36 files changed, 635 insertions(+), 834 deletions(-) diff --git a/pymongo/asynchronous/cursor.py b/pymongo/asynchronous/cursor.py index 4b4bb52a8e..7d7ae4a5db 100644 --- a/pymongo/asynchronous/cursor.py +++ b/pymongo/asynchronous/cursor.py @@ -45,7 +45,7 @@ ) from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS, _QUERY_OPTIONS, CursorType, _Hint, _Sort from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure -from pymongo.lock import _ALock, _create_lock +from pymongo.lock import _async_create_lock from pymongo.message import ( _CursorAddress, _GetMore, @@ -77,7 +77,7 @@ class _ConnectionManager: def __init__(self, conn: AsyncConnection, more_to_come: bool): self.conn: Optional[AsyncConnection] = conn self.more_to_come = more_to_come - self._alock = _ALock(_create_lock()) + self._lock = _async_create_lock() def update_exhaust(self, more_to_come: bool) -> None: self.more_to_come = more_to_come diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index d2b45fd64a..b788043ea4 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -82,7 +82,11 @@ WaitQueueTimeoutError, WriteConcernError, ) -from pymongo.lock import _HAS_REGISTER_AT_FORK, _ALock, _create_lock, _release_locks +from pymongo.lock import ( + _HAS_REGISTER_AT_FORK, + _async_create_lock, + _release_locks, +) from pymongo.logger import _CLIENT_LOGGER, _log_or_warn from pymongo.message import _CursorAddress, _GetMore, _Query from pymongo.monitoring import ConnectionClosedReason @@ -842,7 +846,7 @@ def __init__( self._options = options = ClientOptions(username, password, dbase, opts, _IS_SYNC) self._default_database_name = dbase - self._lock = _ALock(_create_lock()) + self._lock = _async_create_lock() self._kill_cursors_queue: list = [] self._event_listeners = options.pool_options._event_listeners @@ -1728,7 +1732,7 @@ async def _run_operation( address=address, ) - async with operation.conn_mgr._alock: + async with operation.conn_mgr._lock: async with _MongoClientErrorHandler(self, server, operation.session) as err_handler: # type: ignore[arg-type] err_handler.contribute_socket(operation.conn_mgr.conn) return await server.run_operation( @@ -1976,7 +1980,7 @@ async def _close_cursor_now( try: if conn_mgr: - async with conn_mgr._alock: + async with conn_mgr._lock: # Cursor is pinned to LB outside of a transaction. assert address is not None assert conn_mgr.conn is not None diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index a9f02d650a..532c20aa63 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -65,7 +65,11 @@ _CertificateError, ) from pymongo.hello import Hello, HelloCompat -from pymongo.lock import _ACondition, _ALock, _create_lock +from pymongo.lock import ( + _async_cond_wait, + _async_create_condition, + _async_create_lock, +) from pymongo.logger import ( _CONNECTION_LOGGER, _ConnectionStatusMessage, @@ -208,11 +212,6 @@ def _raise_connection_failure( raise AutoReconnect(msg) from error -async def _cond_wait(condition: _ACondition, deadline: Optional[float]) -> bool: - timeout = deadline - time.monotonic() if deadline else None - return await condition.wait(timeout) - - def _get_timeout_details(options: PoolOptions) -> dict[str, float]: details = {} timeout = _csot.get_timeout() @@ -992,8 +991,9 @@ def __init__( # from the right side. self.conns: collections.deque = collections.deque() self.active_contexts: set[_CancellationContext] = set() - _lock = _create_lock() - self.lock = _ALock(_lock) + self.lock = _async_create_lock() + self.size_cond = _async_create_condition(self.lock, threading.Condition) + self._max_connecting_cond = _async_create_condition(self.lock, threading.Condition) self.active_sockets = 0 # Monotonically increasing connection ID required for CMAP Events. self.next_connection_id = 1 @@ -1019,7 +1019,6 @@ def __init__( # The first portion of the wait queue. # Enforces: maxPoolSize # Also used for: clearing the wait queue - self.size_cond = _ACondition(threading.Condition(_lock)) self.requests = 0 self.max_pool_size = self.opts.max_pool_size if not self.max_pool_size: @@ -1027,7 +1026,6 @@ def __init__( # The second portion of the wait queue. # Enforces: maxConnecting # Also used for: clearing the wait queue - self._max_connecting_cond = _ACondition(threading.Condition(_lock)) self._max_connecting = self.opts.max_connecting self._pending = 0 self._client_id = client_id @@ -1456,7 +1454,8 @@ async def _get_conn( async with self.size_cond: self._raise_if_not_ready(checkout_started_time, emit_event=True) while not (self.requests < self.max_pool_size): - if not await _cond_wait(self.size_cond, deadline): + timeout = deadline - time.monotonic() if deadline else None + if not await _async_cond_wait(self.size_cond, timeout): # Timed out, notify the next thread to ensure a # timeout doesn't consume the condition. if self.requests < self.max_pool_size: @@ -1479,7 +1478,8 @@ async def _get_conn( async with self._max_connecting_cond: self._raise_if_not_ready(checkout_started_time, emit_event=False) while not (self.conns or self._pending < self._max_connecting): - if not await _cond_wait(self._max_connecting_cond, deadline): + timeout = deadline - time.monotonic() if deadline else None + if not await _async_cond_wait(self._max_connecting_cond, timeout): # Timed out, notify the next thread to ensure a # timeout doesn't consume the condition. if self.conns or self._pending < self._max_connecting: diff --git a/pymongo/asynchronous/topology.py b/pymongo/asynchronous/topology.py index f0cb56cbf1..460fa15837 100644 --- a/pymongo/asynchronous/topology.py +++ b/pymongo/asynchronous/topology.py @@ -43,7 +43,11 @@ WriteError, ) from pymongo.hello import Hello -from pymongo.lock import _ACondition, _ALock, _create_lock +from pymongo.lock import ( + _async_cond_wait, + _async_create_condition, + _async_create_lock, +) from pymongo.logger import ( _SDAM_LOGGER, _SERVER_SELECTION_LOGGER, @@ -169,9 +173,8 @@ def __init__(self, topology_settings: TopologySettings): self._seed_addresses = list(topology_description.server_descriptions()) self._opened = False self._closed = False - _lock = _create_lock() - self._lock = _ALock(_lock) - self._condition = _ACondition(self._settings.condition_class(_lock)) + self._lock = _async_create_lock() + self._condition = _async_create_condition(self._lock, self._settings.condition_class) self._servers: dict[_Address, Server] = {} self._pid: Optional[int] = None self._max_cluster_time: Optional[ClusterTime] = None @@ -353,7 +356,7 @@ async def _select_servers_loop( # change, or for a timeout. We won't miss any changes that # came after our most recent apply_selector call, since we've # held the lock until now. - await self._condition.wait(common.MIN_HEARTBEAT_INTERVAL) + await _async_cond_wait(self._condition, common.MIN_HEARTBEAT_INTERVAL) self._description.check_compatible() now = time.monotonic() server_descriptions = self._description.apply_selector( @@ -653,7 +656,7 @@ async def request_check_all(self, wait_time: int = 5) -> None: """Wake all monitors, wait for at least one to check its server.""" async with self._lock: self._request_check_all() - await self._condition.wait(wait_time) + await _async_cond_wait(self._condition, wait_time) def data_bearing_servers(self) -> list[ServerDescription]: """Return a list of all data-bearing servers. diff --git a/pymongo/lock.py b/pymongo/lock.py index 0cbfb4a57e..52a400893e 100644 --- a/pymongo/lock.py +++ b/pymongo/lock.py @@ -14,12 +14,10 @@ from __future__ import annotations import asyncio -import collections import os import threading -import time import weakref -from typing import Any, Callable, Optional, TypeVar +from typing import Any, Optional, TypeVar _HAS_REGISTER_AT_FORK = hasattr(os, "register_at_fork") @@ -39,6 +37,21 @@ def _create_lock() -> threading.Lock: return lock +def _async_create_lock() -> asyncio.Lock: + """Represents an asyncio.Lock.""" + return asyncio.Lock() + + +def _create_condition(lock: threading.Lock, dummy: Any) -> threading.Condition: + """Represents a threading.Condition.""" + return threading.Condition(lock) + + +def _async_create_condition(lock: asyncio.Lock, dummy: Any) -> asyncio.Condition: + """Represents an asyncio.Condition.""" + return asyncio.Condition(lock) + + def _release_locks() -> None: # Completed the fork, reset all the locks in the child. for lock in _forkable_locks: @@ -46,202 +59,12 @@ def _release_locks() -> None: lock.release() -# Needed only for synchro.py compat. -def _Lock(lock: threading.Lock) -> threading.Lock: - return lock +async def _async_cond_wait(condition: asyncio.Condition, timeout: Optional[float]) -> bool: + try: + return await asyncio.wait_for(condition.wait(), timeout) + except TimeoutError: + return False -class _ALock: - __slots__ = ("_lock",) - - def __init__(self, lock: threading.Lock) -> None: - self._lock = lock - - def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: - return self._lock.acquire(blocking=blocking, timeout=timeout) - - async def a_acquire(self, blocking: bool = True, timeout: float = -1) -> bool: - if timeout > 0: - tstart = time.monotonic() - while True: - acquired = self._lock.acquire(blocking=False) - if acquired: - return True - if timeout > 0 and (time.monotonic() - tstart) > timeout: - return False - if not blocking: - return False - await asyncio.sleep(0) - - def release(self) -> None: - self._lock.release() - - async def __aenter__(self) -> _ALock: - await self.a_acquire() - return self - - def __enter__(self) -> _ALock: - self._lock.acquire() - return self - - def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> None: - self.release() - - async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: - self.release() - - -def _safe_set_result(fut: asyncio.Future) -> None: - # Ensure the future hasn't been cancelled before calling set_result. - if not fut.done(): - fut.set_result(False) - - -class _ACondition: - __slots__ = ("_condition", "_waiters") - - def __init__(self, condition: threading.Condition) -> None: - self._condition = condition - self._waiters: collections.deque = collections.deque() - - async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: - if timeout > 0: - tstart = time.monotonic() - while True: - acquired = self._condition.acquire(blocking=False) - if acquired: - return True - if timeout > 0 and (time.monotonic() - tstart) > timeout: - return False - if not blocking: - return False - await asyncio.sleep(0) - - async def wait(self, timeout: Optional[float] = None) -> bool: - """Wait until notified. - - If the calling task has not acquired the lock when this - method is called, a RuntimeError is raised. - - This method releases the underlying lock, and then blocks - until it is awakened by a notify() or notify_all() call for - the same condition variable in another task. Once - awakened, it re-acquires the lock and returns True. - - This method may return spuriously, - which is why the caller should always - re-check the state and be prepared to wait() again. - """ - loop = asyncio.get_running_loop() - fut = loop.create_future() - self._waiters.append((loop, fut)) - self.release() - try: - try: - try: - await asyncio.wait_for(fut, timeout) - return True - except asyncio.TimeoutError: - return False # Return false on timeout for sync pool compat. - finally: - # Must re-acquire lock even if wait is cancelled. - # We only catch CancelledError here, since we don't want any - # other (fatal) errors with the future to cause us to spin. - err = None - while True: - try: - await self.acquire() - break - except asyncio.exceptions.CancelledError as e: - err = e - - self._waiters.remove((loop, fut)) - if err is not None: - try: - raise err # Re-raise most recent exception instance. - finally: - err = None # Break reference cycles. - except BaseException: - # Any error raised out of here _may_ have occurred after this Task - # believed to have been successfully notified. - # Make sure to notify another Task instead. This may result - # in a "spurious wakeup", which is allowed as part of the - # Condition Variable protocol. - self.notify(1) - raise - - async def wait_for(self, predicate: Callable[[], _T]) -> _T: - """Wait until a predicate becomes true. - - The predicate should be a callable whose result will be - interpreted as a boolean value. The method will repeatedly - wait() until it evaluates to true. The final predicate value is - the return value. - """ - result = predicate() - while not result: - await self.wait() - result = predicate() - return result - - def notify(self, n: int = 1) -> None: - """By default, wake up one coroutine waiting on this condition, if any. - If the calling coroutine has not acquired the lock when this method - is called, a RuntimeError is raised. - - This method wakes up at most n of the coroutines waiting for the - condition variable; it is a no-op if no coroutines are waiting. - - Note: an awakened coroutine does not actually return from its - wait() call until it can reacquire the lock. Since notify() does - not release the lock, its caller should. - """ - idx = 0 - to_remove = [] - for loop, fut in self._waiters: - if idx >= n: - break - - if fut.done(): - continue - - try: - loop.call_soon_threadsafe(_safe_set_result, fut) - except RuntimeError: - # Loop was closed, ignore. - to_remove.append((loop, fut)) - continue - - idx += 1 - - for waiter in to_remove: - self._waiters.remove(waiter) - - def notify_all(self) -> None: - """Wake up all threads waiting on this condition. This method acts - like notify(), but wakes up all waiting threads instead of one. If the - calling thread has not acquired the lock when this method is called, - a RuntimeError is raised. - """ - self.notify(len(self._waiters)) - - def locked(self) -> bool: - """Only needed for tests in test_locks.""" - return self._condition._lock.locked() # type: ignore[attr-defined] - - def release(self) -> None: - self._condition.release() - - async def __aenter__(self) -> _ACondition: - await self.acquire() - return self - - def __enter__(self) -> _ACondition: - self._condition.acquire() - return self - - async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: - self.release() - - def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> None: - self.release() +def _cond_wait(condition: threading.Condition, timeout: Optional[float]) -> bool: + return condition.wait(timeout) diff --git a/pymongo/synchronous/cursor.py b/pymongo/synchronous/cursor.py index 27a76cf91d..9a7637704f 100644 --- a/pymongo/synchronous/cursor.py +++ b/pymongo/synchronous/cursor.py @@ -77,7 +77,7 @@ class _ConnectionManager: def __init__(self, conn: Connection, more_to_come: bool): self.conn: Optional[Connection] = conn self.more_to_come = more_to_come - self._alock = _create_lock() + self._lock = _create_lock() def update_exhaust(self, more_to_come: bool) -> None: self.more_to_come = more_to_come diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 8f4d9cacf2..840df452a9 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -74,7 +74,11 @@ WaitQueueTimeoutError, WriteConcernError, ) -from pymongo.lock import _HAS_REGISTER_AT_FORK, _create_lock, _release_locks +from pymongo.lock import ( + _HAS_REGISTER_AT_FORK, + _create_lock, + _release_locks, +) from pymongo.logger import _CLIENT_LOGGER, _log_or_warn from pymongo.message import _CursorAddress, _GetMore, _Query from pymongo.monitoring import ConnectionClosedReason @@ -1722,7 +1726,7 @@ def _run_operation( address=address, ) - with operation.conn_mgr._alock: + with operation.conn_mgr._lock: with _MongoClientErrorHandler(self, server, operation.session) as err_handler: # type: ignore[arg-type] err_handler.contribute_socket(operation.conn_mgr.conn) return server.run_operation( @@ -1970,7 +1974,7 @@ def _close_cursor_now( try: if conn_mgr: - with conn_mgr._alock: + with conn_mgr._lock: # Cursor is pinned to LB outside of a transaction. assert address is not None assert conn_mgr.conn is not None diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index eb007a3471..8172a9f846 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -62,7 +62,11 @@ _CertificateError, ) from pymongo.hello import Hello, HelloCompat -from pymongo.lock import _create_lock, _Lock +from pymongo.lock import ( + _cond_wait, + _create_condition, + _create_lock, +) from pymongo.logger import ( _CONNECTION_LOGGER, _ConnectionStatusMessage, @@ -208,11 +212,6 @@ def _raise_connection_failure( raise AutoReconnect(msg) from error -def _cond_wait(condition: threading.Condition, deadline: Optional[float]) -> bool: - timeout = deadline - time.monotonic() if deadline else None - return condition.wait(timeout) - - def _get_timeout_details(options: PoolOptions) -> dict[str, float]: details = {} timeout = _csot.get_timeout() @@ -988,8 +987,9 @@ def __init__( # from the right side. self.conns: collections.deque = collections.deque() self.active_contexts: set[_CancellationContext] = set() - _lock = _create_lock() - self.lock = _Lock(_lock) + self.lock = _create_lock() + self.size_cond = _create_condition(self.lock, threading.Condition) + self._max_connecting_cond = _create_condition(self.lock, threading.Condition) self.active_sockets = 0 # Monotonically increasing connection ID required for CMAP Events. self.next_connection_id = 1 @@ -1015,7 +1015,6 @@ def __init__( # The first portion of the wait queue. # Enforces: maxPoolSize # Also used for: clearing the wait queue - self.size_cond = threading.Condition(_lock) self.requests = 0 self.max_pool_size = self.opts.max_pool_size if not self.max_pool_size: @@ -1023,7 +1022,6 @@ def __init__( # The second portion of the wait queue. # Enforces: maxConnecting # Also used for: clearing the wait queue - self._max_connecting_cond = threading.Condition(_lock) self._max_connecting = self.opts.max_connecting self._pending = 0 self._client_id = client_id @@ -1450,7 +1448,8 @@ def _get_conn( with self.size_cond: self._raise_if_not_ready(checkout_started_time, emit_event=True) while not (self.requests < self.max_pool_size): - if not _cond_wait(self.size_cond, deadline): + timeout = deadline - time.monotonic() if deadline else None + if not _cond_wait(self.size_cond, timeout): # Timed out, notify the next thread to ensure a # timeout doesn't consume the condition. if self.requests < self.max_pool_size: @@ -1473,7 +1472,8 @@ def _get_conn( with self._max_connecting_cond: self._raise_if_not_ready(checkout_started_time, emit_event=False) while not (self.conns or self._pending < self._max_connecting): - if not _cond_wait(self._max_connecting_cond, deadline): + timeout = deadline - time.monotonic() if deadline else None + if not _cond_wait(self._max_connecting_cond, timeout): # Timed out, notify the next thread to ensure a # timeout doesn't consume the condition. if self.conns or self._pending < self._max_connecting: diff --git a/pymongo/synchronous/topology.py b/pymongo/synchronous/topology.py index e34de6bc50..48c8b81590 100644 --- a/pymongo/synchronous/topology.py +++ b/pymongo/synchronous/topology.py @@ -39,7 +39,11 @@ WriteError, ) from pymongo.hello import Hello -from pymongo.lock import _create_lock, _Lock +from pymongo.lock import ( + _cond_wait, + _create_condition, + _create_lock, +) from pymongo.logger import ( _SDAM_LOGGER, _SERVER_SELECTION_LOGGER, @@ -169,9 +173,8 @@ def __init__(self, topology_settings: TopologySettings): self._seed_addresses = list(topology_description.server_descriptions()) self._opened = False self._closed = False - _lock = _create_lock() - self._lock = _Lock(_lock) - self._condition = self._settings.condition_class(_lock) + self._lock = _create_lock() + self._condition = _create_condition(self._lock, self._settings.condition_class) self._servers: dict[_Address, Server] = {} self._pid: Optional[int] = None self._max_cluster_time: Optional[ClusterTime] = None @@ -353,7 +356,7 @@ def _select_servers_loop( # change, or for a timeout. We won't miss any changes that # came after our most recent apply_selector call, since we've # held the lock until now. - self._condition.wait(common.MIN_HEARTBEAT_INTERVAL) + _cond_wait(self._condition, common.MIN_HEARTBEAT_INTERVAL) self._description.check_compatible() now = time.monotonic() server_descriptions = self._description.apply_selector( @@ -651,7 +654,7 @@ def request_check_all(self, wait_time: int = 5) -> None: """Wake all monitors, wait for at least one to check its server.""" with self._lock: self._request_check_all() - self._condition.wait(wait_time) + _cond_wait(self._condition, wait_time) def data_bearing_servers(self) -> list[ServerDescription]: """Return a list of all data-bearing servers. diff --git a/test/asynchronous/test_locks.py b/test/asynchronous/test_locks.py index e0e7f2fc8d..a029f0792b 100644 --- a/test/asynchronous/test_locks.py +++ b/test/asynchronous/test_locks.py @@ -1,513 +1,513 @@ -# Copyright 2024-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Tests for lock.py""" -from __future__ import annotations - -import asyncio -import sys -import threading -import unittest - -sys.path[0:0] = [""] - -from pymongo.lock import _ACondition - - -# Tests adapted from: https://github.com/python/cpython/blob/v3.13.0rc2/Lib/test/test_asyncio/test_locks.py -# Includes tests for: -# - https://github.com/python/cpython/issues/111693 -# - https://github.com/python/cpython/issues/112202 -class TestConditionStdlib(unittest.IsolatedAsyncioTestCase): - async def test_wait(self): - cond = _ACondition(threading.Condition(threading.Lock())) - result = [] - - async def c1(result): - await cond.acquire() - if await cond.wait(): - result.append(1) - return True - - async def c2(result): - await cond.acquire() - if await cond.wait(): - result.append(2) - return True - - async def c3(result): - await cond.acquire() - if await cond.wait(): - result.append(3) - return True - - t1 = asyncio.create_task(c1(result)) - t2 = asyncio.create_task(c2(result)) - t3 = asyncio.create_task(c3(result)) - - await asyncio.sleep(0) - self.assertEqual([], result) - self.assertFalse(cond.locked()) - - self.assertTrue(await cond.acquire()) - cond.notify() - await asyncio.sleep(0) - self.assertEqual([], result) - self.assertTrue(cond.locked()) - - cond.release() - await asyncio.sleep(0) - self.assertEqual([1], result) - self.assertTrue(cond.locked()) - - cond.notify(2) - await asyncio.sleep(0) - self.assertEqual([1], result) - self.assertTrue(cond.locked()) - - cond.release() - await asyncio.sleep(0) - self.assertEqual([1, 2], result) - self.assertTrue(cond.locked()) - - cond.release() - await asyncio.sleep(0) - self.assertEqual([1, 2, 3], result) - self.assertTrue(cond.locked()) - - self.assertTrue(t1.done()) - self.assertTrue(t1.result()) - self.assertTrue(t2.done()) - self.assertTrue(t2.result()) - self.assertTrue(t3.done()) - self.assertTrue(t3.result()) - - async def test_wait_cancel(self): - cond = _ACondition(threading.Condition(threading.Lock())) - await cond.acquire() - - wait = asyncio.create_task(cond.wait()) - asyncio.get_running_loop().call_soon(wait.cancel) - with self.assertRaises(asyncio.CancelledError): - await wait - self.assertFalse(cond._waiters) - self.assertTrue(cond.locked()) - - async def test_wait_cancel_contested(self): - cond = _ACondition(threading.Condition(threading.Lock())) - - await cond.acquire() - self.assertTrue(cond.locked()) - - wait_task = asyncio.create_task(cond.wait()) - await asyncio.sleep(0) - self.assertFalse(cond.locked()) - - # Notify, but contest the lock before cancelling - await cond.acquire() - self.assertTrue(cond.locked()) - cond.notify() - asyncio.get_running_loop().call_soon(wait_task.cancel) - asyncio.get_running_loop().call_soon(cond.release) - - try: - await wait_task - except asyncio.CancelledError: - # Should not happen, since no cancellation points - pass - - self.assertTrue(cond.locked()) - - async def test_wait_cancel_after_notify(self): - # See bpo-32841 - waited = False - - cond = _ACondition(threading.Condition(threading.Lock())) - - async def wait_on_cond(): - nonlocal waited - async with cond: - waited = True # Make sure this area was reached - await cond.wait() - - waiter = asyncio.create_task(wait_on_cond()) - await asyncio.sleep(0) # Start waiting - - await cond.acquire() - cond.notify() - await asyncio.sleep(0) # Get to acquire() - waiter.cancel() - await asyncio.sleep(0) # Activate cancellation - cond.release() - await asyncio.sleep(0) # Cancellation should occur - - self.assertTrue(waiter.cancelled()) - self.assertTrue(waited) - - async def test_wait_unacquired(self): - cond = _ACondition(threading.Condition(threading.Lock())) - with self.assertRaises(RuntimeError): - await cond.wait() - - async def test_wait_for(self): - cond = _ACondition(threading.Condition(threading.Lock())) - presult = False - - def predicate(): - return presult - - result = [] - - async def c1(result): - await cond.acquire() - if await cond.wait_for(predicate): - result.append(1) - cond.release() - return True - - t = asyncio.create_task(c1(result)) - - await asyncio.sleep(0) - self.assertEqual([], result) - - await cond.acquire() - cond.notify() - cond.release() - await asyncio.sleep(0) - self.assertEqual([], result) - - presult = True - await cond.acquire() - cond.notify() - cond.release() - await asyncio.sleep(0) - self.assertEqual([1], result) - - self.assertTrue(t.done()) - self.assertTrue(t.result()) - - async def test_wait_for_unacquired(self): - cond = _ACondition(threading.Condition(threading.Lock())) - - # predicate can return true immediately - res = await cond.wait_for(lambda: [1, 2, 3]) - self.assertEqual([1, 2, 3], res) - - with self.assertRaises(RuntimeError): - await cond.wait_for(lambda: False) - - async def test_notify(self): - cond = _ACondition(threading.Condition(threading.Lock())) - result = [] - - async def c1(result): - async with cond: - if await cond.wait(): - result.append(1) - return True - - async def c2(result): - async with cond: - if await cond.wait(): - result.append(2) - return True - - async def c3(result): - async with cond: - if await cond.wait(): - result.append(3) - return True - - t1 = asyncio.create_task(c1(result)) - t2 = asyncio.create_task(c2(result)) - t3 = asyncio.create_task(c3(result)) - - await asyncio.sleep(0) - self.assertEqual([], result) - - async with cond: - cond.notify(1) - await asyncio.sleep(1) - self.assertEqual([1], result) - - async with cond: - cond.notify(1) - cond.notify(2048) - await asyncio.sleep(1) - self.assertEqual([1, 2, 3], result) - - self.assertTrue(t1.done()) - self.assertTrue(t1.result()) - self.assertTrue(t2.done()) - self.assertTrue(t2.result()) - self.assertTrue(t3.done()) - self.assertTrue(t3.result()) - - async def test_notify_all(self): - cond = _ACondition(threading.Condition(threading.Lock())) - - result = [] - - async def c1(result): - async with cond: - if await cond.wait(): - result.append(1) - return True - - async def c2(result): - async with cond: - if await cond.wait(): - result.append(2) - return True - - t1 = asyncio.create_task(c1(result)) - t2 = asyncio.create_task(c2(result)) - - await asyncio.sleep(0) - self.assertEqual([], result) - - async with cond: - cond.notify_all() - await asyncio.sleep(1) - self.assertEqual([1, 2], result) - - self.assertTrue(t1.done()) - self.assertTrue(t1.result()) - self.assertTrue(t2.done()) - self.assertTrue(t2.result()) - - async def test_context_manager(self): - cond = _ACondition(threading.Condition(threading.Lock())) - self.assertFalse(cond.locked()) - async with cond: - self.assertTrue(cond.locked()) - self.assertFalse(cond.locked()) - - async def test_timeout_in_block(self): - condition = _ACondition(threading.Condition(threading.Lock())) - async with condition: - with self.assertRaises(asyncio.TimeoutError): - await asyncio.wait_for(condition.wait(), timeout=0.5) - - @unittest.skipIf( - sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11" - ) - async def test_cancelled_error_wakeup(self): - # Test that a cancelled error, received when awaiting wakeup, - # will be re-raised un-modified. - wake = False - raised = None - cond = _ACondition(threading.Condition(threading.Lock())) - - async def func(): - nonlocal raised - async with cond: - with self.assertRaises(asyncio.CancelledError) as err: - await cond.wait_for(lambda: wake) - raised = err.exception - raise raised - - task = asyncio.create_task(func()) - await asyncio.sleep(0) - # Task is waiting on the condition, cancel it there. - task.cancel(msg="foo") # type: ignore[call-arg] - with self.assertRaises(asyncio.CancelledError) as err: - await task - self.assertEqual(err.exception.args, ("foo",)) - # We should have got the _same_ exception instance as the one - # originally raised. - self.assertIs(err.exception, raised) - - @unittest.skipIf( - sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11" - ) - async def test_cancelled_error_re_aquire(self): - # Test that a cancelled error, received when re-aquiring lock, - # will be re-raised un-modified. - wake = False - raised = None - cond = _ACondition(threading.Condition(threading.Lock())) - - async def func(): - nonlocal raised - async with cond: - with self.assertRaises(asyncio.CancelledError) as err: - await cond.wait_for(lambda: wake) - raised = err.exception - raise raised - - task = asyncio.create_task(func()) - await asyncio.sleep(0) - # Task is waiting on the condition - await cond.acquire() - wake = True - cond.notify() - await asyncio.sleep(0) - # Task is now trying to re-acquire the lock, cancel it there. - task.cancel(msg="foo") # type: ignore[call-arg] - cond.release() - with self.assertRaises(asyncio.CancelledError) as err: - await task - self.assertEqual(err.exception.args, ("foo",)) - # We should have got the _same_ exception instance as the one - # originally raised. - self.assertIs(err.exception, raised) - - @unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11") - async def test_cancelled_wakeup(self): - # Test that a task cancelled at the "same" time as it is woken - # up as part of a Condition.notify() does not result in a lost wakeup. - # This test simulates a cancel while the target task is awaiting initial - # wakeup on the wakeup queue. - condition = _ACondition(threading.Condition(threading.Lock())) - state = 0 - - async def consumer(): - nonlocal state - async with condition: - while True: - await condition.wait_for(lambda: state != 0) - if state < 0: - return - state -= 1 - - # create two consumers - c = [asyncio.create_task(consumer()) for _ in range(2)] - # wait for them to settle - await asyncio.sleep(0.1) - async with condition: - # produce one item and wake up one - state += 1 - condition.notify(1) - - # Cancel it while it is awaiting to be run. - # This cancellation could come from the outside - c[0].cancel() - - # now wait for the item to be consumed - # if it doesn't means that our "notify" didn"t take hold. - # because it raced with a cancel() - try: - async with asyncio.timeout(1): - await condition.wait_for(lambda: state == 0) - except TimeoutError: - pass - self.assertEqual(state, 0) - - # clean up - state = -1 - condition.notify_all() - await c[1] - - @unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11") - async def test_cancelled_wakeup_relock(self): - # Test that a task cancelled at the "same" time as it is woken - # up as part of a Condition.notify() does not result in a lost wakeup. - # This test simulates a cancel while the target task is acquiring the lock - # again. - condition = _ACondition(threading.Condition(threading.Lock())) - state = 0 - - async def consumer(): - nonlocal state - async with condition: - while True: - await condition.wait_for(lambda: state != 0) - if state < 0: - return - state -= 1 - - # create two consumers - c = [asyncio.create_task(consumer()) for _ in range(2)] - # wait for them to settle - await asyncio.sleep(0.1) - async with condition: - # produce one item and wake up one - state += 1 - condition.notify(1) - - # now we sleep for a bit. This allows the target task to wake up and - # settle on re-aquiring the lock - await asyncio.sleep(0) - - # Cancel it while awaiting the lock - # This cancel could come the outside. - c[0].cancel() - - # now wait for the item to be consumed - # if it doesn't means that our "notify" didn"t take hold. - # because it raced with a cancel() - try: - async with asyncio.timeout(1): - await condition.wait_for(lambda: state == 0) - except TimeoutError: - pass - self.assertEqual(state, 0) - - # clean up - state = -1 - condition.notify_all() - await c[1] - - -class TestCondition(unittest.IsolatedAsyncioTestCase): - async def test_multiple_loops_notify(self): - cond = _ACondition(threading.Condition(threading.Lock())) - - def tmain(cond): - async def atmain(cond): - await asyncio.sleep(1) - async with cond: - cond.notify(1) - - asyncio.run(atmain(cond)) - - t = threading.Thread(target=tmain, args=(cond,)) - t.start() - - async with cond: - self.assertTrue(await cond.wait(30)) - t.join() - - async def test_multiple_loops_notify_all(self): - cond = _ACondition(threading.Condition(threading.Lock())) - results = [] - - def tmain(cond, results): - async def atmain(cond, results): - await asyncio.sleep(1) - async with cond: - res = await cond.wait(30) - results.append(res) - - asyncio.run(atmain(cond, results)) - - nthreads = 5 - threads = [] - for _ in range(nthreads): - threads.append(threading.Thread(target=tmain, args=(cond, results))) - for t in threads: - t.start() - - await asyncio.sleep(2) - async with cond: - cond.notify_all() - - for t in threads: - t.join() - - self.assertEqual(results, [True] * nthreads) - - -if __name__ == "__main__": - unittest.main() +# # Copyright 2024-present MongoDB, Inc. +# # +# # Licensed under the Apache License, Version 2.0 (the "License"); +# # you may not use this file except in compliance with the License. +# # You may obtain a copy of the License at +# # +# # http://www.apache.org/licenses/LICENSE-2.0 +# # +# # Unless required by applicable law or agreed to in writing, software +# # distributed under the License is distributed on an "AS IS" BASIS, +# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# # See the License for the specific language governing permissions and +# # limitations under the License. +# """Tests for lock.py""" +# from __future__ import annotations +# +# import asyncio +# import sys +# import threading +# import unittest +# +# sys.path[0:0] = [""] +# +# from pymongo.lock import _async_create_lock +# +# +# # Tests adapted from: https://github.com/python/cpython/blob/v3.13.0rc2/Lib/test/test_asyncio/test_locks.py +# # Includes tests for: +# # - https://github.com/python/cpython/issues/111693 +# # - https://github.com/python/cpython/issues/112202 +# class TestConditionStdlib(unittest.IsolatedAsyncioTestCase): +# async def test_wait(self): +# cond = _async_create_lock(threading.Condition(threading.Lock())) +# result = [] +# +# async def c1(result): +# await cond.acquire() +# if await cond.wait(): +# result.append(1) +# return True +# +# async def c2(result): +# await cond.acquire() +# if await cond.wait(): +# result.append(2) +# return True +# +# async def c3(result): +# await cond.acquire() +# if await cond.wait(): +# result.append(3) +# return True +# +# t1 = asyncio.create_task(c1(result)) +# t2 = asyncio.create_task(c2(result)) +# t3 = asyncio.create_task(c3(result)) +# +# await asyncio.sleep(0) +# self.assertEqual([], result) +# self.assertFalse(cond.locked()) +# +# self.assertTrue(await cond.acquire()) +# cond.notify() +# await asyncio.sleep(0) +# self.assertEqual([], result) +# self.assertTrue(cond.locked()) +# +# cond.release() +# await asyncio.sleep(0) +# self.assertEqual([1], result) +# self.assertTrue(cond.locked()) +# +# cond.notify(2) +# await asyncio.sleep(0) +# self.assertEqual([1], result) +# self.assertTrue(cond.locked()) +# +# cond.release() +# await asyncio.sleep(0) +# self.assertEqual([1, 2], result) +# self.assertTrue(cond.locked()) +# +# cond.release() +# await asyncio.sleep(0) +# self.assertEqual([1, 2, 3], result) +# self.assertTrue(cond.locked()) +# +# self.assertTrue(t1.done()) +# self.assertTrue(t1.result()) +# self.assertTrue(t2.done()) +# self.assertTrue(t2.result()) +# self.assertTrue(t3.done()) +# self.assertTrue(t3.result()) +# +# async def test_wait_cancel(self): +# cond = _async_create_lock(threading.Condition(threading.Lock())) +# await cond.acquire() +# +# wait = asyncio.create_task(cond.wait()) +# asyncio.get_running_loop().call_soon(wait.cancel) +# with self.assertRaises(asyncio.CancelledError): +# await wait +# self.assertFalse(cond._waiters) +# self.assertTrue(cond.locked()) +# +# async def test_wait_cancel_contested(self): +# cond = _async_create_lock(threading.Condition(threading.Lock())) +# +# await cond.acquire() +# self.assertTrue(cond.locked()) +# +# wait_task = asyncio.create_task(cond.wait()) +# await asyncio.sleep(0) +# self.assertFalse(cond.locked()) +# +# # Notify, but contest the lock before cancelling +# await cond.acquire() +# self.assertTrue(cond.locked()) +# cond.notify() +# asyncio.get_running_loop().call_soon(wait_task.cancel) +# asyncio.get_running_loop().call_soon(cond.release) +# +# try: +# await wait_task +# except asyncio.CancelledError: +# # Should not happen, since no cancellation points +# pass +# +# self.assertTrue(cond.locked()) +# +# async def test_wait_cancel_after_notify(self): +# # See bpo-32841 +# waited = False +# +# cond = _async_create_lock(threading.Condition(threading.Lock())) +# +# async def wait_on_cond(): +# nonlocal waited +# async with cond: +# waited = True # Make sure this area was reached +# await cond.wait() +# +# waiter = asyncio.create_task(wait_on_cond()) +# await asyncio.sleep(0) # Start waiting +# +# await cond.acquire() +# cond.notify() +# await asyncio.sleep(0) # Get to acquire() +# waiter.cancel() +# await asyncio.sleep(0) # Activate cancellation +# cond.release() +# await asyncio.sleep(0) # Cancellation should occur +# +# self.assertTrue(waiter.cancelled()) +# self.assertTrue(waited) +# +# async def test_wait_unacquired(self): +# cond = _async_create_lock(threading.Condition(threading.Lock())) +# with self.assertRaises(RuntimeError): +# await cond.wait() +# +# async def test_wait_for(self): +# cond = _async_create_lock(threading.Condition(threading.Lock())) +# presult = False +# +# def predicate(): +# return presult +# +# result = [] +# +# async def c1(result): +# await cond.acquire() +# if await cond.wait_for(predicate): +# result.append(1) +# cond.release() +# return True +# +# t = asyncio.create_task(c1(result)) +# +# await asyncio.sleep(0) +# self.assertEqual([], result) +# +# await cond.acquire() +# cond.notify() +# cond.release() +# await asyncio.sleep(0) +# self.assertEqual([], result) +# +# presult = True +# await cond.acquire() +# cond.notify() +# cond.release() +# await asyncio.sleep(0) +# self.assertEqual([1], result) +# +# self.assertTrue(t.done()) +# self.assertTrue(t.result()) +# +# async def test_wait_for_unacquired(self): +# cond = _async_create_lock(threading.Condition(threading.Lock())) +# +# # predicate can return true immediately +# res = await cond.wait_for(lambda: [1, 2, 3]) +# self.assertEqual([1, 2, 3], res) +# +# with self.assertRaises(RuntimeError): +# await cond.wait_for(lambda: False) +# +# async def test_notify(self): +# cond = _async_create_lock(threading.Condition(threading.Lock())) +# result = [] +# +# async def c1(result): +# async with cond: +# if await cond.wait(): +# result.append(1) +# return True +# +# async def c2(result): +# async with cond: +# if await cond.wait(): +# result.append(2) +# return True +# +# async def c3(result): +# async with cond: +# if await cond.wait(): +# result.append(3) +# return True +# +# t1 = asyncio.create_task(c1(result)) +# t2 = asyncio.create_task(c2(result)) +# t3 = asyncio.create_task(c3(result)) +# +# await asyncio.sleep(0) +# self.assertEqual([], result) +# +# async with cond: +# cond.notify(1) +# await asyncio.sleep(1) +# self.assertEqual([1], result) +# +# async with cond: +# cond.notify(1) +# cond.notify(2048) +# await asyncio.sleep(1) +# self.assertEqual([1, 2, 3], result) +# +# self.assertTrue(t1.done()) +# self.assertTrue(t1.result()) +# self.assertTrue(t2.done()) +# self.assertTrue(t2.result()) +# self.assertTrue(t3.done()) +# self.assertTrue(t3.result()) +# +# async def test_notify_all(self): +# cond = _async_create_lock(threading.Condition(threading.Lock())) +# +# result = [] +# +# async def c1(result): +# async with cond: +# if await cond.wait(): +# result.append(1) +# return True +# +# async def c2(result): +# async with cond: +# if await cond.wait(): +# result.append(2) +# return True +# +# t1 = asyncio.create_task(c1(result)) +# t2 = asyncio.create_task(c2(result)) +# +# await asyncio.sleep(0) +# self.assertEqual([], result) +# +# async with cond: +# cond.notify_all() +# await asyncio.sleep(1) +# self.assertEqual([1, 2], result) +# +# self.assertTrue(t1.done()) +# self.assertTrue(t1.result()) +# self.assertTrue(t2.done()) +# self.assertTrue(t2.result()) +# +# async def test_context_manager(self): +# cond = _async_create_lock(threading.Condition(threading.Lock())) +# self.assertFalse(cond.locked()) +# async with cond: +# self.assertTrue(cond.locked()) +# self.assertFalse(cond.locked()) +# +# async def test_timeout_in_block(self): +# condition = _async_create_lock(threading.Condition(threading.Lock())) +# async with condition: +# with self.assertRaises(asyncio.TimeoutError): +# await asyncio.wait_for(condition.wait(), timeout=0.5) +# +# @unittest.skipIf( +# sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11" +# ) +# async def test_cancelled_error_wakeup(self): +# # Test that a cancelled error, received when awaiting wakeup, +# # will be re-raised un-modified. +# wake = False +# raised = None +# cond = _async_create_lock(threading.Condition(threading.Lock())) +# +# async def func(): +# nonlocal raised +# async with cond: +# with self.assertRaises(asyncio.CancelledError) as err: +# await cond.wait_for(lambda: wake) +# raised = err.exception +# raise raised +# +# task = asyncio.create_task(func()) +# await asyncio.sleep(0) +# # Task is waiting on the condition, cancel it there. +# task.cancel(msg="foo") # type: ignore[call-arg] +# with self.assertRaises(asyncio.CancelledError) as err: +# await task +# self.assertEqual(err.exception.args, ("foo",)) +# # We should have got the _same_ exception instance as the one +# # originally raised. +# self.assertIs(err.exception, raised) +# +# @unittest.skipIf( +# sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11" +# ) +# async def test_cancelled_error_re_aquire(self): +# # Test that a cancelled error, received when re-aquiring lock, +# # will be re-raised un-modified. +# wake = False +# raised = None +# cond = _async_create_lock(threading.Condition(threading.Lock())) +# +# async def func(): +# nonlocal raised +# async with cond: +# with self.assertRaises(asyncio.CancelledError) as err: +# await cond.wait_for(lambda: wake) +# raised = err.exception +# raise raised +# +# task = asyncio.create_task(func()) +# await asyncio.sleep(0) +# # Task is waiting on the condition +# await cond.acquire() +# wake = True +# cond.notify() +# await asyncio.sleep(0) +# # Task is now trying to re-acquire the lock, cancel it there. +# task.cancel(msg="foo") # type: ignore[call-arg] +# cond.release() +# with self.assertRaises(asyncio.CancelledError) as err: +# await task +# self.assertEqual(err.exception.args, ("foo",)) +# # We should have got the _same_ exception instance as the one +# # originally raised. +# self.assertIs(err.exception, raised) +# +# @unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11") +# async def test_cancelled_wakeup(self): +# # Test that a task cancelled at the "same" time as it is woken +# # up as part of a Condition.notify() does not result in a lost wakeup. +# # This test simulates a cancel while the target task is awaiting initial +# # wakeup on the wakeup queue. +# condition = _async_create_lock(threading.Condition(threading.Lock())) +# state = 0 +# +# async def consumer(): +# nonlocal state +# async with condition: +# while True: +# await condition.wait_for(lambda: state != 0) +# if state < 0: +# return +# state -= 1 +# +# # create two consumers +# c = [asyncio.create_task(consumer()) for _ in range(2)] +# # wait for them to settle +# await asyncio.sleep(0.1) +# async with condition: +# # produce one item and wake up one +# state += 1 +# condition.notify(1) +# +# # Cancel it while it is awaiting to be run. +# # This cancellation could come from the outside +# c[0].cancel() +# +# # now wait for the item to be consumed +# # if it doesn't means that our "notify" didn"t take hold. +# # because it raced with a cancel() +# try: +# async with asyncio.timeout(1): +# await condition.wait_for(lambda: state == 0) +# except TimeoutError: +# pass +# self.assertEqual(state, 0) +# +# # clean up +# state = -1 +# condition.notify_all() +# await c[1] +# +# @unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11") +# async def test_cancelled_wakeup_relock(self): +# # Test that a task cancelled at the "same" time as it is woken +# # up as part of a Condition.notify() does not result in a lost wakeup. +# # This test simulates a cancel while the target task is acquiring the lock +# # again. +# condition = _async_create_lock(threading.Condition(threading.Lock())) +# state = 0 +# +# async def consumer(): +# nonlocal state +# async with condition: +# while True: +# await condition.wait_for(lambda: state != 0) +# if state < 0: +# return +# state -= 1 +# +# # create two consumers +# c = [asyncio.create_task(consumer()) for _ in range(2)] +# # wait for them to settle +# await asyncio.sleep(0.1) +# async with condition: +# # produce one item and wake up one +# state += 1 +# condition.notify(1) +# +# # now we sleep for a bit. This allows the target task to wake up and +# # settle on re-aquiring the lock +# await asyncio.sleep(0) +# +# # Cancel it while awaiting the lock +# # This cancel could come the outside. +# c[0].cancel() +# +# # now wait for the item to be consumed +# # if it doesn't means that our "notify" didn"t take hold. +# # because it raced with a cancel() +# try: +# async with asyncio.timeout(1): +# await condition.wait_for(lambda: state == 0) +# except TimeoutError: +# pass +# self.assertEqual(state, 0) +# +# # clean up +# state = -1 +# condition.notify_all() +# await c[1] +# +# +# class TestCondition(unittest.IsolatedAsyncioTestCase): +# async def test_multiple_loops_notify(self): +# cond = _async_create_lock(threading.Condition(threading.Lock())) +# +# def tmain(cond): +# async def atmain(cond): +# await asyncio.sleep(1) +# async with cond: +# cond.notify(1) +# +# asyncio.run(atmain(cond)) +# +# t = threading.Thread(target=tmain, args=(cond,)) +# t.start() +# +# async with cond: +# self.assertTrue(await cond.wait(30)) +# t.join() +# +# async def test_multiple_loops_notify_all(self): +# cond = _async_create_lock(threading.Condition(threading.Lock())) +# results = [] +# +# def tmain(cond, results): +# async def atmain(cond, results): +# await asyncio.sleep(1) +# async with cond: +# res = await cond.wait(30) +# results.append(res) +# +# asyncio.run(atmain(cond, results)) +# +# nthreads = 5 +# threads = [] +# for _ in range(nthreads): +# threads.append(threading.Thread(target=tmain, args=(cond, results))) +# for t in threads: +# t.start() +# +# await asyncio.sleep(2) +# async with cond: +# cond.notify_all() +# +# for t in threads: +# t.join() +# +# self.assertEqual(results, [True] * nthreads) +# +# +# if __name__ == "__main__": +# unittest.main() diff --git a/test/conftest.py b/test/conftest.py index 91fad28d0a..013e9c4e24 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -2,7 +2,8 @@ import asyncio import sys -from test import pytest_conf, setup, teardown +from test import pytest_conf +from test.synchronous import setup, teardown import pytest diff --git a/test/pymongo_mocks.py b/test/pymongo_mocks.py index 7662dc9682..51c90a3884 100644 --- a/test/pymongo_mocks.py +++ b/test/pymongo_mocks.py @@ -19,6 +19,7 @@ import weakref from functools import partial from test import client_context +from test.synchronous import client_context from pymongo import MongoClient, common from pymongo.errors import AutoReconnect, NetworkTimeout diff --git a/test/test_auth.py b/test/test_auth.py index b311d330bc..355d7d1833 100644 --- a/test/test_auth.py +++ b/test/test_auth.py @@ -23,7 +23,7 @@ sys.path[0:0] = [""] -from test import ( +from test.synchronous import ( IntegrationTest, PyMongoTestCase, SkipTest, diff --git a/test/test_auth_spec.py b/test/test_auth_spec.py index 3c3a1a67ae..fa14b82289 100644 --- a/test/test_auth_spec.py +++ b/test/test_auth_spec.py @@ -20,7 +20,7 @@ import os import sys import warnings -from test import PyMongoTestCase +from test.synchronous import PyMongoTestCase sys.path[0:0] = [""] diff --git a/test/test_bulk.py b/test/test_bulk.py index 64fd48e8cd..76c20d692a 100644 --- a/test/test_bulk.py +++ b/test/test_bulk.py @@ -23,7 +23,7 @@ sys.path[0:0] = [""] -from test import IntegrationTest, client_context, remove_all_users, unittest +from test.synchronous import IntegrationTest, client_context, remove_all_users, unittest from test.utils import wait_until from bson.binary import Binary, UuidRepresentation diff --git a/test/test_change_stream.py b/test/test_change_stream.py index dae224c5e0..63c76b6f33 100644 --- a/test/test_change_stream.py +++ b/test/test_change_stream.py @@ -28,7 +28,7 @@ sys.path[0:0] = [""] -from test import ( +from test.synchronous import ( IntegrationTest, PyMongoTestCase, Version, diff --git a/test/test_client.py b/test/test_client.py index a4c521157b..1dcf83abba 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -44,7 +44,7 @@ sys.path[0:0] = [""] -from test import ( +from test.synchronous import ( HAVE_IPADDRESS, IntegrationTest, MockClientTest, @@ -58,7 +58,7 @@ remove_all_users, unittest, ) -from test.pymongo_mocks import MockClient +from test.synchronous.pymongo_mocks import MockClient from test.test_binary import BinaryData from test.utils import ( NTHREADS, diff --git a/test/test_client_bulk_write.py b/test/test_client_bulk_write.py index 58b5015dd2..908cf95337 100644 --- a/test/test_client_bulk_write.py +++ b/test/test_client_bulk_write.py @@ -20,7 +20,7 @@ sys.path[0:0] = [""] -from test import ( +from test.synchronous import ( IntegrationTest, client_context, unittest, diff --git a/test/test_client_context.py b/test/test_client_context.py index 5996f9243b..44dadda2cc 100644 --- a/test/test_client_context.py +++ b/test/test_client_context.py @@ -18,7 +18,7 @@ sys.path[0:0] = [""] -from test import SkipTest, UnitTest, client_context, unittest +from test.synchronous import SkipTest, UnitTest, client_context, unittest _IS_SYNC = True diff --git a/test/test_collation.py b/test/test_collation.py index e5c1c7eb11..a63d2c68f4 100644 --- a/test/test_collation.py +++ b/test/test_collation.py @@ -17,7 +17,7 @@ import functools import warnings -from test import IntegrationTest, client_context, unittest +from test.synchronous import IntegrationTest, client_context, unittest from test.utils import EventListener from typing import Any diff --git a/test/test_collection.py b/test/test_collection.py index f2f01ac686..4a22ba4e65 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -27,11 +27,11 @@ sys.path[0:0] = [""] -from test import ( # TODO: fix sync imports in PYTHON-4528 +from test import unittest +from test.synchronous import ( # TODO: fix sync imports in PYTHON-4528 IntegrationTest, UnitTest, client_context, - unittest, ) from test.utils import ( IMPOSSIBLE_WRITE_CONCERN, diff --git a/test/test_common.py b/test/test_common.py index e69b421c9f..8e26003f1f 100644 --- a/test/test_common.py +++ b/test/test_common.py @@ -20,7 +20,7 @@ sys.path[0:0] = [""] -from test import IntegrationTest, client_context, connected, unittest +from test.synchronous import IntegrationTest, client_context, connected, unittest from bson.binary import PYTHON_LEGACY, STANDARD, Binary, UuidRepresentation from bson.codec_options import CodecOptions diff --git a/test/test_connections_survive_primary_stepdown_spec.py b/test/test_connections_survive_primary_stepdown_spec.py index 984d700fb3..5ef569f7a6 100644 --- a/test/test_connections_survive_primary_stepdown_spec.py +++ b/test/test_connections_survive_primary_stepdown_spec.py @@ -19,13 +19,13 @@ sys.path[0:0] = [""] -from test import ( +from test.synchronous import ( IntegrationTest, client_context, reset_client_context, unittest, ) -from test.helpers import repl_set_step_down +from test.synchronous.helpers import repl_set_step_down from test.utils import ( CMAPListener, ensure_all_connected, diff --git a/test/test_cursor.py b/test/test_cursor.py index 7c073bf351..8ac67835fc 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -29,7 +29,7 @@ sys.path[0:0] = [""] -from test import IntegrationTest, client_context, unittest +from test.synchronous import IntegrationTest, client_context, unittest from test.utils import ( AllowListEventListener, EventListener, diff --git a/test/test_database.py b/test/test_database.py index 4973ed0134..064e34162f 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -23,7 +23,8 @@ sys.path[0:0] = [""] -from test import IntegrationTest, client_context, unittest +from test import unittest +from test.synchronous import IntegrationTest, client_context from test.test_custom_types import DECIMAL_CODECOPTS from test.utils import ( IMPOSSIBLE_WRITE_CONCERN, diff --git a/test/test_encryption.py b/test/test_encryption.py index 43c85e2c5b..d74514f093 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -28,8 +28,8 @@ import traceback import uuid import warnings -from test import IntegrationTest, PyMongoTestCase, client_context -from test.test_bulk import BulkTestBase +from test.synchronous import IntegrationTest, PyMongoTestCase, client_context +from test.synchronous.test_bulk import BulkTestBase from threading import Thread from typing import Any, Dict, Mapping, Optional @@ -53,7 +53,8 @@ KMIP_CREDS, LOCAL_MASTER_KEY, ) -from test.test_bulk import BulkTestBase +from test.synchronous.test_bulk import BulkTestBase +from test.synchronous.utils_spec_runner import SpecRunner from test.unified_format import generate_test_classes from test.utils import ( AllowListEventListener, @@ -64,7 +65,6 @@ is_greenthread_patched, wait_until, ) -from test.utils_spec_runner import SpecRunner from bson import DatetimeMS, Decimal128, encode, json_util from bson.binary import UUID_SUBTYPE, Binary, UuidRepresentation diff --git a/test/test_grid_file.py b/test/test_grid_file.py index fe88aec5ff..bf49750c25 100644 --- a/test/test_grid_file.py +++ b/test/test_grid_file.py @@ -21,7 +21,7 @@ import sys import zipfile from io import BytesIO -from test import ( +from test.synchronous import ( IntegrationTest, UnitTest, client_context, diff --git a/test/test_logger.py b/test/test_logger.py index b3c8e6d176..8a89330195 100644 --- a/test/test_logger.py +++ b/test/test_logger.py @@ -14,7 +14,8 @@ from __future__ import annotations import os -from test import IntegrationTest, unittest +from test import unittest +from test.synchronous import IntegrationTest from unittest.mock import patch from bson import json_util diff --git a/test/test_monitoring.py b/test/test_monitoring.py index a0c520ed27..4bb174355a 100644 --- a/test/test_monitoring.py +++ b/test/test_monitoring.py @@ -22,7 +22,7 @@ sys.path[0:0] = [""] -from test import ( +from test.synchronous import ( IntegrationTest, client_context, client_knobs, diff --git a/test/test_raw_bson.py b/test/test_raw_bson.py index 4d9a3ceb05..4f14e63ea0 100644 --- a/test/test_raw_bson.py +++ b/test/test_raw_bson.py @@ -19,7 +19,7 @@ sys.path[0:0] = [""] -from test import IntegrationTest, client_context, unittest +from test.synchronous import IntegrationTest, client_context, unittest from bson import Code, DBRef, decode, encode from bson.binary import JAVA_LEGACY, Binary, UuidRepresentation diff --git a/test/test_retryable_reads.py b/test/test_retryable_reads.py index d4951db5ee..a66fbd04ca 100644 --- a/test/test_retryable_reads.py +++ b/test/test_retryable_reads.py @@ -24,7 +24,7 @@ sys.path[0:0] = [""] -from test import ( +from test.synchronous import ( IntegrationTest, PyMongoTestCase, client_context, diff --git a/test/test_retryable_writes.py b/test/test_retryable_writes.py index 5df6c41f7a..9d9d28c9f9 100644 --- a/test/test_retryable_writes.py +++ b/test/test_retryable_writes.py @@ -23,13 +23,13 @@ sys.path[0:0] = [""] -from test import ( +from test.synchronous import ( IntegrationTest, SkipTest, client_context, unittest, ) -from test.helpers import client_knobs +from test.synchronous.helpers import client_knobs from test.utils import ( CMAPListener, DeprecationFilter, diff --git a/test/test_session.py b/test/test_session.py index 9f94ded927..881a9a9d23 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -25,7 +25,7 @@ sys.path[0:0] = [""] -from test import ( +from test.synchronous import ( IntegrationTest, PyMongoTestCase, SkipTest, diff --git a/test/test_transactions.py b/test/test_transactions.py index 3cecbe9d38..f361852e61 100644 --- a/test/test_transactions.py +++ b/test/test_transactions.py @@ -17,13 +17,13 @@ import sys from io import BytesIO -from test.utils_spec_runner import SpecRunner +from test.synchronous.utils_spec_runner import SpecRunner from gridfs.synchronous.grid_file import GridFS, GridFSBucket sys.path[0:0] = [""] -from test import IntegrationTest, client_context, unittest +from test.synchronous import IntegrationTest, client_context, unittest from test.utils import ( OvertCommandListener, wait_until, diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index 06a40351cd..32a5244e16 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -18,7 +18,7 @@ import functools import threading from collections import abc -from test import IntegrationTest, client_context, client_knobs +from test.synchronous import IntegrationTest, client_context, client_knobs from test.utils import ( CMAPListener, CompareType, diff --git a/tools/synchro.py b/tools/synchro.py index c3c0b568ed..969e6801d2 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -112,6 +112,9 @@ "async_wait_for_event": "wait_for_event", "pymongo_server_monitor_task": "pymongo_server_monitor_thread", "pymongo_server_rtt_task": "pymongo_server_rtt_thread", + "_async_create_lock": "_create_lock", + "_async_create_condition": "_create_condition", + "_async_cond_wait": "_cond_wait", } docstring_replacements: dict[tuple[str, str], str] = { @@ -132,8 +135,6 @@ ".. warning:: This API is currently in beta, meaning the classes, methods, and behaviors described within may change before the full release." } -type_replacements = {"_Condition": "threading.Condition"} - import_replacements = {"test.synchronous": "test"} _pymongo_base = "./pymongo/asynchronous/" @@ -227,10 +228,6 @@ def process_files(files: list[str]) -> None: lines = translate_async_sleeps(lines) if file in docstring_translate_files: lines = translate_docstrings(lines) - translate_locks(lines) - translate_types(lines) - if file in sync_test_files: - translate_imports(lines) f.seek(0) f.writelines(lines) f.truncate() @@ -262,43 +259,6 @@ def translate_coroutine_types(lines: list[str]) -> list[str]: return lines -def translate_locks(lines: list[str]) -> list[str]: - lock_lines = [line for line in lines if "_Lock(" in line] - cond_lines = [line for line in lines if "_Condition(" in line] - for line in lock_lines: - res = re.search(r"_Lock\(([^()]*\([^()]*\))\)", line) - if res: - old = res[0] - index = lines.index(line) - lines[index] = line.replace(old, res[1]) - for line in cond_lines: - res = re.search(r"_Condition\(([^()]*\([^()]*\))\)", line) - if res: - old = res[0] - index = lines.index(line) - lines[index] = line.replace(old, res[1]) - - return lines - - -def translate_types(lines: list[str]) -> list[str]: - for k, v in type_replacements.items(): - matches = [line for line in lines if k in line and "import" not in line] - for line in matches: - index = lines.index(line) - lines[index] = line.replace(k, v) - return lines - - -def translate_imports(lines: list[str]) -> list[str]: - for k, v in import_replacements.items(): - matches = [line for line in lines if k in line and "import" in line] - for line in matches: - index = lines.index(line) - lines[index] = line.replace(k, v) - return lines - - def translate_async_sleeps(lines: list[str]) -> list[str]: blocking_sleeps = [line for line in lines if "asyncio.sleep(0)" in line] lines = [line for line in lines if line not in blocking_sleeps] From 81ea0bcb2bf4b3e480b5ce4a941d27a651620f80 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 16 Oct 2024 09:18:16 -0400 Subject: [PATCH 2/9] Fix test import --- test/conftest.py | 3 +-- test/pymongo_mocks.py | 1 - test/test_auth.py | 2 +- test/test_auth_spec.py | 2 +- test/test_bulk.py | 2 +- test/test_change_stream.py | 2 +- test/test_client.py | 4 ++-- test/test_client_bulk_write.py | 2 +- test/test_client_context.py | 2 +- test/test_collation.py | 2 +- test/test_collection.py | 4 ++-- test/test_common.py | 2 +- .../test_connections_survive_primary_stepdown_spec.py | 4 ++-- test/test_cursor.py | 2 +- test/test_database.py | 3 +-- test/test_encryption.py | 8 ++++---- test/test_grid_file.py | 2 +- test/test_logger.py | 3 +-- test/test_monitoring.py | 2 +- test/test_raw_bson.py | 2 +- test/test_retryable_reads.py | 2 +- test/test_retryable_writes.py | 4 ++-- test/test_session.py | 2 +- test/test_transactions.py | 4 ++-- test/utils_spec_runner.py | 2 +- tools/synchro.py | 11 +++++++++++ 26 files changed, 43 insertions(+), 36 deletions(-) diff --git a/test/conftest.py b/test/conftest.py index 013e9c4e24..91fad28d0a 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -2,8 +2,7 @@ import asyncio import sys -from test import pytest_conf -from test.synchronous import setup, teardown +from test import pytest_conf, setup, teardown import pytest diff --git a/test/pymongo_mocks.py b/test/pymongo_mocks.py index 51c90a3884..7662dc9682 100644 --- a/test/pymongo_mocks.py +++ b/test/pymongo_mocks.py @@ -19,7 +19,6 @@ import weakref from functools import partial from test import client_context -from test.synchronous import client_context from pymongo import MongoClient, common from pymongo.errors import AutoReconnect, NetworkTimeout diff --git a/test/test_auth.py b/test/test_auth.py index 355d7d1833..b311d330bc 100644 --- a/test/test_auth.py +++ b/test/test_auth.py @@ -23,7 +23,7 @@ sys.path[0:0] = [""] -from test.synchronous import ( +from test import ( IntegrationTest, PyMongoTestCase, SkipTest, diff --git a/test/test_auth_spec.py b/test/test_auth_spec.py index fa14b82289..3c3a1a67ae 100644 --- a/test/test_auth_spec.py +++ b/test/test_auth_spec.py @@ -20,7 +20,7 @@ import os import sys import warnings -from test.synchronous import PyMongoTestCase +from test import PyMongoTestCase sys.path[0:0] = [""] diff --git a/test/test_bulk.py b/test/test_bulk.py index 76c20d692a..64fd48e8cd 100644 --- a/test/test_bulk.py +++ b/test/test_bulk.py @@ -23,7 +23,7 @@ sys.path[0:0] = [""] -from test.synchronous import IntegrationTest, client_context, remove_all_users, unittest +from test import IntegrationTest, client_context, remove_all_users, unittest from test.utils import wait_until from bson.binary import Binary, UuidRepresentation diff --git a/test/test_change_stream.py b/test/test_change_stream.py index 63c76b6f33..dae224c5e0 100644 --- a/test/test_change_stream.py +++ b/test/test_change_stream.py @@ -28,7 +28,7 @@ sys.path[0:0] = [""] -from test.synchronous import ( +from test import ( IntegrationTest, PyMongoTestCase, Version, diff --git a/test/test_client.py b/test/test_client.py index 1dcf83abba..a4c521157b 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -44,7 +44,7 @@ sys.path[0:0] = [""] -from test.synchronous import ( +from test import ( HAVE_IPADDRESS, IntegrationTest, MockClientTest, @@ -58,7 +58,7 @@ remove_all_users, unittest, ) -from test.synchronous.pymongo_mocks import MockClient +from test.pymongo_mocks import MockClient from test.test_binary import BinaryData from test.utils import ( NTHREADS, diff --git a/test/test_client_bulk_write.py b/test/test_client_bulk_write.py index 908cf95337..58b5015dd2 100644 --- a/test/test_client_bulk_write.py +++ b/test/test_client_bulk_write.py @@ -20,7 +20,7 @@ sys.path[0:0] = [""] -from test.synchronous import ( +from test import ( IntegrationTest, client_context, unittest, diff --git a/test/test_client_context.py b/test/test_client_context.py index 44dadda2cc..5996f9243b 100644 --- a/test/test_client_context.py +++ b/test/test_client_context.py @@ -18,7 +18,7 @@ sys.path[0:0] = [""] -from test.synchronous import SkipTest, UnitTest, client_context, unittest +from test import SkipTest, UnitTest, client_context, unittest _IS_SYNC = True diff --git a/test/test_collation.py b/test/test_collation.py index a63d2c68f4..e5c1c7eb11 100644 --- a/test/test_collation.py +++ b/test/test_collation.py @@ -17,7 +17,7 @@ import functools import warnings -from test.synchronous import IntegrationTest, client_context, unittest +from test import IntegrationTest, client_context, unittest from test.utils import EventListener from typing import Any diff --git a/test/test_collection.py b/test/test_collection.py index 4a22ba4e65..f2f01ac686 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -27,11 +27,11 @@ sys.path[0:0] = [""] -from test import unittest -from test.synchronous import ( # TODO: fix sync imports in PYTHON-4528 +from test import ( # TODO: fix sync imports in PYTHON-4528 IntegrationTest, UnitTest, client_context, + unittest, ) from test.utils import ( IMPOSSIBLE_WRITE_CONCERN, diff --git a/test/test_common.py b/test/test_common.py index 8e26003f1f..e69b421c9f 100644 --- a/test/test_common.py +++ b/test/test_common.py @@ -20,7 +20,7 @@ sys.path[0:0] = [""] -from test.synchronous import IntegrationTest, client_context, connected, unittest +from test import IntegrationTest, client_context, connected, unittest from bson.binary import PYTHON_LEGACY, STANDARD, Binary, UuidRepresentation from bson.codec_options import CodecOptions diff --git a/test/test_connections_survive_primary_stepdown_spec.py b/test/test_connections_survive_primary_stepdown_spec.py index 5ef569f7a6..984d700fb3 100644 --- a/test/test_connections_survive_primary_stepdown_spec.py +++ b/test/test_connections_survive_primary_stepdown_spec.py @@ -19,13 +19,13 @@ sys.path[0:0] = [""] -from test.synchronous import ( +from test import ( IntegrationTest, client_context, reset_client_context, unittest, ) -from test.synchronous.helpers import repl_set_step_down +from test.helpers import repl_set_step_down from test.utils import ( CMAPListener, ensure_all_connected, diff --git a/test/test_cursor.py b/test/test_cursor.py index 8ac67835fc..7c073bf351 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -29,7 +29,7 @@ sys.path[0:0] = [""] -from test.synchronous import IntegrationTest, client_context, unittest +from test import IntegrationTest, client_context, unittest from test.utils import ( AllowListEventListener, EventListener, diff --git a/test/test_database.py b/test/test_database.py index 064e34162f..4973ed0134 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -23,8 +23,7 @@ sys.path[0:0] = [""] -from test import unittest -from test.synchronous import IntegrationTest, client_context +from test import IntegrationTest, client_context, unittest from test.test_custom_types import DECIMAL_CODECOPTS from test.utils import ( IMPOSSIBLE_WRITE_CONCERN, diff --git a/test/test_encryption.py b/test/test_encryption.py index d74514f093..43c85e2c5b 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -28,8 +28,8 @@ import traceback import uuid import warnings -from test.synchronous import IntegrationTest, PyMongoTestCase, client_context -from test.synchronous.test_bulk import BulkTestBase +from test import IntegrationTest, PyMongoTestCase, client_context +from test.test_bulk import BulkTestBase from threading import Thread from typing import Any, Dict, Mapping, Optional @@ -53,8 +53,7 @@ KMIP_CREDS, LOCAL_MASTER_KEY, ) -from test.synchronous.test_bulk import BulkTestBase -from test.synchronous.utils_spec_runner import SpecRunner +from test.test_bulk import BulkTestBase from test.unified_format import generate_test_classes from test.utils import ( AllowListEventListener, @@ -65,6 +64,7 @@ is_greenthread_patched, wait_until, ) +from test.utils_spec_runner import SpecRunner from bson import DatetimeMS, Decimal128, encode, json_util from bson.binary import UUID_SUBTYPE, Binary, UuidRepresentation diff --git a/test/test_grid_file.py b/test/test_grid_file.py index bf49750c25..fe88aec5ff 100644 --- a/test/test_grid_file.py +++ b/test/test_grid_file.py @@ -21,7 +21,7 @@ import sys import zipfile from io import BytesIO -from test.synchronous import ( +from test import ( IntegrationTest, UnitTest, client_context, diff --git a/test/test_logger.py b/test/test_logger.py index 8a89330195..b3c8e6d176 100644 --- a/test/test_logger.py +++ b/test/test_logger.py @@ -14,8 +14,7 @@ from __future__ import annotations import os -from test import unittest -from test.synchronous import IntegrationTest +from test import IntegrationTest, unittest from unittest.mock import patch from bson import json_util diff --git a/test/test_monitoring.py b/test/test_monitoring.py index 4bb174355a..a0c520ed27 100644 --- a/test/test_monitoring.py +++ b/test/test_monitoring.py @@ -22,7 +22,7 @@ sys.path[0:0] = [""] -from test.synchronous import ( +from test import ( IntegrationTest, client_context, client_knobs, diff --git a/test/test_raw_bson.py b/test/test_raw_bson.py index 4f14e63ea0..4d9a3ceb05 100644 --- a/test/test_raw_bson.py +++ b/test/test_raw_bson.py @@ -19,7 +19,7 @@ sys.path[0:0] = [""] -from test.synchronous import IntegrationTest, client_context, unittest +from test import IntegrationTest, client_context, unittest from bson import Code, DBRef, decode, encode from bson.binary import JAVA_LEGACY, Binary, UuidRepresentation diff --git a/test/test_retryable_reads.py b/test/test_retryable_reads.py index a66fbd04ca..d4951db5ee 100644 --- a/test/test_retryable_reads.py +++ b/test/test_retryable_reads.py @@ -24,7 +24,7 @@ sys.path[0:0] = [""] -from test.synchronous import ( +from test import ( IntegrationTest, PyMongoTestCase, client_context, diff --git a/test/test_retryable_writes.py b/test/test_retryable_writes.py index 9d9d28c9f9..5df6c41f7a 100644 --- a/test/test_retryable_writes.py +++ b/test/test_retryable_writes.py @@ -23,13 +23,13 @@ sys.path[0:0] = [""] -from test.synchronous import ( +from test import ( IntegrationTest, SkipTest, client_context, unittest, ) -from test.synchronous.helpers import client_knobs +from test.helpers import client_knobs from test.utils import ( CMAPListener, DeprecationFilter, diff --git a/test/test_session.py b/test/test_session.py index 881a9a9d23..9f94ded927 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -25,7 +25,7 @@ sys.path[0:0] = [""] -from test.synchronous import ( +from test import ( IntegrationTest, PyMongoTestCase, SkipTest, diff --git a/test/test_transactions.py b/test/test_transactions.py index f361852e61..3cecbe9d38 100644 --- a/test/test_transactions.py +++ b/test/test_transactions.py @@ -17,13 +17,13 @@ import sys from io import BytesIO -from test.synchronous.utils_spec_runner import SpecRunner +from test.utils_spec_runner import SpecRunner from gridfs.synchronous.grid_file import GridFS, GridFSBucket sys.path[0:0] = [""] -from test.synchronous import IntegrationTest, client_context, unittest +from test import IntegrationTest, client_context, unittest from test.utils import ( OvertCommandListener, wait_until, diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index 32a5244e16..06a40351cd 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -18,7 +18,7 @@ import functools import threading from collections import abc -from test.synchronous import IntegrationTest, client_context, client_knobs +from test import IntegrationTest, client_context, client_knobs from test.utils import ( CMAPListener, CompareType, diff --git a/tools/synchro.py b/tools/synchro.py index 969e6801d2..25d5ee4c30 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -228,6 +228,8 @@ def process_files(files: list[str]) -> None: lines = translate_async_sleeps(lines) if file in docstring_translate_files: lines = translate_docstrings(lines) + if file in sync_test_files: + translate_imports(lines) f.seek(0) f.writelines(lines) f.truncate() @@ -259,6 +261,15 @@ def translate_coroutine_types(lines: list[str]) -> list[str]: return lines +def translate_imports(lines: list[str]) -> list[str]: + for k, v in import_replacements.items(): + matches = [line for line in lines if k in line and "import" in line] + for line in matches: + index = lines.index(line) + lines[index] = line.replace(k, v) + return lines + + def translate_async_sleeps(lines: list[str]) -> list[str]: blocking_sleeps = [line for line in lines if "asyncio.sleep(0)" in line] lines = [line for line in lines if line not in blocking_sleeps] From 6880fafa934abb6ed54be36d16734e0b7cb8ee11 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 16 Oct 2024 11:43:47 -0400 Subject: [PATCH 3/9] Remove unneeded test_locks --- pymongo/asynchronous/pool.py | 5 +- pymongo/asynchronous/topology.py | 4 +- pymongo/lock.py | 12 +- pymongo/synchronous/pool.py | 5 +- pymongo/synchronous/topology.py | 4 +- test/__init__.py | 2 +- test/asynchronous/__init__.py | 2 +- test/asynchronous/test_locks.py | 513 ------------------------------- 8 files changed, 22 insertions(+), 525 deletions(-) delete mode 100644 test/asynchronous/test_locks.py diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index 532c20aa63..4363fc3370 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -23,7 +23,6 @@ import socket import ssl import sys -import threading import time import weakref from typing import ( @@ -992,8 +991,8 @@ def __init__( self.conns: collections.deque = collections.deque() self.active_contexts: set[_CancellationContext] = set() self.lock = _async_create_lock() - self.size_cond = _async_create_condition(self.lock, threading.Condition) - self._max_connecting_cond = _async_create_condition(self.lock, threading.Condition) + self.size_cond = _async_create_condition(self.lock) + self._max_connecting_cond = _async_create_condition(self.lock) self.active_sockets = 0 # Monotonically increasing connection ID required for CMAP Events. self.next_connection_id = 1 diff --git a/pymongo/asynchronous/topology.py b/pymongo/asynchronous/topology.py index 460fa15837..6d67710a7e 100644 --- a/pymongo/asynchronous/topology.py +++ b/pymongo/asynchronous/topology.py @@ -174,7 +174,9 @@ def __init__(self, topology_settings: TopologySettings): self._opened = False self._closed = False self._lock = _async_create_lock() - self._condition = _async_create_condition(self._lock, self._settings.condition_class) + self._condition = _async_create_condition( + self._lock, self._settings.condition_class if _IS_SYNC else None + ) self._servers: dict[_Address, Server] = {} self._pid: Optional[int] = None self._max_cluster_time: Optional[ClusterTime] = None diff --git a/pymongo/lock.py b/pymongo/lock.py index 52a400893e..26244ddefd 100644 --- a/pymongo/lock.py +++ b/pymongo/lock.py @@ -42,13 +42,21 @@ def _async_create_lock() -> asyncio.Lock: return asyncio.Lock() -def _create_condition(lock: threading.Lock, dummy: Any) -> threading.Condition: +def _create_condition( + lock: threading.Lock, condition_class: Optional[Any] = None +) -> threading.Condition: """Represents a threading.Condition.""" + if condition_class: + return condition_class(lock) return threading.Condition(lock) -def _async_create_condition(lock: asyncio.Lock, dummy: Any) -> asyncio.Condition: +def _async_create_condition( + lock: asyncio.Lock, condition_class: Optional[Any] = None +) -> asyncio.Condition: """Represents an asyncio.Condition.""" + if condition_class: + return condition_class(lock) return asyncio.Condition(lock) diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 8172a9f846..f68bfe2002 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -23,7 +23,6 @@ import socket import ssl import sys -import threading import time import weakref from typing import ( @@ -988,8 +987,8 @@ def __init__( self.conns: collections.deque = collections.deque() self.active_contexts: set[_CancellationContext] = set() self.lock = _create_lock() - self.size_cond = _create_condition(self.lock, threading.Condition) - self._max_connecting_cond = _create_condition(self.lock, threading.Condition) + self.size_cond = _create_condition(self.lock) + self._max_connecting_cond = _create_condition(self.lock) self.active_sockets = 0 # Monotonically increasing connection ID required for CMAP Events. self.next_connection_id = 1 diff --git a/pymongo/synchronous/topology.py b/pymongo/synchronous/topology.py index 48c8b81590..b03269ae43 100644 --- a/pymongo/synchronous/topology.py +++ b/pymongo/synchronous/topology.py @@ -174,7 +174,9 @@ def __init__(self, topology_settings: TopologySettings): self._opened = False self._closed = False self._lock = _create_lock() - self._condition = _create_condition(self._lock, self._settings.condition_class) + self._condition = _create_condition( + self._lock, self._settings.condition_class if _IS_SYNC else None + ) self._servers: dict[_Address, Server] = {} self._pid: Optional[int] = None self._max_cluster_time: Optional[ClusterTime] = None diff --git a/test/__init__.py b/test/__init__.py index 940518c2c5..c1944f5870 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -1131,7 +1131,7 @@ class IntegrationTest(PyMongoTestCase): @client_context.require_connection def setUp(self) -> None: - if not _IS_SYNC: + if not _IS_SYNC and client_context.client is not None: reset_client_context() if client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False): raise SkipTest("this test does not support load balancers") diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index 8d1e3e1911..9ca5a32ffc 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -1149,7 +1149,7 @@ class AsyncIntegrationTest(AsyncPyMongoTestCase): @async_client_context.require_connection async def asyncSetUp(self) -> None: - if not _IS_SYNC: + if not _IS_SYNC and async_client_context.client is not None: await reset_client_context() if async_client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False): raise SkipTest("this test does not support load balancers") diff --git a/test/asynchronous/test_locks.py b/test/asynchronous/test_locks.py deleted file mode 100644 index a029f0792b..0000000000 --- a/test/asynchronous/test_locks.py +++ /dev/null @@ -1,513 +0,0 @@ -# # Copyright 2024-present MongoDB, Inc. -# # -# # Licensed under the Apache License, Version 2.0 (the "License"); -# # you may not use this file except in compliance with the License. -# # You may obtain a copy of the License at -# # -# # http://www.apache.org/licenses/LICENSE-2.0 -# # -# # Unless required by applicable law or agreed to in writing, software -# # distributed under the License is distributed on an "AS IS" BASIS, -# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# # See the License for the specific language governing permissions and -# # limitations under the License. -# """Tests for lock.py""" -# from __future__ import annotations -# -# import asyncio -# import sys -# import threading -# import unittest -# -# sys.path[0:0] = [""] -# -# from pymongo.lock import _async_create_lock -# -# -# # Tests adapted from: https://github.com/python/cpython/blob/v3.13.0rc2/Lib/test/test_asyncio/test_locks.py -# # Includes tests for: -# # - https://github.com/python/cpython/issues/111693 -# # - https://github.com/python/cpython/issues/112202 -# class TestConditionStdlib(unittest.IsolatedAsyncioTestCase): -# async def test_wait(self): -# cond = _async_create_lock(threading.Condition(threading.Lock())) -# result = [] -# -# async def c1(result): -# await cond.acquire() -# if await cond.wait(): -# result.append(1) -# return True -# -# async def c2(result): -# await cond.acquire() -# if await cond.wait(): -# result.append(2) -# return True -# -# async def c3(result): -# await cond.acquire() -# if await cond.wait(): -# result.append(3) -# return True -# -# t1 = asyncio.create_task(c1(result)) -# t2 = asyncio.create_task(c2(result)) -# t3 = asyncio.create_task(c3(result)) -# -# await asyncio.sleep(0) -# self.assertEqual([], result) -# self.assertFalse(cond.locked()) -# -# self.assertTrue(await cond.acquire()) -# cond.notify() -# await asyncio.sleep(0) -# self.assertEqual([], result) -# self.assertTrue(cond.locked()) -# -# cond.release() -# await asyncio.sleep(0) -# self.assertEqual([1], result) -# self.assertTrue(cond.locked()) -# -# cond.notify(2) -# await asyncio.sleep(0) -# self.assertEqual([1], result) -# self.assertTrue(cond.locked()) -# -# cond.release() -# await asyncio.sleep(0) -# self.assertEqual([1, 2], result) -# self.assertTrue(cond.locked()) -# -# cond.release() -# await asyncio.sleep(0) -# self.assertEqual([1, 2, 3], result) -# self.assertTrue(cond.locked()) -# -# self.assertTrue(t1.done()) -# self.assertTrue(t1.result()) -# self.assertTrue(t2.done()) -# self.assertTrue(t2.result()) -# self.assertTrue(t3.done()) -# self.assertTrue(t3.result()) -# -# async def test_wait_cancel(self): -# cond = _async_create_lock(threading.Condition(threading.Lock())) -# await cond.acquire() -# -# wait = asyncio.create_task(cond.wait()) -# asyncio.get_running_loop().call_soon(wait.cancel) -# with self.assertRaises(asyncio.CancelledError): -# await wait -# self.assertFalse(cond._waiters) -# self.assertTrue(cond.locked()) -# -# async def test_wait_cancel_contested(self): -# cond = _async_create_lock(threading.Condition(threading.Lock())) -# -# await cond.acquire() -# self.assertTrue(cond.locked()) -# -# wait_task = asyncio.create_task(cond.wait()) -# await asyncio.sleep(0) -# self.assertFalse(cond.locked()) -# -# # Notify, but contest the lock before cancelling -# await cond.acquire() -# self.assertTrue(cond.locked()) -# cond.notify() -# asyncio.get_running_loop().call_soon(wait_task.cancel) -# asyncio.get_running_loop().call_soon(cond.release) -# -# try: -# await wait_task -# except asyncio.CancelledError: -# # Should not happen, since no cancellation points -# pass -# -# self.assertTrue(cond.locked()) -# -# async def test_wait_cancel_after_notify(self): -# # See bpo-32841 -# waited = False -# -# cond = _async_create_lock(threading.Condition(threading.Lock())) -# -# async def wait_on_cond(): -# nonlocal waited -# async with cond: -# waited = True # Make sure this area was reached -# await cond.wait() -# -# waiter = asyncio.create_task(wait_on_cond()) -# await asyncio.sleep(0) # Start waiting -# -# await cond.acquire() -# cond.notify() -# await asyncio.sleep(0) # Get to acquire() -# waiter.cancel() -# await asyncio.sleep(0) # Activate cancellation -# cond.release() -# await asyncio.sleep(0) # Cancellation should occur -# -# self.assertTrue(waiter.cancelled()) -# self.assertTrue(waited) -# -# async def test_wait_unacquired(self): -# cond = _async_create_lock(threading.Condition(threading.Lock())) -# with self.assertRaises(RuntimeError): -# await cond.wait() -# -# async def test_wait_for(self): -# cond = _async_create_lock(threading.Condition(threading.Lock())) -# presult = False -# -# def predicate(): -# return presult -# -# result = [] -# -# async def c1(result): -# await cond.acquire() -# if await cond.wait_for(predicate): -# result.append(1) -# cond.release() -# return True -# -# t = asyncio.create_task(c1(result)) -# -# await asyncio.sleep(0) -# self.assertEqual([], result) -# -# await cond.acquire() -# cond.notify() -# cond.release() -# await asyncio.sleep(0) -# self.assertEqual([], result) -# -# presult = True -# await cond.acquire() -# cond.notify() -# cond.release() -# await asyncio.sleep(0) -# self.assertEqual([1], result) -# -# self.assertTrue(t.done()) -# self.assertTrue(t.result()) -# -# async def test_wait_for_unacquired(self): -# cond = _async_create_lock(threading.Condition(threading.Lock())) -# -# # predicate can return true immediately -# res = await cond.wait_for(lambda: [1, 2, 3]) -# self.assertEqual([1, 2, 3], res) -# -# with self.assertRaises(RuntimeError): -# await cond.wait_for(lambda: False) -# -# async def test_notify(self): -# cond = _async_create_lock(threading.Condition(threading.Lock())) -# result = [] -# -# async def c1(result): -# async with cond: -# if await cond.wait(): -# result.append(1) -# return True -# -# async def c2(result): -# async with cond: -# if await cond.wait(): -# result.append(2) -# return True -# -# async def c3(result): -# async with cond: -# if await cond.wait(): -# result.append(3) -# return True -# -# t1 = asyncio.create_task(c1(result)) -# t2 = asyncio.create_task(c2(result)) -# t3 = asyncio.create_task(c3(result)) -# -# await asyncio.sleep(0) -# self.assertEqual([], result) -# -# async with cond: -# cond.notify(1) -# await asyncio.sleep(1) -# self.assertEqual([1], result) -# -# async with cond: -# cond.notify(1) -# cond.notify(2048) -# await asyncio.sleep(1) -# self.assertEqual([1, 2, 3], result) -# -# self.assertTrue(t1.done()) -# self.assertTrue(t1.result()) -# self.assertTrue(t2.done()) -# self.assertTrue(t2.result()) -# self.assertTrue(t3.done()) -# self.assertTrue(t3.result()) -# -# async def test_notify_all(self): -# cond = _async_create_lock(threading.Condition(threading.Lock())) -# -# result = [] -# -# async def c1(result): -# async with cond: -# if await cond.wait(): -# result.append(1) -# return True -# -# async def c2(result): -# async with cond: -# if await cond.wait(): -# result.append(2) -# return True -# -# t1 = asyncio.create_task(c1(result)) -# t2 = asyncio.create_task(c2(result)) -# -# await asyncio.sleep(0) -# self.assertEqual([], result) -# -# async with cond: -# cond.notify_all() -# await asyncio.sleep(1) -# self.assertEqual([1, 2], result) -# -# self.assertTrue(t1.done()) -# self.assertTrue(t1.result()) -# self.assertTrue(t2.done()) -# self.assertTrue(t2.result()) -# -# async def test_context_manager(self): -# cond = _async_create_lock(threading.Condition(threading.Lock())) -# self.assertFalse(cond.locked()) -# async with cond: -# self.assertTrue(cond.locked()) -# self.assertFalse(cond.locked()) -# -# async def test_timeout_in_block(self): -# condition = _async_create_lock(threading.Condition(threading.Lock())) -# async with condition: -# with self.assertRaises(asyncio.TimeoutError): -# await asyncio.wait_for(condition.wait(), timeout=0.5) -# -# @unittest.skipIf( -# sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11" -# ) -# async def test_cancelled_error_wakeup(self): -# # Test that a cancelled error, received when awaiting wakeup, -# # will be re-raised un-modified. -# wake = False -# raised = None -# cond = _async_create_lock(threading.Condition(threading.Lock())) -# -# async def func(): -# nonlocal raised -# async with cond: -# with self.assertRaises(asyncio.CancelledError) as err: -# await cond.wait_for(lambda: wake) -# raised = err.exception -# raise raised -# -# task = asyncio.create_task(func()) -# await asyncio.sleep(0) -# # Task is waiting on the condition, cancel it there. -# task.cancel(msg="foo") # type: ignore[call-arg] -# with self.assertRaises(asyncio.CancelledError) as err: -# await task -# self.assertEqual(err.exception.args, ("foo",)) -# # We should have got the _same_ exception instance as the one -# # originally raised. -# self.assertIs(err.exception, raised) -# -# @unittest.skipIf( -# sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11" -# ) -# async def test_cancelled_error_re_aquire(self): -# # Test that a cancelled error, received when re-aquiring lock, -# # will be re-raised un-modified. -# wake = False -# raised = None -# cond = _async_create_lock(threading.Condition(threading.Lock())) -# -# async def func(): -# nonlocal raised -# async with cond: -# with self.assertRaises(asyncio.CancelledError) as err: -# await cond.wait_for(lambda: wake) -# raised = err.exception -# raise raised -# -# task = asyncio.create_task(func()) -# await asyncio.sleep(0) -# # Task is waiting on the condition -# await cond.acquire() -# wake = True -# cond.notify() -# await asyncio.sleep(0) -# # Task is now trying to re-acquire the lock, cancel it there. -# task.cancel(msg="foo") # type: ignore[call-arg] -# cond.release() -# with self.assertRaises(asyncio.CancelledError) as err: -# await task -# self.assertEqual(err.exception.args, ("foo",)) -# # We should have got the _same_ exception instance as the one -# # originally raised. -# self.assertIs(err.exception, raised) -# -# @unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11") -# async def test_cancelled_wakeup(self): -# # Test that a task cancelled at the "same" time as it is woken -# # up as part of a Condition.notify() does not result in a lost wakeup. -# # This test simulates a cancel while the target task is awaiting initial -# # wakeup on the wakeup queue. -# condition = _async_create_lock(threading.Condition(threading.Lock())) -# state = 0 -# -# async def consumer(): -# nonlocal state -# async with condition: -# while True: -# await condition.wait_for(lambda: state != 0) -# if state < 0: -# return -# state -= 1 -# -# # create two consumers -# c = [asyncio.create_task(consumer()) for _ in range(2)] -# # wait for them to settle -# await asyncio.sleep(0.1) -# async with condition: -# # produce one item and wake up one -# state += 1 -# condition.notify(1) -# -# # Cancel it while it is awaiting to be run. -# # This cancellation could come from the outside -# c[0].cancel() -# -# # now wait for the item to be consumed -# # if it doesn't means that our "notify" didn"t take hold. -# # because it raced with a cancel() -# try: -# async with asyncio.timeout(1): -# await condition.wait_for(lambda: state == 0) -# except TimeoutError: -# pass -# self.assertEqual(state, 0) -# -# # clean up -# state = -1 -# condition.notify_all() -# await c[1] -# -# @unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11") -# async def test_cancelled_wakeup_relock(self): -# # Test that a task cancelled at the "same" time as it is woken -# # up as part of a Condition.notify() does not result in a lost wakeup. -# # This test simulates a cancel while the target task is acquiring the lock -# # again. -# condition = _async_create_lock(threading.Condition(threading.Lock())) -# state = 0 -# -# async def consumer(): -# nonlocal state -# async with condition: -# while True: -# await condition.wait_for(lambda: state != 0) -# if state < 0: -# return -# state -= 1 -# -# # create two consumers -# c = [asyncio.create_task(consumer()) for _ in range(2)] -# # wait for them to settle -# await asyncio.sleep(0.1) -# async with condition: -# # produce one item and wake up one -# state += 1 -# condition.notify(1) -# -# # now we sleep for a bit. This allows the target task to wake up and -# # settle on re-aquiring the lock -# await asyncio.sleep(0) -# -# # Cancel it while awaiting the lock -# # This cancel could come the outside. -# c[0].cancel() -# -# # now wait for the item to be consumed -# # if it doesn't means that our "notify" didn"t take hold. -# # because it raced with a cancel() -# try: -# async with asyncio.timeout(1): -# await condition.wait_for(lambda: state == 0) -# except TimeoutError: -# pass -# self.assertEqual(state, 0) -# -# # clean up -# state = -1 -# condition.notify_all() -# await c[1] -# -# -# class TestCondition(unittest.IsolatedAsyncioTestCase): -# async def test_multiple_loops_notify(self): -# cond = _async_create_lock(threading.Condition(threading.Lock())) -# -# def tmain(cond): -# async def atmain(cond): -# await asyncio.sleep(1) -# async with cond: -# cond.notify(1) -# -# asyncio.run(atmain(cond)) -# -# t = threading.Thread(target=tmain, args=(cond,)) -# t.start() -# -# async with cond: -# self.assertTrue(await cond.wait(30)) -# t.join() -# -# async def test_multiple_loops_notify_all(self): -# cond = _async_create_lock(threading.Condition(threading.Lock())) -# results = [] -# -# def tmain(cond, results): -# async def atmain(cond, results): -# await asyncio.sleep(1) -# async with cond: -# res = await cond.wait(30) -# results.append(res) -# -# asyncio.run(atmain(cond, results)) -# -# nthreads = 5 -# threads = [] -# for _ in range(nthreads): -# threads.append(threading.Thread(target=tmain, args=(cond, results))) -# for t in threads: -# t.start() -# -# await asyncio.sleep(2) -# async with cond: -# cond.notify_all() -# -# for t in threads: -# t.join() -# -# self.assertEqual(results, [True] * nthreads) -# -# -# if __name__ == "__main__": -# unittest.main() From 037f3b4f0a5464c8342782708d2460a23273412d Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 16 Oct 2024 12:08:05 -0400 Subject: [PATCH 4/9] Vendor 3.13 Lock + Condition classes --- pymongo/lock.py | 626 +++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 617 insertions(+), 9 deletions(-) diff --git a/pymongo/lock.py b/pymongo/lock.py index 26244ddefd..be0dcf4f0a 100644 --- a/pymongo/lock.py +++ b/pymongo/lock.py @@ -11,12 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Lock and Condition classes vendored from https://github.com/python/cpython/blob/main/Lib/asyncio/locks.py +to port 3.13 fixes to older versions of Python.""" + from __future__ import annotations -import asyncio +import collections +import enum import os import threading import weakref +from asyncio import exceptions, mixins, wait_for from typing import Any, Optional, TypeVar _HAS_REGISTER_AT_FORK = hasattr(os, "register_at_fork") @@ -27,6 +32,611 @@ _T = TypeVar("_T") +class _ContextManagerMixin: + async def __aenter__(self): + await self.acquire() + # We have no use for the "as ..." clause in the with + # statement for locks. + return + + async def __aexit__(self, exc_type, exc, tb): + self.release() + + +class Lock(_ContextManagerMixin, mixins._LoopBoundMixin): + """Primitive lock objects. + + A primitive lock is a synchronization primitive that is not owned + by a particular task when locked. A primitive lock is in one + of two states, 'locked' or 'unlocked'. + + It is created in the unlocked state. It has two basic methods, + acquire() and release(). When the state is unlocked, acquire() + changes the state to locked and returns immediately. When the + state is locked, acquire() blocks until a call to release() in + another task changes it to unlocked, then the acquire() call + resets it to locked and returns. The release() method should only + be called in the locked state; it changes the state to unlocked + and returns immediately. If an attempt is made to release an + unlocked lock, a RuntimeError will be raised. + + When more than one task is blocked in acquire() waiting for + the state to turn to unlocked, only one task proceeds when a + release() call resets the state to unlocked; successive release() + calls will unblock tasks in FIFO order. + + Locks also support the asynchronous context management protocol. + 'async with lock' statement should be used. + + Usage: + + lock = Lock() + ... + await lock.acquire() + try: + ... + finally: + lock.release() + + Context manager usage: + + lock = Lock() + ... + async with lock: + ... + + Lock objects can be tested for locking state: + + if not lock.locked(): + await lock.acquire() + else: + # lock is acquired + ... + + """ + + def __init__(self): + self._waiters = None + self._locked = False + + def __repr__(self): + res = super().__repr__() + extra = "locked" if self._locked else "unlocked" + if self._waiters: + extra = f"{extra}, waiters:{len(self._waiters)}" + return f"<{res[1:-1]} [{extra}]>" + + def locked(self): + """Return True if lock is acquired.""" + return self._locked + + async def acquire(self): + """Acquire a lock. + + This method blocks until the lock is unlocked, then sets it to + locked and returns True. + """ + # Implement fair scheduling, where thread always waits + # its turn. Jumping the queue if all are cancelled is an optimization. + if not self._locked and ( + self._waiters is None or all(w.cancelled() for w in self._waiters) + ): + self._locked = True + return True + + if self._waiters is None: + self._waiters = collections.deque() + fut = self._get_loop().create_future() + self._waiters.append(fut) + + try: + try: + await fut + finally: + self._waiters.remove(fut) + except exceptions.CancelledError: + # Currently the only exception designed be able to occur here. + + # Ensure the lock invariant: If lock is not claimed (or about + # to be claimed by us) and there is a Task in waiters, + # ensure that the Task at the head will run. + if not self._locked: + self._wake_up_first() + raise + + # assert self._locked is False + self._locked = True + return True + + def release(self): + """Release a lock. + + When the lock is locked, reset it to unlocked, and return. + If any other tasks are blocked waiting for the lock to become + unlocked, allow exactly one of them to proceed. + + When invoked on an unlocked lock, a RuntimeError is raised. + + There is no return value. + """ + if self._locked: + self._locked = False + self._wake_up_first() + else: + raise RuntimeError("Lock is not acquired.") + + def _wake_up_first(self): + """Ensure that the first waiter will wake up.""" + if not self._waiters: + return + try: + fut = next(iter(self._waiters)) + except StopIteration: + return + + # .done() means that the waiter is already set to wake up. + if not fut.done(): + fut.set_result(True) + + +class Event(mixins._LoopBoundMixin): + """Asynchronous equivalent to threading.Event. + + Class implementing event objects. An event manages a flag that can be set + to true with the set() method and reset to false with the clear() method. + The wait() method blocks until the flag is true. The flag is initially + false. + """ + + def __init__(self): + self._waiters = collections.deque() + self._value = False + + def __repr__(self): + res = super().__repr__() + extra = "set" if self._value else "unset" + if self._waiters: + extra = f"{extra}, waiters:{len(self._waiters)}" + return f"<{res[1:-1]} [{extra}]>" + + def is_set(self): + """Return True if and only if the internal flag is true.""" + return self._value + + def set(self): + """Set the internal flag to true. All tasks waiting for it to + become true are awakened. Tasks that call wait() once the flag is + true will not block at all. + """ + if not self._value: + self._value = True + + for fut in self._waiters: + if not fut.done(): + fut.set_result(True) + + def clear(self): + """Reset the internal flag to false. Subsequently, tasks calling + wait() will block until set() is called to set the internal flag + to true again.""" + self._value = False + + async def wait(self): + """Block until the internal flag is true. + + If the internal flag is true on entry, return True + immediately. Otherwise, block until another task calls + set() to set the flag to true, then return True. + """ + if self._value: + return True + + fut = self._get_loop().create_future() + self._waiters.append(fut) + try: + await fut + return True + finally: + self._waiters.remove(fut) + + +class Condition(_ContextManagerMixin, mixins._LoopBoundMixin): + """Asynchronous equivalent to threading.Condition. + + This class implements condition variable objects. A condition variable + allows one or more tasks to wait until they are notified by another + task. + + A new Lock object is created and used as the underlying lock. + """ + + def __init__(self, lock=None): + if lock is None: + lock = Lock() + + self._lock = lock + # Export the lock's locked(), acquire() and release() methods. + self.locked = lock.locked + self.acquire = lock.acquire + self.release = lock.release + + self._waiters = collections.deque() + + def __repr__(self): + res = super().__repr__() + extra = "locked" if self.locked() else "unlocked" + if self._waiters: + extra = f"{extra}, waiters:{len(self._waiters)}" + return f"<{res[1:-1]} [{extra}]>" + + async def wait(self): + """Wait until notified. + + If the calling task has not acquired the lock when this + method is called, a RuntimeError is raised. + + This method releases the underlying lock, and then blocks + until it is awakened by a notify() or notify_all() call for + the same condition variable in another task. Once + awakened, it re-acquires the lock and returns True. + + This method may return spuriously, + which is why the caller should always + re-check the state and be prepared to wait() again. + """ + if not self.locked(): + raise RuntimeError("cannot wait on un-acquired lock") + + fut = self._get_loop().create_future() + self.release() + try: + try: + self._waiters.append(fut) + try: + await fut + return True + finally: + self._waiters.remove(fut) + + finally: + # Must re-acquire lock even if wait is cancelled. + # We only catch CancelledError here, since we don't want any + # other (fatal) errors with the future to cause us to spin. + err = None + while True: + try: + await self.acquire() + break + except exceptions.CancelledError as e: + err = e + + if err is not None: + try: + raise err # Re-raise most recent exception instance. + finally: + err = None # Break reference cycles. + except BaseException: + # Any error raised out of here _may_ have occurred after this Task + # believed to have been successfully notified. + # Make sure to notify another Task instead. This may result + # in a "spurious wakeup", which is allowed as part of the + # Condition Variable protocol. + self._notify(1) + raise + + async def wait_for(self, predicate): + """Wait until a predicate becomes true. + + The predicate should be a callable whose result will be + interpreted as a boolean value. The method will repeatedly + wait() until it evaluates to true. The final predicate value is + the return value. + """ + result = predicate() + while not result: + await self.wait() + result = predicate() + return result + + def notify(self, n=1): + """By default, wake up one task waiting on this condition, if any. + If the calling task has not acquired the lock when this method + is called, a RuntimeError is raised. + + This method wakes up n of the tasks waiting for the condition + variable; if fewer than n are waiting, they are all awoken. + + Note: an awakened task does not actually return from its + wait() call until it can reacquire the lock. Since notify() does + not release the lock, its caller should. + """ + if not self.locked(): + raise RuntimeError("cannot notify on un-acquired lock") + self._notify(n) + + def _notify(self, n): + idx = 0 + for fut in self._waiters: + if idx >= n: + break + + if not fut.done(): + idx += 1 + fut.set_result(False) + + def notify_all(self): + """Wake up all tasks waiting on this condition. This method acts + like notify(), but wakes up all waiting tasks instead of one. If the + calling task has not acquired the lock when this method is called, + a RuntimeError is raised. + """ + self.notify(len(self._waiters)) + + +class Semaphore(_ContextManagerMixin, mixins._LoopBoundMixin): + """A Semaphore implementation. + + A semaphore manages an internal counter which is decremented by each + acquire() call and incremented by each release() call. The counter + can never go below zero; when acquire() finds that it is zero, it blocks, + waiting until some other thread calls release(). + + Semaphores also support the context management protocol. + + The optional argument gives the initial value for the internal + counter; it defaults to 1. If the value given is less than 0, + ValueError is raised. + """ + + def __init__(self, value=1): + if value < 0: + raise ValueError("Semaphore initial value must be >= 0") + self._waiters = None + self._value = value + + def __repr__(self): + res = super().__repr__() + extra = "locked" if self.locked() else f"unlocked, value:{self._value}" + if self._waiters: + extra = f"{extra}, waiters:{len(self._waiters)}" + return f"<{res[1:-1]} [{extra}]>" + + def locked(self): + """Returns True if semaphore cannot be acquired immediately.""" + # Due to state, or FIFO rules (must allow others to run first). + return self._value == 0 or (any(not w.cancelled() for w in (self._waiters or ()))) + + async def acquire(self): + """Acquire a semaphore. + + If the internal counter is larger than zero on entry, + decrement it by one and return True immediately. If it is + zero on entry, block, waiting until some other task has + called release() to make it larger than 0, and then return + True. + """ + if not self.locked(): + # Maintain FIFO, wait for others to start even if _value > 0. + self._value -= 1 + return True + + if self._waiters is None: + self._waiters = collections.deque() + fut = self._get_loop().create_future() + self._waiters.append(fut) + + try: + try: + await fut + finally: + self._waiters.remove(fut) + except exceptions.CancelledError: + # Currently the only exception designed be able to occur here. + if fut.done() and not fut.cancelled(): + # Our Future was successfully set to True via _wake_up_next(), + # but we are not about to successfully acquire(). Therefore we + # must undo the bookkeeping already done and attempt to wake + # up someone else. + self._value += 1 + raise + + finally: + # New waiters may have arrived but had to wait due to FIFO. + # Wake up as many as are allowed. + while self._value > 0: + if not self._wake_up_next(): + break # There was no-one to wake up. + return True + + def release(self): + """Release a semaphore, incrementing the internal counter by one. + + When it was zero on entry and another task is waiting for it to + become larger than zero again, wake up that task. + """ + self._value += 1 + self._wake_up_next() + + def _wake_up_next(self): + """Wake up the first waiter that isn't done.""" + if not self._waiters: + return False + + for fut in self._waiters: + if not fut.done(): + self._value -= 1 + fut.set_result(True) + # `fut` is now `done()` and not `cancelled()`. + return True + return False + + +class BoundedSemaphore(Semaphore): + """A bounded semaphore implementation. + + This raises ValueError in release() if it would increase the value + above the initial value. + """ + + def __init__(self, value=1): + self._bound_value = value + super().__init__(value) + + def release(self): + if self._value >= self._bound_value: + raise ValueError("BoundedSemaphore released too many times") + super().release() + + +class _BarrierState(enum.Enum): + FILLING = "filling" + DRAINING = "draining" + RESETTING = "resetting" + BROKEN = "broken" + + +class Barrier(mixins._LoopBoundMixin): + """Asyncio equivalent to threading.Barrier + + Implements a Barrier primitive. + Useful for synchronizing a fixed number of tasks at known synchronization + points. Tasks block on 'wait()' and are simultaneously awoken once they + have all made their call. + """ + + def __init__(self, parties): + """Create a barrier, initialised to 'parties' tasks.""" + if parties < 1: + raise ValueError("parties must be > 0") + + self._cond = Condition() # notify all tasks when state changes + + self._parties = parties + self._state = _BarrierState.FILLING + self._count = 0 # count tasks in Barrier + + def __repr__(self): + res = super().__repr__() + extra = f"{self._state.value}" + if not self.broken: + extra += f", waiters:{self.n_waiting}/{self.parties}" + return f"<{res[1:-1]} [{extra}]>" + + async def __aenter__(self): + # wait for the barrier reaches the parties number + # when start draining release and return index of waited task + return await self.wait() + + async def __aexit__(self, *args): + pass + + async def wait(self): + """Wait for the barrier. + + When the specified number of tasks have started waiting, they are all + simultaneously awoken. + Returns an unique and individual index number from 0 to 'parties-1'. + """ + async with self._cond: + await self._block() # Block while the barrier drains or resets. + try: + index = self._count + self._count += 1 + if index + 1 == self._parties: + # We release the barrier + await self._release() + else: + await self._wait() + return index + finally: + self._count -= 1 + # Wake up any tasks waiting for barrier to drain. + self._exit() + + async def _block(self): + # Block until the barrier is ready for us, + # or raise an exception if it is broken. + # + # It is draining or resetting, wait until done + # unless a CancelledError occurs + await self._cond.wait_for( + lambda: self._state not in (_BarrierState.DRAINING, _BarrierState.RESETTING) + ) + + # see if the barrier is in a broken state + if self._state is _BarrierState.BROKEN: + raise exceptions.BrokenBarrierError("Barrier aborted") + + async def _release(self): + # Release the tasks waiting in the barrier. + + # Enter draining state. + # Next waiting tasks will be blocked until the end of draining. + self._state = _BarrierState.DRAINING + self._cond.notify_all() + + async def _wait(self): + # Wait in the barrier until we are released. Raise an exception + # if the barrier is reset or broken. + + # wait for end of filling + # unless a CancelledError occurs + await self._cond.wait_for(lambda: self._state is not _BarrierState.FILLING) + + if self._state in (_BarrierState.BROKEN, _BarrierState.RESETTING): + raise exceptions.BrokenBarrierError("Abort or reset of barrier") + + def _exit(self): + # If we are the last tasks to exit the barrier, signal any tasks + # waiting for the barrier to drain. + if self._count == 0: + if self._state in (_BarrierState.RESETTING, _BarrierState.DRAINING): + self._state = _BarrierState.FILLING + self._cond.notify_all() + + async def reset(self): + """Reset the barrier to the initial state. + + Any tasks currently waiting will get the BrokenBarrier exception + raised. + """ + async with self._cond: + if self._count > 0: + if self._state is not _BarrierState.RESETTING: + # reset the barrier, waking up tasks + self._state = _BarrierState.RESETTING + else: + self._state = _BarrierState.FILLING + self._cond.notify_all() + + async def abort(self): + """Place the barrier into a 'broken' state. + + Useful in case of error. Any currently waiting tasks and tasks + attempting to 'wait()' will have BrokenBarrierError raised. + """ + async with self._cond: + self._state = _BarrierState.BROKEN + self._cond.notify_all() + + @property + def parties(self): + """Return the number of tasks required to trip the barrier.""" + return self._parties + + @property + def n_waiting(self): + """Return the number of tasks currently waiting at the barrier.""" + if self._state is _BarrierState.FILLING: + return self._count + return 0 + + @property + def broken(self): + """Return True if the barrier is in a broken state.""" + return self._state is _BarrierState.BROKEN + + def _create_lock() -> threading.Lock: """Represents a lock that is tracked upon instantiation using a WeakSet and reset by pymongo upon forking. @@ -37,9 +647,9 @@ def _create_lock() -> threading.Lock: return lock -def _async_create_lock() -> asyncio.Lock: +def _async_create_lock() -> Lock: """Represents an asyncio.Lock.""" - return asyncio.Lock() + return Lock() def _create_condition( @@ -51,13 +661,11 @@ def _create_condition( return threading.Condition(lock) -def _async_create_condition( - lock: asyncio.Lock, condition_class: Optional[Any] = None -) -> asyncio.Condition: +def _async_create_condition(lock: Lock, condition_class: Optional[Any] = None) -> Condition: """Represents an asyncio.Condition.""" if condition_class: return condition_class(lock) - return asyncio.Condition(lock) + return Condition(lock) def _release_locks() -> None: @@ -67,9 +675,9 @@ def _release_locks() -> None: lock.release() -async def _async_cond_wait(condition: asyncio.Condition, timeout: Optional[float]) -> bool: +async def _async_cond_wait(condition: Condition, timeout: Optional[float]) -> bool: try: - return await asyncio.wait_for(condition.wait(), timeout) + return await wait_for(condition.wait(), timeout) except TimeoutError: return False From a96b8b6005ae24b3eec8af8413debac0db8979cd Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 16 Oct 2024 12:35:39 -0400 Subject: [PATCH 5/9] Fix lock typing and imports --- pymongo/lock.py | 389 +++++------------------------------------------- 1 file changed, 40 insertions(+), 349 deletions(-) diff --git a/pymongo/lock.py b/pymongo/lock.py index be0dcf4f0a..4c6c210575 100644 --- a/pymongo/lock.py +++ b/pymongo/lock.py @@ -17,12 +17,11 @@ from __future__ import annotations import collections -import enum import os import threading import weakref -from asyncio import exceptions, mixins, wait_for -from typing import Any, Optional, TypeVar +from asyncio import events, exceptions, wait_for +from typing import Any, Coroutine, Optional, TypeVar _HAS_REGISTER_AT_FORK = hasattr(os, "register_at_fork") @@ -31,19 +30,36 @@ _T = TypeVar("_T") +_global_lock = threading.Lock() + + +class _LoopBoundMixin: + _loop = None + + def _get_loop(self) -> Any: + loop = events._get_running_loop() + + if self._loop is None: + with _global_lock: + if self._loop is None: + self._loop = loop + if loop is not self._loop: + raise RuntimeError(f"{self!r} is bound to a different event loop") + return loop + class _ContextManagerMixin: - async def __aenter__(self): - await self.acquire() + async def __aenter__(self) -> None: + await self.acquire() # type: ignore[attr-defined] # We have no use for the "as ..." clause in the with # statement for locks. return - async def __aexit__(self, exc_type, exc, tb): - self.release() + async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: + self.release() # type: ignore[attr-defined] -class Lock(_ContextManagerMixin, mixins._LoopBoundMixin): +class Lock(_ContextManagerMixin, _LoopBoundMixin): """Primitive lock objects. A primitive lock is a synchronization primitive that is not owned @@ -95,22 +111,22 @@ class Lock(_ContextManagerMixin, mixins._LoopBoundMixin): """ - def __init__(self): - self._waiters = None + def __init__(self) -> None: + self._waiters: Optional[collections.deque] = None self._locked = False - def __repr__(self): + def __repr__(self) -> str: res = super().__repr__() extra = "locked" if self._locked else "unlocked" if self._waiters: extra = f"{extra}, waiters:{len(self._waiters)}" return f"<{res[1:-1]} [{extra}]>" - def locked(self): + def locked(self) -> bool: """Return True if lock is acquired.""" return self._locked - async def acquire(self): + async def acquire(self) -> bool: """Acquire a lock. This method blocks until the lock is unlocked, then sets it to @@ -148,7 +164,7 @@ async def acquire(self): self._locked = True return True - def release(self): + def release(self) -> None: """Release a lock. When the lock is locked, reset it to unlocked, and return. @@ -165,7 +181,7 @@ def release(self): else: raise RuntimeError("Lock is not acquired.") - def _wake_up_first(self): + def _wake_up_first(self) -> None: """Ensure that the first waiter will wake up.""" if not self._waiters: return @@ -179,68 +195,7 @@ def _wake_up_first(self): fut.set_result(True) -class Event(mixins._LoopBoundMixin): - """Asynchronous equivalent to threading.Event. - - Class implementing event objects. An event manages a flag that can be set - to true with the set() method and reset to false with the clear() method. - The wait() method blocks until the flag is true. The flag is initially - false. - """ - - def __init__(self): - self._waiters = collections.deque() - self._value = False - - def __repr__(self): - res = super().__repr__() - extra = "set" if self._value else "unset" - if self._waiters: - extra = f"{extra}, waiters:{len(self._waiters)}" - return f"<{res[1:-1]} [{extra}]>" - - def is_set(self): - """Return True if and only if the internal flag is true.""" - return self._value - - def set(self): - """Set the internal flag to true. All tasks waiting for it to - become true are awakened. Tasks that call wait() once the flag is - true will not block at all. - """ - if not self._value: - self._value = True - - for fut in self._waiters: - if not fut.done(): - fut.set_result(True) - - def clear(self): - """Reset the internal flag to false. Subsequently, tasks calling - wait() will block until set() is called to set the internal flag - to true again.""" - self._value = False - - async def wait(self): - """Block until the internal flag is true. - - If the internal flag is true on entry, return True - immediately. Otherwise, block until another task calls - set() to set the flag to true, then return True. - """ - if self._value: - return True - - fut = self._get_loop().create_future() - self._waiters.append(fut) - try: - await fut - return True - finally: - self._waiters.remove(fut) - - -class Condition(_ContextManagerMixin, mixins._LoopBoundMixin): +class Condition(_ContextManagerMixin, _LoopBoundMixin): """Asynchronous equivalent to threading.Condition. This class implements condition variable objects. A condition variable @@ -250,7 +205,7 @@ class Condition(_ContextManagerMixin, mixins._LoopBoundMixin): A new Lock object is created and used as the underlying lock. """ - def __init__(self, lock=None): + def __init__(self, lock: Optional[Lock] = None) -> None: if lock is None: lock = Lock() @@ -260,16 +215,16 @@ def __init__(self, lock=None): self.acquire = lock.acquire self.release = lock.release - self._waiters = collections.deque() + self._waiters: collections.deque = collections.deque() - def __repr__(self): + def __repr__(self) -> str: res = super().__repr__() extra = "locked" if self.locked() else "unlocked" if self._waiters: extra = f"{extra}, waiters:{len(self._waiters)}" return f"<{res[1:-1]} [{extra}]>" - async def wait(self): + async def wait(self) -> bool: """Wait until notified. If the calling task has not acquired the lock when this @@ -324,7 +279,7 @@ async def wait(self): self._notify(1) raise - async def wait_for(self, predicate): + async def wait_for(self, predicate: Any) -> Coroutine: """Wait until a predicate becomes true. The predicate should be a callable whose result will be @@ -338,7 +293,7 @@ async def wait_for(self, predicate): result = predicate() return result - def notify(self, n=1): + def notify(self, n: int = 1) -> None: """By default, wake up one task waiting on this condition, if any. If the calling task has not acquired the lock when this method is called, a RuntimeError is raised. @@ -354,7 +309,7 @@ def notify(self, n=1): raise RuntimeError("cannot notify on un-acquired lock") self._notify(n) - def _notify(self, n): + def _notify(self, n: int) -> None: idx = 0 for fut in self._waiters: if idx >= n: @@ -364,7 +319,7 @@ def _notify(self, n): idx += 1 fut.set_result(False) - def notify_all(self): + def notify_all(self) -> None: """Wake up all tasks waiting on this condition. This method acts like notify(), but wakes up all waiting tasks instead of one. If the calling task has not acquired the lock when this method is called, @@ -373,270 +328,6 @@ def notify_all(self): self.notify(len(self._waiters)) -class Semaphore(_ContextManagerMixin, mixins._LoopBoundMixin): - """A Semaphore implementation. - - A semaphore manages an internal counter which is decremented by each - acquire() call and incremented by each release() call. The counter - can never go below zero; when acquire() finds that it is zero, it blocks, - waiting until some other thread calls release(). - - Semaphores also support the context management protocol. - - The optional argument gives the initial value for the internal - counter; it defaults to 1. If the value given is less than 0, - ValueError is raised. - """ - - def __init__(self, value=1): - if value < 0: - raise ValueError("Semaphore initial value must be >= 0") - self._waiters = None - self._value = value - - def __repr__(self): - res = super().__repr__() - extra = "locked" if self.locked() else f"unlocked, value:{self._value}" - if self._waiters: - extra = f"{extra}, waiters:{len(self._waiters)}" - return f"<{res[1:-1]} [{extra}]>" - - def locked(self): - """Returns True if semaphore cannot be acquired immediately.""" - # Due to state, or FIFO rules (must allow others to run first). - return self._value == 0 or (any(not w.cancelled() for w in (self._waiters or ()))) - - async def acquire(self): - """Acquire a semaphore. - - If the internal counter is larger than zero on entry, - decrement it by one and return True immediately. If it is - zero on entry, block, waiting until some other task has - called release() to make it larger than 0, and then return - True. - """ - if not self.locked(): - # Maintain FIFO, wait for others to start even if _value > 0. - self._value -= 1 - return True - - if self._waiters is None: - self._waiters = collections.deque() - fut = self._get_loop().create_future() - self._waiters.append(fut) - - try: - try: - await fut - finally: - self._waiters.remove(fut) - except exceptions.CancelledError: - # Currently the only exception designed be able to occur here. - if fut.done() and not fut.cancelled(): - # Our Future was successfully set to True via _wake_up_next(), - # but we are not about to successfully acquire(). Therefore we - # must undo the bookkeeping already done and attempt to wake - # up someone else. - self._value += 1 - raise - - finally: - # New waiters may have arrived but had to wait due to FIFO. - # Wake up as many as are allowed. - while self._value > 0: - if not self._wake_up_next(): - break # There was no-one to wake up. - return True - - def release(self): - """Release a semaphore, incrementing the internal counter by one. - - When it was zero on entry and another task is waiting for it to - become larger than zero again, wake up that task. - """ - self._value += 1 - self._wake_up_next() - - def _wake_up_next(self): - """Wake up the first waiter that isn't done.""" - if not self._waiters: - return False - - for fut in self._waiters: - if not fut.done(): - self._value -= 1 - fut.set_result(True) - # `fut` is now `done()` and not `cancelled()`. - return True - return False - - -class BoundedSemaphore(Semaphore): - """A bounded semaphore implementation. - - This raises ValueError in release() if it would increase the value - above the initial value. - """ - - def __init__(self, value=1): - self._bound_value = value - super().__init__(value) - - def release(self): - if self._value >= self._bound_value: - raise ValueError("BoundedSemaphore released too many times") - super().release() - - -class _BarrierState(enum.Enum): - FILLING = "filling" - DRAINING = "draining" - RESETTING = "resetting" - BROKEN = "broken" - - -class Barrier(mixins._LoopBoundMixin): - """Asyncio equivalent to threading.Barrier - - Implements a Barrier primitive. - Useful for synchronizing a fixed number of tasks at known synchronization - points. Tasks block on 'wait()' and are simultaneously awoken once they - have all made their call. - """ - - def __init__(self, parties): - """Create a barrier, initialised to 'parties' tasks.""" - if parties < 1: - raise ValueError("parties must be > 0") - - self._cond = Condition() # notify all tasks when state changes - - self._parties = parties - self._state = _BarrierState.FILLING - self._count = 0 # count tasks in Barrier - - def __repr__(self): - res = super().__repr__() - extra = f"{self._state.value}" - if not self.broken: - extra += f", waiters:{self.n_waiting}/{self.parties}" - return f"<{res[1:-1]} [{extra}]>" - - async def __aenter__(self): - # wait for the barrier reaches the parties number - # when start draining release and return index of waited task - return await self.wait() - - async def __aexit__(self, *args): - pass - - async def wait(self): - """Wait for the barrier. - - When the specified number of tasks have started waiting, they are all - simultaneously awoken. - Returns an unique and individual index number from 0 to 'parties-1'. - """ - async with self._cond: - await self._block() # Block while the barrier drains or resets. - try: - index = self._count - self._count += 1 - if index + 1 == self._parties: - # We release the barrier - await self._release() - else: - await self._wait() - return index - finally: - self._count -= 1 - # Wake up any tasks waiting for barrier to drain. - self._exit() - - async def _block(self): - # Block until the barrier is ready for us, - # or raise an exception if it is broken. - # - # It is draining or resetting, wait until done - # unless a CancelledError occurs - await self._cond.wait_for( - lambda: self._state not in (_BarrierState.DRAINING, _BarrierState.RESETTING) - ) - - # see if the barrier is in a broken state - if self._state is _BarrierState.BROKEN: - raise exceptions.BrokenBarrierError("Barrier aborted") - - async def _release(self): - # Release the tasks waiting in the barrier. - - # Enter draining state. - # Next waiting tasks will be blocked until the end of draining. - self._state = _BarrierState.DRAINING - self._cond.notify_all() - - async def _wait(self): - # Wait in the barrier until we are released. Raise an exception - # if the barrier is reset or broken. - - # wait for end of filling - # unless a CancelledError occurs - await self._cond.wait_for(lambda: self._state is not _BarrierState.FILLING) - - if self._state in (_BarrierState.BROKEN, _BarrierState.RESETTING): - raise exceptions.BrokenBarrierError("Abort or reset of barrier") - - def _exit(self): - # If we are the last tasks to exit the barrier, signal any tasks - # waiting for the barrier to drain. - if self._count == 0: - if self._state in (_BarrierState.RESETTING, _BarrierState.DRAINING): - self._state = _BarrierState.FILLING - self._cond.notify_all() - - async def reset(self): - """Reset the barrier to the initial state. - - Any tasks currently waiting will get the BrokenBarrier exception - raised. - """ - async with self._cond: - if self._count > 0: - if self._state is not _BarrierState.RESETTING: - # reset the barrier, waking up tasks - self._state = _BarrierState.RESETTING - else: - self._state = _BarrierState.FILLING - self._cond.notify_all() - - async def abort(self): - """Place the barrier into a 'broken' state. - - Useful in case of error. Any currently waiting tasks and tasks - attempting to 'wait()' will have BrokenBarrierError raised. - """ - async with self._cond: - self._state = _BarrierState.BROKEN - self._cond.notify_all() - - @property - def parties(self): - """Return the number of tasks required to trip the barrier.""" - return self._parties - - @property - def n_waiting(self): - """Return the number of tasks currently waiting at the barrier.""" - if self._state is _BarrierState.FILLING: - return self._count - return 0 - - @property - def broken(self): - """Return True if the barrier is in a broken state.""" - return self._state is _BarrierState.BROKEN - - def _create_lock() -> threading.Lock: """Represents a lock that is tracked upon instantiation using a WeakSet and reset by pymongo upon forking. From f26a82bcb1855ee355c2fdebc425ff511ae45f42 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 16 Oct 2024 14:22:49 -0400 Subject: [PATCH 6/9] Fix asyncio.TimeoutError import for locks --- pymongo/lock.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymongo/lock.py b/pymongo/lock.py index 4c6c210575..8d80ef1777 100644 --- a/pymongo/lock.py +++ b/pymongo/lock.py @@ -20,7 +20,7 @@ import os import threading import weakref -from asyncio import events, exceptions, wait_for +from asyncio import TimeoutError, events, exceptions, wait_for from typing import Any, Coroutine, Optional, TypeVar _HAS_REGISTER_AT_FORK = hasattr(os, "register_at_fork") From 67d1a539c0669566ae73b705048eba9fac99a9f4 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 16 Oct 2024 14:49:17 -0400 Subject: [PATCH 7/9] Add license notice --- THIRD-PARTY-NOTICES | 59 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/THIRD-PARTY-NOTICES b/THIRD-PARTY-NOTICES index 0b9fc738ed..2ba43abde4 100644 --- a/THIRD-PARTY-NOTICES +++ b/THIRD-PARTY-NOTICES @@ -71,3 +71,62 @@ OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +3) License Notice for lock.py +----------------------------------------- + +1. This LICENSE AGREEMENT is between the Python Software Foundation +("PSF"), and the Individual or Organization ("Licensee") accessing and +otherwise using this software ("Python") in source or binary form and +its associated documentation. + +2. Subject to the terms and conditions of this License Agreement, PSF hereby +grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce, +analyze, test, perform and/or display publicly, prepare derivative works, +distribute, and otherwise use Python alone or in any derivative version, +provided, however, that PSF's License Agreement and PSF's notice of copyright, +i.e., "Copyright (c) 2001-2024 Python Software Foundation; All Rights Reserved" +are retained in Python alone or in any derivative version prepared by Licensee. + +3. In the event Licensee prepares a derivative work that is based on +or incorporates Python or any part thereof, and wants to make +the derivative work available to others as provided herein, then +Licensee hereby agrees to include in any such work a brief summary of +the changes made to Python. + +4. PSF is making Python available to Licensee on an "AS IS" +basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND +DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT +INFRINGE ANY THIRD PARTY RIGHTS. + +5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON +FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS +A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON, +OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. + +6. This License Agreement will automatically terminate upon a material +breach of its terms and conditions. + +7. Nothing in this License Agreement shall be deemed to create any +relationship of agency, partnership, or joint venture between PSF and +Licensee. This License Agreement does not grant permission to use PSF +trademarks or trade name in a trademark sense to endorse or promote +products or services of Licensee, or any third party. + +8. By copying, installing or otherwise using Python, Licensee +agrees to be bound by the terms and conditions of this License +Agreement. + +Permission to use, copy, modify, and/or distribute this software for any +purpose with or without fee is hereby granted. + +THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH +REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY +AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, +INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM +LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR +OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR +PERFORMANCE OF THIS SOFTWARE. From 51ffa4a9c95bcc25c411199591903a28389d6507 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 16 Oct 2024 14:55:24 -0400 Subject: [PATCH 8/9] Move vendored lock and condition classes into separate file --- THIRD-PARTY-NOTICES | 2 +- pymongo/_asyncio_lock.py | 309 +++++++++++++++++++++ pymongo/asynchronous/pool.py | 3 +- pymongo/lock.py | 316 ++-------------------- pymongo/synchronous/pool.py | 3 +- test/asynchronous/test_locks.py | 463 ++++++++++++++++++++++++++++++++ 6 files changed, 792 insertions(+), 304 deletions(-) create mode 100644 pymongo/_asyncio_lock.py create mode 100644 test/asynchronous/test_locks.py diff --git a/THIRD-PARTY-NOTICES b/THIRD-PARTY-NOTICES index 2ba43abde4..7e20a6f2bd 100644 --- a/THIRD-PARTY-NOTICES +++ b/THIRD-PARTY-NOTICES @@ -73,7 +73,7 @@ OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -3) License Notice for lock.py +3) License Notice for async_lock.py ----------------------------------------- 1. This LICENSE AGREEMENT is between the Python Software Foundation diff --git a/pymongo/_asyncio_lock.py b/pymongo/_asyncio_lock.py new file mode 100644 index 0000000000..669b0f63a7 --- /dev/null +++ b/pymongo/_asyncio_lock.py @@ -0,0 +1,309 @@ +# Copyright (c) 2001-2024 Python Software Foundation; All Rights Reserved + +"""Lock and Condition classes vendored from https://github.com/python/cpython/blob/main/Lib/asyncio/locks.py +to port 3.13 fixes to older versions of Python. +Can be removed once we drop Python 3.12 support.""" + +from __future__ import annotations + +import collections +import threading +from asyncio import events, exceptions +from typing import Any, Coroutine, Optional + +_global_lock = threading.Lock() + + +class _LoopBoundMixin: + _loop = None + + def _get_loop(self) -> Any: + loop = events._get_running_loop() + + if self._loop is None: + with _global_lock: + if self._loop is None: + self._loop = loop + if loop is not self._loop: + raise RuntimeError(f"{self!r} is bound to a different event loop") + return loop + + +class _ContextManagerMixin: + async def __aenter__(self) -> None: + await self.acquire() # type: ignore[attr-defined] + # We have no use for the "as ..." clause in the with + # statement for locks. + return + + async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: + self.release() # type: ignore[attr-defined] + + +class Lock(_ContextManagerMixin, _LoopBoundMixin): + """Primitive lock objects. + + A primitive lock is a synchronization primitive that is not owned + by a particular task when locked. A primitive lock is in one + of two states, 'locked' or 'unlocked'. + + It is created in the unlocked state. It has two basic methods, + acquire() and release(). When the state is unlocked, acquire() + changes the state to locked and returns immediately. When the + state is locked, acquire() blocks until a call to release() in + another task changes it to unlocked, then the acquire() call + resets it to locked and returns. The release() method should only + be called in the locked state; it changes the state to unlocked + and returns immediately. If an attempt is made to release an + unlocked lock, a RuntimeError will be raised. + + When more than one task is blocked in acquire() waiting for + the state to turn to unlocked, only one task proceeds when a + release() call resets the state to unlocked; successive release() + calls will unblock tasks in FIFO order. + + Locks also support the asynchronous context management protocol. + 'async with lock' statement should be used. + + Usage: + + lock = Lock() + ... + await lock.acquire() + try: + ... + finally: + lock.release() + + Context manager usage: + + lock = Lock() + ... + async with lock: + ... + + Lock objects can be tested for locking state: + + if not lock.locked(): + await lock.acquire() + else: + # lock is acquired + ... + + """ + + def __init__(self) -> None: + self._waiters: Optional[collections.deque] = None + self._locked = False + + def __repr__(self) -> str: + res = super().__repr__() + extra = "locked" if self._locked else "unlocked" + if self._waiters: + extra = f"{extra}, waiters:{len(self._waiters)}" + return f"<{res[1:-1]} [{extra}]>" + + def locked(self) -> bool: + """Return True if lock is acquired.""" + return self._locked + + async def acquire(self) -> bool: + """Acquire a lock. + + This method blocks until the lock is unlocked, then sets it to + locked and returns True. + """ + # Implement fair scheduling, where thread always waits + # its turn. Jumping the queue if all are cancelled is an optimization. + if not self._locked and ( + self._waiters is None or all(w.cancelled() for w in self._waiters) + ): + self._locked = True + return True + + if self._waiters is None: + self._waiters = collections.deque() + fut = self._get_loop().create_future() + self._waiters.append(fut) + + try: + try: + await fut + finally: + self._waiters.remove(fut) + except exceptions.CancelledError: + # Currently the only exception designed be able to occur here. + + # Ensure the lock invariant: If lock is not claimed (or about + # to be claimed by us) and there is a Task in waiters, + # ensure that the Task at the head will run. + if not self._locked: + self._wake_up_first() + raise + + # assert self._locked is False + self._locked = True + return True + + def release(self) -> None: + """Release a lock. + + When the lock is locked, reset it to unlocked, and return. + If any other tasks are blocked waiting for the lock to become + unlocked, allow exactly one of them to proceed. + + When invoked on an unlocked lock, a RuntimeError is raised. + + There is no return value. + """ + if self._locked: + self._locked = False + self._wake_up_first() + else: + raise RuntimeError("Lock is not acquired.") + + def _wake_up_first(self) -> None: + """Ensure that the first waiter will wake up.""" + if not self._waiters: + return + try: + fut = next(iter(self._waiters)) + except StopIteration: + return + + # .done() means that the waiter is already set to wake up. + if not fut.done(): + fut.set_result(True) + + +class Condition(_ContextManagerMixin, _LoopBoundMixin): + """Asynchronous equivalent to threading.Condition. + + This class implements condition variable objects. A condition variable + allows one or more tasks to wait until they are notified by another + task. + + A new Lock object is created and used as the underlying lock. + """ + + def __init__(self, lock: Optional[Lock] = None) -> None: + if lock is None: + lock = Lock() + + self._lock = lock + # Export the lock's locked(), acquire() and release() methods. + self.locked = lock.locked + self.acquire = lock.acquire + self.release = lock.release + + self._waiters: collections.deque = collections.deque() + + def __repr__(self) -> str: + res = super().__repr__() + extra = "locked" if self.locked() else "unlocked" + if self._waiters: + extra = f"{extra}, waiters:{len(self._waiters)}" + return f"<{res[1:-1]} [{extra}]>" + + async def wait(self) -> bool: + """Wait until notified. + + If the calling task has not acquired the lock when this + method is called, a RuntimeError is raised. + + This method releases the underlying lock, and then blocks + until it is awakened by a notify() or notify_all() call for + the same condition variable in another task. Once + awakened, it re-acquires the lock and returns True. + + This method may return spuriously, + which is why the caller should always + re-check the state and be prepared to wait() again. + """ + if not self.locked(): + raise RuntimeError("cannot wait on un-acquired lock") + + fut = self._get_loop().create_future() + self.release() + try: + try: + self._waiters.append(fut) + try: + await fut + return True + finally: + self._waiters.remove(fut) + + finally: + # Must re-acquire lock even if wait is cancelled. + # We only catch CancelledError here, since we don't want any + # other (fatal) errors with the future to cause us to spin. + err = None + while True: + try: + await self.acquire() + break + except exceptions.CancelledError as e: + err = e + + if err is not None: + try: + raise err # Re-raise most recent exception instance. + finally: + err = None # Break reference cycles. + except BaseException: + # Any error raised out of here _may_ have occurred after this Task + # believed to have been successfully notified. + # Make sure to notify another Task instead. This may result + # in a "spurious wakeup", which is allowed as part of the + # Condition Variable protocol. + self._notify(1) + raise + + async def wait_for(self, predicate: Any) -> Coroutine: + """Wait until a predicate becomes true. + + The predicate should be a callable whose result will be + interpreted as a boolean value. The method will repeatedly + wait() until it evaluates to true. The final predicate value is + the return value. + """ + result = predicate() + while not result: + await self.wait() + result = predicate() + return result + + def notify(self, n: int = 1) -> None: + """By default, wake up one task waiting on this condition, if any. + If the calling task has not acquired the lock when this method + is called, a RuntimeError is raised. + + This method wakes up n of the tasks waiting for the condition + variable; if fewer than n are waiting, they are all awoken. + + Note: an awakened task does not actually return from its + wait() call until it can reacquire the lock. Since notify() does + not release the lock, its caller should. + """ + if not self.locked(): + raise RuntimeError("cannot notify on un-acquired lock") + self._notify(n) + + def _notify(self, n: int) -> None: + idx = 0 + for fut in self._waiters: + if idx >= n: + break + + if not fut.done(): + idx += 1 + fut.set_result(False) + + def notify_all(self) -> None: + """Wake up all tasks waiting on this condition. This method acts + like notify(), but wakes up all waiting tasks instead of one. If the + calling task has not acquired the lock when this method is called, + a RuntimeError is raised. + """ + self.notify(len(self._waiters)) diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index 4363fc3370..2fe9579aef 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -991,7 +991,6 @@ def __init__( self.conns: collections.deque = collections.deque() self.active_contexts: set[_CancellationContext] = set() self.lock = _async_create_lock() - self.size_cond = _async_create_condition(self.lock) self._max_connecting_cond = _async_create_condition(self.lock) self.active_sockets = 0 # Monotonically increasing connection ID required for CMAP Events. @@ -1018,6 +1017,7 @@ def __init__( # The first portion of the wait queue. # Enforces: maxPoolSize # Also used for: clearing the wait queue + self.size_cond = _async_create_condition(self.lock) self.requests = 0 self.max_pool_size = self.opts.max_pool_size if not self.max_pool_size: @@ -1025,6 +1025,7 @@ def __init__( # The second portion of the wait queue. # Enforces: maxConnecting # Also used for: clearing the wait queue + self._max_connecting_cond = _async_create_condition(self.lock) self._max_connecting = self.opts.max_connecting self._pending = 0 self._client_id = client_id diff --git a/pymongo/lock.py b/pymongo/lock.py index 8d80ef1777..f7e4b4ae3c 100644 --- a/pymongo/lock.py +++ b/pymongo/lock.py @@ -11,17 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Lock and Condition classes vendored from https://github.com/python/cpython/blob/main/Lib/asyncio/locks.py -to port 3.13 fixes to older versions of Python.""" + +"""Internal helpers for lock and condition coordination primitives.""" from __future__ import annotations -import collections +import asyncio import os +import sys import threading import weakref -from asyncio import TimeoutError, events, exceptions, wait_for -from typing import Any, Coroutine, Optional, TypeVar +from asyncio import TimeoutError, wait_for +from typing import Any, Optional, TypeVar + +import pymongo._asyncio_lock _HAS_REGISTER_AT_FORK = hasattr(os, "register_at_fork") @@ -30,302 +33,13 @@ _T = TypeVar("_T") -_global_lock = threading.Lock() - - -class _LoopBoundMixin: - _loop = None - - def _get_loop(self) -> Any: - loop = events._get_running_loop() - - if self._loop is None: - with _global_lock: - if self._loop is None: - self._loop = loop - if loop is not self._loop: - raise RuntimeError(f"{self!r} is bound to a different event loop") - return loop - - -class _ContextManagerMixin: - async def __aenter__(self) -> None: - await self.acquire() # type: ignore[attr-defined] - # We have no use for the "as ..." clause in the with - # statement for locks. - return - - async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: - self.release() # type: ignore[attr-defined] - - -class Lock(_ContextManagerMixin, _LoopBoundMixin): - """Primitive lock objects. - - A primitive lock is a synchronization primitive that is not owned - by a particular task when locked. A primitive lock is in one - of two states, 'locked' or 'unlocked'. - - It is created in the unlocked state. It has two basic methods, - acquire() and release(). When the state is unlocked, acquire() - changes the state to locked and returns immediately. When the - state is locked, acquire() blocks until a call to release() in - another task changes it to unlocked, then the acquire() call - resets it to locked and returns. The release() method should only - be called in the locked state; it changes the state to unlocked - and returns immediately. If an attempt is made to release an - unlocked lock, a RuntimeError will be raised. - - When more than one task is blocked in acquire() waiting for - the state to turn to unlocked, only one task proceeds when a - release() call resets the state to unlocked; successive release() - calls will unblock tasks in FIFO order. - - Locks also support the asynchronous context management protocol. - 'async with lock' statement should be used. - - Usage: - - lock = Lock() - ... - await lock.acquire() - try: - ... - finally: - lock.release() - - Context manager usage: - - lock = Lock() - ... - async with lock: - ... - - Lock objects can be tested for locking state: - - if not lock.locked(): - await lock.acquire() - else: - # lock is acquired - ... - - """ - - def __init__(self) -> None: - self._waiters: Optional[collections.deque] = None - self._locked = False - - def __repr__(self) -> str: - res = super().__repr__() - extra = "locked" if self._locked else "unlocked" - if self._waiters: - extra = f"{extra}, waiters:{len(self._waiters)}" - return f"<{res[1:-1]} [{extra}]>" - - def locked(self) -> bool: - """Return True if lock is acquired.""" - return self._locked - - async def acquire(self) -> bool: - """Acquire a lock. - - This method blocks until the lock is unlocked, then sets it to - locked and returns True. - """ - # Implement fair scheduling, where thread always waits - # its turn. Jumping the queue if all are cancelled is an optimization. - if not self._locked and ( - self._waiters is None or all(w.cancelled() for w in self._waiters) - ): - self._locked = True - return True - - if self._waiters is None: - self._waiters = collections.deque() - fut = self._get_loop().create_future() - self._waiters.append(fut) - - try: - try: - await fut - finally: - self._waiters.remove(fut) - except exceptions.CancelledError: - # Currently the only exception designed be able to occur here. - - # Ensure the lock invariant: If lock is not claimed (or about - # to be claimed by us) and there is a Task in waiters, - # ensure that the Task at the head will run. - if not self._locked: - self._wake_up_first() - raise - - # assert self._locked is False - self._locked = True - return True - - def release(self) -> None: - """Release a lock. - - When the lock is locked, reset it to unlocked, and return. - If any other tasks are blocked waiting for the lock to become - unlocked, allow exactly one of them to proceed. - - When invoked on an unlocked lock, a RuntimeError is raised. - - There is no return value. - """ - if self._locked: - self._locked = False - self._wake_up_first() - else: - raise RuntimeError("Lock is not acquired.") - - def _wake_up_first(self) -> None: - """Ensure that the first waiter will wake up.""" - if not self._waiters: - return - try: - fut = next(iter(self._waiters)) - except StopIteration: - return - - # .done() means that the waiter is already set to wake up. - if not fut.done(): - fut.set_result(True) - - -class Condition(_ContextManagerMixin, _LoopBoundMixin): - """Asynchronous equivalent to threading.Condition. - - This class implements condition variable objects. A condition variable - allows one or more tasks to wait until they are notified by another - task. - - A new Lock object is created and used as the underlying lock. - """ - - def __init__(self, lock: Optional[Lock] = None) -> None: - if lock is None: - lock = Lock() - - self._lock = lock - # Export the lock's locked(), acquire() and release() methods. - self.locked = lock.locked - self.acquire = lock.acquire - self.release = lock.release - - self._waiters: collections.deque = collections.deque() - - def __repr__(self) -> str: - res = super().__repr__() - extra = "locked" if self.locked() else "unlocked" - if self._waiters: - extra = f"{extra}, waiters:{len(self._waiters)}" - return f"<{res[1:-1]} [{extra}]>" - - async def wait(self) -> bool: - """Wait until notified. - - If the calling task has not acquired the lock when this - method is called, a RuntimeError is raised. - - This method releases the underlying lock, and then blocks - until it is awakened by a notify() or notify_all() call for - the same condition variable in another task. Once - awakened, it re-acquires the lock and returns True. - - This method may return spuriously, - which is why the caller should always - re-check the state and be prepared to wait() again. - """ - if not self.locked(): - raise RuntimeError("cannot wait on un-acquired lock") - - fut = self._get_loop().create_future() - self.release() - try: - try: - self._waiters.append(fut) - try: - await fut - return True - finally: - self._waiters.remove(fut) - - finally: - # Must re-acquire lock even if wait is cancelled. - # We only catch CancelledError here, since we don't want any - # other (fatal) errors with the future to cause us to spin. - err = None - while True: - try: - await self.acquire() - break - except exceptions.CancelledError as e: - err = e - - if err is not None: - try: - raise err # Re-raise most recent exception instance. - finally: - err = None # Break reference cycles. - except BaseException: - # Any error raised out of here _may_ have occurred after this Task - # believed to have been successfully notified. - # Make sure to notify another Task instead. This may result - # in a "spurious wakeup", which is allowed as part of the - # Condition Variable protocol. - self._notify(1) - raise - - async def wait_for(self, predicate: Any) -> Coroutine: - """Wait until a predicate becomes true. - - The predicate should be a callable whose result will be - interpreted as a boolean value. The method will repeatedly - wait() until it evaluates to true. The final predicate value is - the return value. - """ - result = predicate() - while not result: - await self.wait() - result = predicate() - return result - - def notify(self, n: int = 1) -> None: - """By default, wake up one task waiting on this condition, if any. - If the calling task has not acquired the lock when this method - is called, a RuntimeError is raised. - - This method wakes up n of the tasks waiting for the condition - variable; if fewer than n are waiting, they are all awoken. - - Note: an awakened task does not actually return from its - wait() call until it can reacquire the lock. Since notify() does - not release the lock, its caller should. - """ - if not self.locked(): - raise RuntimeError("cannot notify on un-acquired lock") - self._notify(n) - - def _notify(self, n: int) -> None: - idx = 0 - for fut in self._waiters: - if idx >= n: - break - - if not fut.done(): - idx += 1 - fut.set_result(False) - - def notify_all(self) -> None: - """Wake up all tasks waiting on this condition. This method acts - like notify(), but wakes up all waiting tasks instead of one. If the - calling task has not acquired the lock when this method is called, - a RuntimeError is raised. - """ - self.notify(len(self._waiters)) +# Needed to support 3.13 asyncio fixes in older versions of Python +if sys.version_info >= (3, 13): + Lock = asyncio.Lock + Condition = asyncio.Condition +else: + Lock = pymongo._asyncio_lock.Lock + Condition = pymongo._asyncio_lock.Condition def _create_lock() -> threading.Lock: diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index f68bfe2002..6ac7b4eca9 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -987,7 +987,6 @@ def __init__( self.conns: collections.deque = collections.deque() self.active_contexts: set[_CancellationContext] = set() self.lock = _create_lock() - self.size_cond = _create_condition(self.lock) self._max_connecting_cond = _create_condition(self.lock) self.active_sockets = 0 # Monotonically increasing connection ID required for CMAP Events. @@ -1014,6 +1013,7 @@ def __init__( # The first portion of the wait queue. # Enforces: maxPoolSize # Also used for: clearing the wait queue + self.size_cond = _create_condition(self.lock) self.requests = 0 self.max_pool_size = self.opts.max_pool_size if not self.max_pool_size: @@ -1021,6 +1021,7 @@ def __init__( # The second portion of the wait queue. # Enforces: maxConnecting # Also used for: clearing the wait queue + self._max_connecting_cond = _create_condition(self.lock) self._max_connecting = self.opts.max_connecting self._pending = 0 self._client_id = client_id diff --git a/test/asynchronous/test_locks.py b/test/asynchronous/test_locks.py new file mode 100644 index 0000000000..ae9f7771dd --- /dev/null +++ b/test/asynchronous/test_locks.py @@ -0,0 +1,463 @@ +# Copyright 2024-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for lock.py""" +from __future__ import annotations + +import asyncio +import sys +import unittest + +from pymongo.lock import _async_create_condition, _async_create_lock + +sys.path[0:0] = [""] + + +# Tests adapted from: https://github.com/python/cpython/blob/v3.13.0rc2/Lib/test/test_asyncio/test_locks.py +# Includes tests for: +# - https://github.com/python/cpython/issues/111693 +# - https://github.com/python/cpython/issues/112202 +class TestConditionStdlib(unittest.IsolatedAsyncioTestCase): + async def test_wait(self): + cond = _async_create_condition(_async_create_lock()) + result = [] + + async def c1(result): + await cond.acquire() + if await cond.wait(): + result.append(1) + return True + + async def c2(result): + await cond.acquire() + if await cond.wait(): + result.append(2) + return True + + async def c3(result): + await cond.acquire() + if await cond.wait(): + result.append(3) + return True + + t1 = asyncio.create_task(c1(result)) + t2 = asyncio.create_task(c2(result)) + t3 = asyncio.create_task(c3(result)) + + await asyncio.sleep(0) + self.assertEqual([], result) + self.assertFalse(cond.locked()) + + self.assertTrue(await cond.acquire()) + cond.notify() + await asyncio.sleep(0) + self.assertEqual([], result) + self.assertTrue(cond.locked()) + + cond.release() + await asyncio.sleep(0) + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.notify(2) + await asyncio.sleep(0) + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.release() + await asyncio.sleep(0) + self.assertEqual([1, 2], result) + self.assertTrue(cond.locked()) + + cond.release() + await asyncio.sleep(0) + self.assertEqual([1, 2, 3], result) + self.assertTrue(cond.locked()) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + async def test_wait_cancel(self): + cond = _async_create_condition(_async_create_lock()) + await cond.acquire() + + wait = asyncio.create_task(cond.wait()) + asyncio.get_running_loop().call_soon(wait.cancel) + with self.assertRaises(asyncio.CancelledError): + await wait + self.assertFalse(cond._waiters) + self.assertTrue(cond.locked()) + + async def test_wait_cancel_contested(self): + cond = _async_create_condition(_async_create_lock()) + + await cond.acquire() + self.assertTrue(cond.locked()) + + wait_task = asyncio.create_task(cond.wait()) + await asyncio.sleep(0) + self.assertFalse(cond.locked()) + + # Notify, but contest the lock before cancelling + await cond.acquire() + self.assertTrue(cond.locked()) + cond.notify() + asyncio.get_running_loop().call_soon(wait_task.cancel) + asyncio.get_running_loop().call_soon(cond.release) + + try: + await wait_task + except asyncio.CancelledError: + # Should not happen, since no cancellation points + pass + + self.assertTrue(cond.locked()) + + async def test_wait_cancel_after_notify(self): + # See bpo-32841 + waited = False + + cond = _async_create_condition(_async_create_lock()) + + async def wait_on_cond(): + nonlocal waited + async with cond: + waited = True # Make sure this area was reached + await cond.wait() + + waiter = asyncio.create_task(wait_on_cond()) + await asyncio.sleep(0) # Start waiting + + await cond.acquire() + cond.notify() + await asyncio.sleep(0) # Get to acquire() + waiter.cancel() + await asyncio.sleep(0) # Activate cancellation + cond.release() + await asyncio.sleep(0) # Cancellation should occur + + self.assertTrue(waiter.cancelled()) + self.assertTrue(waited) + + async def test_wait_unacquired(self): + cond = _async_create_condition(_async_create_lock()) + with self.assertRaises(RuntimeError): + await cond.wait() + + async def test_wait_for(self): + cond = _async_create_condition(_async_create_lock()) + presult = False + + def predicate(): + return presult + + result = [] + + async def c1(result): + await cond.acquire() + if await cond.wait_for(predicate): + result.append(1) + cond.release() + return True + + t = asyncio.create_task(c1(result)) + + await asyncio.sleep(0) + self.assertEqual([], result) + + await cond.acquire() + cond.notify() + cond.release() + await asyncio.sleep(0) + self.assertEqual([], result) + + presult = True + await cond.acquire() + cond.notify() + cond.release() + await asyncio.sleep(0) + self.assertEqual([1], result) + + self.assertTrue(t.done()) + self.assertTrue(t.result()) + + async def test_wait_for_unacquired(self): + cond = _async_create_condition(_async_create_lock()) + + # predicate can return true immediately + res = await cond.wait_for(lambda: [1, 2, 3]) + self.assertEqual([1, 2, 3], res) + + with self.assertRaises(RuntimeError): + await cond.wait_for(lambda: False) + + async def test_notify(self): + cond = _async_create_condition(_async_create_lock()) + result = [] + + async def c1(result): + async with cond: + if await cond.wait(): + result.append(1) + return True + + async def c2(result): + async with cond: + if await cond.wait(): + result.append(2) + return True + + async def c3(result): + async with cond: + if await cond.wait(): + result.append(3) + return True + + t1 = asyncio.create_task(c1(result)) + t2 = asyncio.create_task(c2(result)) + t3 = asyncio.create_task(c3(result)) + + await asyncio.sleep(0) + self.assertEqual([], result) + + async with cond: + cond.notify(1) + await asyncio.sleep(1) + self.assertEqual([1], result) + + async with cond: + cond.notify(1) + cond.notify(2048) + await asyncio.sleep(1) + self.assertEqual([1, 2, 3], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + async def test_notify_all(self): + cond = _async_create_condition(_async_create_lock()) + + result = [] + + async def c1(result): + async with cond: + if await cond.wait(): + result.append(1) + return True + + async def c2(result): + async with cond: + if await cond.wait(): + result.append(2) + return True + + t1 = asyncio.create_task(c1(result)) + t2 = asyncio.create_task(c2(result)) + + await asyncio.sleep(0) + self.assertEqual([], result) + + async with cond: + cond.notify_all() + await asyncio.sleep(1) + self.assertEqual([1, 2], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + + async def test_context_manager(self): + cond = _async_create_condition(_async_create_lock()) + self.assertFalse(cond.locked()) + async with cond: + self.assertTrue(cond.locked()) + self.assertFalse(cond.locked()) + + async def test_timeout_in_block(self): + condition = _async_create_condition(_async_create_lock()) + async with condition: + with self.assertRaises(asyncio.TimeoutError): + await asyncio.wait_for(condition.wait(), timeout=0.5) + + @unittest.skipIf( + sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11" + ) + async def test_cancelled_error_wakeup(self): + # Test that a cancelled error, received when awaiting wakeup, + # will be re-raised un-modified. + wake = False + raised = None + cond = _async_create_condition(_async_create_lock()) + + async def func(): + nonlocal raised + async with cond: + with self.assertRaises(asyncio.CancelledError) as err: + await cond.wait_for(lambda: wake) + raised = err.exception + raise raised + + task = asyncio.create_task(func()) + await asyncio.sleep(0) + # Task is waiting on the condition, cancel it there. + task.cancel(msg="foo") # type: ignore[call-arg] + with self.assertRaises(asyncio.CancelledError) as err: + await task + self.assertEqual(err.exception.args, ("foo",)) + # We should have got the _same_ exception instance as the one + # originally raised. + self.assertIs(err.exception, raised) + + @unittest.skipIf( + sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11" + ) + async def test_cancelled_error_re_aquire(self): + # Test that a cancelled error, received when re-aquiring lock, + # will be re-raised un-modified. + wake = False + raised = None + cond = _async_create_condition(_async_create_lock()) + + async def func(): + nonlocal raised + async with cond: + with self.assertRaises(asyncio.CancelledError) as err: + await cond.wait_for(lambda: wake) + raised = err.exception + raise raised + + task = asyncio.create_task(func()) + await asyncio.sleep(0) + # Task is waiting on the condition + await cond.acquire() + wake = True + cond.notify() + await asyncio.sleep(0) + # Task is now trying to re-acquire the lock, cancel it there. + task.cancel(msg="foo") # type: ignore[call-arg] + cond.release() + with self.assertRaises(asyncio.CancelledError) as err: + await task + self.assertEqual(err.exception.args, ("foo",)) + # We should have got the _same_ exception instance as the one + # originally raised. + self.assertIs(err.exception, raised) + + @unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11") + async def test_cancelled_wakeup(self): + # Test that a task cancelled at the "same" time as it is woken + # up as part of a Condition.notify() does not result in a lost wakeup. + # This test simulates a cancel while the target task is awaiting initial + # wakeup on the wakeup queue. + condition = _async_create_condition(_async_create_lock()) + state = 0 + + async def consumer(): + nonlocal state + async with condition: + while True: + await condition.wait_for(lambda: state != 0) + if state < 0: + return + state -= 1 + + # create two consumers + c = [asyncio.create_task(consumer()) for _ in range(2)] + # wait for them to settle + await asyncio.sleep(0.1) + async with condition: + # produce one item and wake up one + state += 1 + condition.notify(1) + + # Cancel it while it is awaiting to be run. + # This cancellation could come from the outside + c[0].cancel() + + # now wait for the item to be consumed + # if it doesn't means that our "notify" didn"t take hold. + # because it raced with a cancel() + try: + async with asyncio.timeout(1): + await condition.wait_for(lambda: state == 0) + except TimeoutError: + pass + self.assertEqual(state, 0) + + # clean up + state = -1 + condition.notify_all() + await c[1] + + @unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11") + async def test_cancelled_wakeup_relock(self): + # Test that a task cancelled at the "same" time as it is woken + # up as part of a Condition.notify() does not result in a lost wakeup. + # This test simulates a cancel while the target task is acquiring the lock + # again. + condition = _async_create_condition(_async_create_lock()) + state = 0 + + async def consumer(): + nonlocal state + async with condition: + while True: + await condition.wait_for(lambda: state != 0) + if state < 0: + return + state -= 1 + + # create two consumers + c = [asyncio.create_task(consumer()) for _ in range(2)] + # wait for them to settle + await asyncio.sleep(0.1) + async with condition: + # produce one item and wake up one + state += 1 + condition.notify(1) + + # now we sleep for a bit. This allows the target task to wake up and + # settle on re-aquiring the lock + await asyncio.sleep(0) + + # Cancel it while awaiting the lock + # This cancel could come the outside. + c[0].cancel() + + # now wait for the item to be consumed + # if it doesn't means that our "notify" didn"t take hold. + # because it raced with a cancel() + try: + async with asyncio.timeout(1): + await condition.wait_for(lambda: state == 0) + except TimeoutError: + pass + self.assertEqual(state, 0) + + # clean up + state = -1 + condition.notify_all() + await c[1] + + +if __name__ == "__main__": + unittest.main() From 4626c82fbabf766614e3d0d1b85ad401451a2b60 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 16 Oct 2024 16:12:17 -0400 Subject: [PATCH 9/9] Address review --- pymongo/lock.py | 7 +- test/asynchronous/test_locks.py | 785 ++++++++++++++++---------------- 2 files changed, 396 insertions(+), 396 deletions(-) diff --git a/pymongo/lock.py b/pymongo/lock.py index f7e4b4ae3c..6bf7138017 100644 --- a/pymongo/lock.py +++ b/pymongo/lock.py @@ -21,7 +21,7 @@ import sys import threading import weakref -from asyncio import TimeoutError, wait_for +from asyncio import wait_for from typing import Any, Optional, TypeVar import pymongo._asyncio_lock @@ -33,7 +33,8 @@ _T = TypeVar("_T") -# Needed to support 3.13 asyncio fixes in older versions of Python +# Needed to support 3.13 asyncio fixes (https://github.com/python/cpython/issues/112202) +# in older versions of Python if sys.version_info >= (3, 13): Lock = asyncio.Lock Condition = asyncio.Condition @@ -83,7 +84,7 @@ def _release_locks() -> None: async def _async_cond_wait(condition: Condition, timeout: Optional[float]) -> bool: try: return await wait_for(condition.wait(), timeout) - except TimeoutError: + except asyncio.TimeoutError: return False diff --git a/test/asynchronous/test_locks.py b/test/asynchronous/test_locks.py index ae9f7771dd..e5a0adfee6 100644 --- a/test/asynchronous/test_locks.py +++ b/test/asynchronous/test_locks.py @@ -22,442 +22,441 @@ sys.path[0:0] = [""] +if sys.version_info < (3, 13): + # Tests adapted from: https://github.com/python/cpython/blob/v3.13.0rc2/Lib/test/test_asyncio/test_locks.py + # Includes tests for: + # - https://github.com/python/cpython/issues/111693 + # - https://github.com/python/cpython/issues/112202 + class TestConditionStdlib(unittest.IsolatedAsyncioTestCase): + async def test_wait(self): + cond = _async_create_condition(_async_create_lock()) + result = [] + + async def c1(result): + await cond.acquire() + if await cond.wait(): + result.append(1) + return True -# Tests adapted from: https://github.com/python/cpython/blob/v3.13.0rc2/Lib/test/test_asyncio/test_locks.py -# Includes tests for: -# - https://github.com/python/cpython/issues/111693 -# - https://github.com/python/cpython/issues/112202 -class TestConditionStdlib(unittest.IsolatedAsyncioTestCase): - async def test_wait(self): - cond = _async_create_condition(_async_create_lock()) - result = [] + async def c2(result): + await cond.acquire() + if await cond.wait(): + result.append(2) + return True - async def c1(result): - await cond.acquire() - if await cond.wait(): - result.append(1) - return True + async def c3(result): + await cond.acquire() + if await cond.wait(): + result.append(3) + return True - async def c2(result): - await cond.acquire() - if await cond.wait(): - result.append(2) - return True + t1 = asyncio.create_task(c1(result)) + t2 = asyncio.create_task(c2(result)) + t3 = asyncio.create_task(c3(result)) - async def c3(result): - await cond.acquire() - if await cond.wait(): - result.append(3) - return True - - t1 = asyncio.create_task(c1(result)) - t2 = asyncio.create_task(c2(result)) - t3 = asyncio.create_task(c3(result)) - - await asyncio.sleep(0) - self.assertEqual([], result) - self.assertFalse(cond.locked()) - - self.assertTrue(await cond.acquire()) - cond.notify() - await asyncio.sleep(0) - self.assertEqual([], result) - self.assertTrue(cond.locked()) - - cond.release() - await asyncio.sleep(0) - self.assertEqual([1], result) - self.assertTrue(cond.locked()) - - cond.notify(2) - await asyncio.sleep(0) - self.assertEqual([1], result) - self.assertTrue(cond.locked()) - - cond.release() - await asyncio.sleep(0) - self.assertEqual([1, 2], result) - self.assertTrue(cond.locked()) - - cond.release() - await asyncio.sleep(0) - self.assertEqual([1, 2, 3], result) - self.assertTrue(cond.locked()) - - self.assertTrue(t1.done()) - self.assertTrue(t1.result()) - self.assertTrue(t2.done()) - self.assertTrue(t2.result()) - self.assertTrue(t3.done()) - self.assertTrue(t3.result()) - - async def test_wait_cancel(self): - cond = _async_create_condition(_async_create_lock()) - await cond.acquire() - - wait = asyncio.create_task(cond.wait()) - asyncio.get_running_loop().call_soon(wait.cancel) - with self.assertRaises(asyncio.CancelledError): - await wait - self.assertFalse(cond._waiters) - self.assertTrue(cond.locked()) - - async def test_wait_cancel_contested(self): - cond = _async_create_condition(_async_create_lock()) - - await cond.acquire() - self.assertTrue(cond.locked()) - - wait_task = asyncio.create_task(cond.wait()) - await asyncio.sleep(0) - self.assertFalse(cond.locked()) - - # Notify, but contest the lock before cancelling - await cond.acquire() - self.assertTrue(cond.locked()) - cond.notify() - asyncio.get_running_loop().call_soon(wait_task.cancel) - asyncio.get_running_loop().call_soon(cond.release) - - try: - await wait_task - except asyncio.CancelledError: - # Should not happen, since no cancellation points - pass - - self.assertTrue(cond.locked()) - - async def test_wait_cancel_after_notify(self): - # See bpo-32841 - waited = False - - cond = _async_create_condition(_async_create_lock()) - - async def wait_on_cond(): - nonlocal waited - async with cond: - waited = True # Make sure this area was reached - await cond.wait() + await asyncio.sleep(0) + self.assertEqual([], result) + self.assertFalse(cond.locked()) - waiter = asyncio.create_task(wait_on_cond()) - await asyncio.sleep(0) # Start waiting + self.assertTrue(await cond.acquire()) + cond.notify() + await asyncio.sleep(0) + self.assertEqual([], result) + self.assertTrue(cond.locked()) - await cond.acquire() - cond.notify() - await asyncio.sleep(0) # Get to acquire() - waiter.cancel() - await asyncio.sleep(0) # Activate cancellation - cond.release() - await asyncio.sleep(0) # Cancellation should occur + cond.release() + await asyncio.sleep(0) + self.assertEqual([1], result) + self.assertTrue(cond.locked()) - self.assertTrue(waiter.cancelled()) - self.assertTrue(waited) + cond.notify(2) + await asyncio.sleep(0) + self.assertEqual([1], result) + self.assertTrue(cond.locked()) - async def test_wait_unacquired(self): - cond = _async_create_condition(_async_create_lock()) - with self.assertRaises(RuntimeError): - await cond.wait() + cond.release() + await asyncio.sleep(0) + self.assertEqual([1, 2], result) + self.assertTrue(cond.locked()) + + cond.release() + await asyncio.sleep(0) + self.assertEqual([1, 2, 3], result) + self.assertTrue(cond.locked()) - async def test_wait_for(self): - cond = _async_create_condition(_async_create_lock()) - presult = False + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) - def predicate(): - return presult + async def test_wait_cancel(self): + cond = _async_create_condition(_async_create_lock()) + await cond.acquire() - result = [] + wait = asyncio.create_task(cond.wait()) + asyncio.get_running_loop().call_soon(wait.cancel) + with self.assertRaises(asyncio.CancelledError): + await wait + self.assertFalse(cond._waiters) + self.assertTrue(cond.locked()) + + async def test_wait_cancel_contested(self): + cond = _async_create_condition(_async_create_lock()) - async def c1(result): await cond.acquire() - if await cond.wait_for(predicate): - result.append(1) - cond.release() - return True + self.assertTrue(cond.locked()) - t = asyncio.create_task(c1(result)) + wait_task = asyncio.create_task(cond.wait()) + await asyncio.sleep(0) + self.assertFalse(cond.locked()) - await asyncio.sleep(0) - self.assertEqual([], result) + # Notify, but contest the lock before cancelling + await cond.acquire() + self.assertTrue(cond.locked()) + cond.notify() + asyncio.get_running_loop().call_soon(wait_task.cancel) + asyncio.get_running_loop().call_soon(cond.release) - await cond.acquire() - cond.notify() - cond.release() - await asyncio.sleep(0) - self.assertEqual([], result) + try: + await wait_task + except asyncio.CancelledError: + # Should not happen, since no cancellation points + pass - presult = True - await cond.acquire() - cond.notify() - cond.release() - await asyncio.sleep(0) - self.assertEqual([1], result) + self.assertTrue(cond.locked()) - self.assertTrue(t.done()) - self.assertTrue(t.result()) + async def test_wait_cancel_after_notify(self): + # See bpo-32841 + waited = False - async def test_wait_for_unacquired(self): - cond = _async_create_condition(_async_create_lock()) + cond = _async_create_condition(_async_create_lock()) - # predicate can return true immediately - res = await cond.wait_for(lambda: [1, 2, 3]) - self.assertEqual([1, 2, 3], res) + async def wait_on_cond(): + nonlocal waited + async with cond: + waited = True # Make sure this area was reached + await cond.wait() - with self.assertRaises(RuntimeError): - await cond.wait_for(lambda: False) + waiter = asyncio.create_task(wait_on_cond()) + await asyncio.sleep(0) # Start waiting - async def test_notify(self): - cond = _async_create_condition(_async_create_lock()) - result = [] + await cond.acquire() + cond.notify() + await asyncio.sleep(0) # Get to acquire() + waiter.cancel() + await asyncio.sleep(0) # Activate cancellation + cond.release() + await asyncio.sleep(0) # Cancellation should occur + + self.assertTrue(waiter.cancelled()) + self.assertTrue(waited) + + async def test_wait_unacquired(self): + cond = _async_create_condition(_async_create_lock()) + with self.assertRaises(RuntimeError): + await cond.wait() - async def c1(result): - async with cond: - if await cond.wait(): - result.append(1) - return True + async def test_wait_for(self): + cond = _async_create_condition(_async_create_lock()) + presult = False - async def c2(result): - async with cond: - if await cond.wait(): - result.append(2) - return True + def predicate(): + return presult - async def c3(result): - async with cond: - if await cond.wait(): - result.append(3) + result = [] + + async def c1(result): + await cond.acquire() + if await cond.wait_for(predicate): + result.append(1) + cond.release() return True - t1 = asyncio.create_task(c1(result)) - t2 = asyncio.create_task(c2(result)) - t3 = asyncio.create_task(c3(result)) + t = asyncio.create_task(c1(result)) - await asyncio.sleep(0) - self.assertEqual([], result) + await asyncio.sleep(0) + self.assertEqual([], result) - async with cond: - cond.notify(1) - await asyncio.sleep(1) - self.assertEqual([1], result) + await cond.acquire() + cond.notify() + cond.release() + await asyncio.sleep(0) + self.assertEqual([], result) - async with cond: - cond.notify(1) - cond.notify(2048) - await asyncio.sleep(1) - self.assertEqual([1, 2, 3], result) + presult = True + await cond.acquire() + cond.notify() + cond.release() + await asyncio.sleep(0) + self.assertEqual([1], result) - self.assertTrue(t1.done()) - self.assertTrue(t1.result()) - self.assertTrue(t2.done()) - self.assertTrue(t2.result()) - self.assertTrue(t3.done()) - self.assertTrue(t3.result()) + self.assertTrue(t.done()) + self.assertTrue(t.result()) - async def test_notify_all(self): - cond = _async_create_condition(_async_create_lock()) + async def test_wait_for_unacquired(self): + cond = _async_create_condition(_async_create_lock()) - result = [] + # predicate can return true immediately + res = await cond.wait_for(lambda: [1, 2, 3]) + self.assertEqual([1, 2, 3], res) - async def c1(result): - async with cond: - if await cond.wait(): - result.append(1) - return True + with self.assertRaises(RuntimeError): + await cond.wait_for(lambda: False) - async def c2(result): - async with cond: - if await cond.wait(): - result.append(2) - return True + async def test_notify(self): + cond = _async_create_condition(_async_create_lock()) + result = [] - t1 = asyncio.create_task(c1(result)) - t2 = asyncio.create_task(c2(result)) + async def c1(result): + async with cond: + if await cond.wait(): + result.append(1) + return True - await asyncio.sleep(0) - self.assertEqual([], result) + async def c2(result): + async with cond: + if await cond.wait(): + result.append(2) + return True - async with cond: - cond.notify_all() - await asyncio.sleep(1) - self.assertEqual([1, 2], result) + async def c3(result): + async with cond: + if await cond.wait(): + result.append(3) + return True - self.assertTrue(t1.done()) - self.assertTrue(t1.result()) - self.assertTrue(t2.done()) - self.assertTrue(t2.result()) + t1 = asyncio.create_task(c1(result)) + t2 = asyncio.create_task(c2(result)) + t3 = asyncio.create_task(c3(result)) - async def test_context_manager(self): - cond = _async_create_condition(_async_create_lock()) - self.assertFalse(cond.locked()) - async with cond: - self.assertTrue(cond.locked()) - self.assertFalse(cond.locked()) - - async def test_timeout_in_block(self): - condition = _async_create_condition(_async_create_lock()) - async with condition: - with self.assertRaises(asyncio.TimeoutError): - await asyncio.wait_for(condition.wait(), timeout=0.5) - - @unittest.skipIf( - sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11" - ) - async def test_cancelled_error_wakeup(self): - # Test that a cancelled error, received when awaiting wakeup, - # will be re-raised un-modified. - wake = False - raised = None - cond = _async_create_condition(_async_create_lock()) - - async def func(): - nonlocal raised - async with cond: - with self.assertRaises(asyncio.CancelledError) as err: - await cond.wait_for(lambda: wake) - raised = err.exception - raise raised - - task = asyncio.create_task(func()) - await asyncio.sleep(0) - # Task is waiting on the condition, cancel it there. - task.cancel(msg="foo") # type: ignore[call-arg] - with self.assertRaises(asyncio.CancelledError) as err: - await task - self.assertEqual(err.exception.args, ("foo",)) - # We should have got the _same_ exception instance as the one - # originally raised. - self.assertIs(err.exception, raised) - - @unittest.skipIf( - sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11" - ) - async def test_cancelled_error_re_aquire(self): - # Test that a cancelled error, received when re-aquiring lock, - # will be re-raised un-modified. - wake = False - raised = None - cond = _async_create_condition(_async_create_lock()) - - async def func(): - nonlocal raised - async with cond: - with self.assertRaises(asyncio.CancelledError) as err: - await cond.wait_for(lambda: wake) - raised = err.exception - raise raised - - task = asyncio.create_task(func()) - await asyncio.sleep(0) - # Task is waiting on the condition - await cond.acquire() - wake = True - cond.notify() - await asyncio.sleep(0) - # Task is now trying to re-acquire the lock, cancel it there. - task.cancel(msg="foo") # type: ignore[call-arg] - cond.release() - with self.assertRaises(asyncio.CancelledError) as err: - await task - self.assertEqual(err.exception.args, ("foo",)) - # We should have got the _same_ exception instance as the one - # originally raised. - self.assertIs(err.exception, raised) - - @unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11") - async def test_cancelled_wakeup(self): - # Test that a task cancelled at the "same" time as it is woken - # up as part of a Condition.notify() does not result in a lost wakeup. - # This test simulates a cancel while the target task is awaiting initial - # wakeup on the wakeup queue. - condition = _async_create_condition(_async_create_lock()) - state = 0 - - async def consumer(): - nonlocal state - async with condition: - while True: - await condition.wait_for(lambda: state != 0) - if state < 0: - return - state -= 1 - - # create two consumers - c = [asyncio.create_task(consumer()) for _ in range(2)] - # wait for them to settle - await asyncio.sleep(0.1) - async with condition: - # produce one item and wake up one - state += 1 - condition.notify(1) - - # Cancel it while it is awaiting to be run. - # This cancellation could come from the outside - c[0].cancel() - - # now wait for the item to be consumed - # if it doesn't means that our "notify" didn"t take hold. - # because it raced with a cancel() - try: - async with asyncio.timeout(1): - await condition.wait_for(lambda: state == 0) - except TimeoutError: - pass - self.assertEqual(state, 0) - - # clean up - state = -1 - condition.notify_all() - await c[1] - - @unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11") - async def test_cancelled_wakeup_relock(self): - # Test that a task cancelled at the "same" time as it is woken - # up as part of a Condition.notify() does not result in a lost wakeup. - # This test simulates a cancel while the target task is acquiring the lock - # again. - condition = _async_create_condition(_async_create_lock()) - state = 0 - - async def consumer(): - nonlocal state - async with condition: - while True: - await condition.wait_for(lambda: state != 0) - if state < 0: - return - state -= 1 - - # create two consumers - c = [asyncio.create_task(consumer()) for _ in range(2)] - # wait for them to settle - await asyncio.sleep(0.1) - async with condition: - # produce one item and wake up one - state += 1 - condition.notify(1) - - # now we sleep for a bit. This allows the target task to wake up and - # settle on re-aquiring the lock await asyncio.sleep(0) + self.assertEqual([], result) - # Cancel it while awaiting the lock - # This cancel could come the outside. - c[0].cancel() + async with cond: + cond.notify(1) + await asyncio.sleep(1) + self.assertEqual([1], result) - # now wait for the item to be consumed - # if it doesn't means that our "notify" didn"t take hold. - # because it raced with a cancel() - try: - async with asyncio.timeout(1): - await condition.wait_for(lambda: state == 0) - except TimeoutError: - pass - self.assertEqual(state, 0) + async with cond: + cond.notify(1) + cond.notify(2048) + await asyncio.sleep(1) + self.assertEqual([1, 2, 3], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + async def test_notify_all(self): + cond = _async_create_condition(_async_create_lock()) + + result = [] + + async def c1(result): + async with cond: + if await cond.wait(): + result.append(1) + return True + + async def c2(result): + async with cond: + if await cond.wait(): + result.append(2) + return True + + t1 = asyncio.create_task(c1(result)) + t2 = asyncio.create_task(c2(result)) - # clean up - state = -1 - condition.notify_all() - await c[1] + await asyncio.sleep(0) + self.assertEqual([], result) + async with cond: + cond.notify_all() + await asyncio.sleep(1) + self.assertEqual([1, 2], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + + async def test_context_manager(self): + cond = _async_create_condition(_async_create_lock()) + self.assertFalse(cond.locked()) + async with cond: + self.assertTrue(cond.locked()) + self.assertFalse(cond.locked()) -if __name__ == "__main__": - unittest.main() + async def test_timeout_in_block(self): + condition = _async_create_condition(_async_create_lock()) + async with condition: + with self.assertRaises(asyncio.TimeoutError): + await asyncio.wait_for(condition.wait(), timeout=0.5) + + @unittest.skipIf( + sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11" + ) + async def test_cancelled_error_wakeup(self): + # Test that a cancelled error, received when awaiting wakeup, + # will be re-raised un-modified. + wake = False + raised = None + cond = _async_create_condition(_async_create_lock()) + + async def func(): + nonlocal raised + async with cond: + with self.assertRaises(asyncio.CancelledError) as err: + await cond.wait_for(lambda: wake) + raised = err.exception + raise raised + + task = asyncio.create_task(func()) + await asyncio.sleep(0) + # Task is waiting on the condition, cancel it there. + task.cancel(msg="foo") # type: ignore[call-arg] + with self.assertRaises(asyncio.CancelledError) as err: + await task + self.assertEqual(err.exception.args, ("foo",)) + # We should have got the _same_ exception instance as the one + # originally raised. + self.assertIs(err.exception, raised) + + @unittest.skipIf( + sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11" + ) + async def test_cancelled_error_re_aquire(self): + # Test that a cancelled error, received when re-aquiring lock, + # will be re-raised un-modified. + wake = False + raised = None + cond = _async_create_condition(_async_create_lock()) + + async def func(): + nonlocal raised + async with cond: + with self.assertRaises(asyncio.CancelledError) as err: + await cond.wait_for(lambda: wake) + raised = err.exception + raise raised + + task = asyncio.create_task(func()) + await asyncio.sleep(0) + # Task is waiting on the condition + await cond.acquire() + wake = True + cond.notify() + await asyncio.sleep(0) + # Task is now trying to re-acquire the lock, cancel it there. + task.cancel(msg="foo") # type: ignore[call-arg] + cond.release() + with self.assertRaises(asyncio.CancelledError) as err: + await task + self.assertEqual(err.exception.args, ("foo",)) + # We should have got the _same_ exception instance as the one + # originally raised. + self.assertIs(err.exception, raised) + + @unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11") + async def test_cancelled_wakeup(self): + # Test that a task cancelled at the "same" time as it is woken + # up as part of a Condition.notify() does not result in a lost wakeup. + # This test simulates a cancel while the target task is awaiting initial + # wakeup on the wakeup queue. + condition = _async_create_condition(_async_create_lock()) + state = 0 + + async def consumer(): + nonlocal state + async with condition: + while True: + await condition.wait_for(lambda: state != 0) + if state < 0: + return + state -= 1 + + # create two consumers + c = [asyncio.create_task(consumer()) for _ in range(2)] + # wait for them to settle + await asyncio.sleep(0.1) + async with condition: + # produce one item and wake up one + state += 1 + condition.notify(1) + + # Cancel it while it is awaiting to be run. + # This cancellation could come from the outside + c[0].cancel() + + # now wait for the item to be consumed + # if it doesn't means that our "notify" didn"t take hold. + # because it raced with a cancel() + try: + async with asyncio.timeout(1): + await condition.wait_for(lambda: state == 0) + except TimeoutError: + pass + self.assertEqual(state, 0) + + # clean up + state = -1 + condition.notify_all() + await c[1] + + @unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11") + async def test_cancelled_wakeup_relock(self): + # Test that a task cancelled at the "same" time as it is woken + # up as part of a Condition.notify() does not result in a lost wakeup. + # This test simulates a cancel while the target task is acquiring the lock + # again. + condition = _async_create_condition(_async_create_lock()) + state = 0 + + async def consumer(): + nonlocal state + async with condition: + while True: + await condition.wait_for(lambda: state != 0) + if state < 0: + return + state -= 1 + + # create two consumers + c = [asyncio.create_task(consumer()) for _ in range(2)] + # wait for them to settle + await asyncio.sleep(0.1) + async with condition: + # produce one item and wake up one + state += 1 + condition.notify(1) + + # now we sleep for a bit. This allows the target task to wake up and + # settle on re-aquiring the lock + await asyncio.sleep(0) + + # Cancel it while awaiting the lock + # This cancel could come the outside. + c[0].cancel() + + # now wait for the item to be consumed + # if it doesn't means that our "notify" didn"t take hold. + # because it raced with a cancel() + try: + async with asyncio.timeout(1): + await condition.wait_for(lambda: state == 0) + except TimeoutError: + pass + self.assertEqual(state, 0) + + # clean up + state = -1 + condition.notify_all() + await c[1] + + if __name__ == "__main__": + unittest.main()