Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PYTHON-4860 - Async client should use asyncio.Lock and asyncio.Condition #1934

Merged
merged 11 commits into from
Oct 17, 2024
4 changes: 2 additions & 2 deletions pymongo/asynchronous/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
)
from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS, _QUERY_OPTIONS, CursorType, _Hint, _Sort
from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure
from pymongo.lock import _ALock, _create_lock
from pymongo.lock import _async_create_lock
from pymongo.message import (
_CursorAddress,
_GetMore,
Expand Down Expand Up @@ -77,7 +77,7 @@ class _ConnectionManager:
def __init__(self, conn: AsyncConnection, more_to_come: bool):
self.conn: Optional[AsyncConnection] = conn
self.more_to_come = more_to_come
self._alock = _ALock(_create_lock())
self._lock = _async_create_lock()

def update_exhaust(self, more_to_come: bool) -> None:
self.more_to_come = more_to_come
Expand Down
12 changes: 8 additions & 4 deletions pymongo/asynchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,11 @@
WaitQueueTimeoutError,
WriteConcernError,
)
from pymongo.lock import _HAS_REGISTER_AT_FORK, _ALock, _create_lock, _release_locks
from pymongo.lock import (
_HAS_REGISTER_AT_FORK,
_async_create_lock,
_release_locks,
)
from pymongo.logger import _CLIENT_LOGGER, _log_or_warn
from pymongo.message import _CursorAddress, _GetMore, _Query
from pymongo.monitoring import ConnectionClosedReason
Expand Down Expand Up @@ -842,7 +846,7 @@ def __init__(
self._options = options = ClientOptions(username, password, dbase, opts, _IS_SYNC)

self._default_database_name = dbase
self._lock = _ALock(_create_lock())
self._lock = _async_create_lock()
self._kill_cursors_queue: list = []

self._event_listeners = options.pool_options._event_listeners
Expand Down Expand Up @@ -1721,7 +1725,7 @@ async def _run_operation(
address=address,
)

async with operation.conn_mgr._alock:
async with operation.conn_mgr._lock:
async with _MongoClientErrorHandler(self, server, operation.session) as err_handler: # type: ignore[arg-type]
err_handler.contribute_socket(operation.conn_mgr.conn)
return await server.run_operation(
Expand Down Expand Up @@ -1969,7 +1973,7 @@ async def _close_cursor_now(

try:
if conn_mgr:
async with conn_mgr._alock:
async with conn_mgr._lock:
# Cursor is pinned to LB outside of a transaction.
assert address is not None
assert conn_mgr.conn is not None
Expand Down
24 changes: 12 additions & 12 deletions pymongo/asynchronous/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@
_CertificateError,
)
from pymongo.hello import Hello, HelloCompat
from pymongo.lock import _ACondition, _ALock, _create_lock
from pymongo.lock import (
_async_cond_wait,
_async_create_condition,
_async_create_lock,
)
from pymongo.logger import (
_CONNECTION_LOGGER,
_ConnectionStatusMessage,
Expand Down Expand Up @@ -208,11 +212,6 @@ def _raise_connection_failure(
raise AutoReconnect(msg) from error


async def _cond_wait(condition: _ACondition, deadline: Optional[float]) -> bool:
timeout = deadline - time.monotonic() if deadline else None
return await condition.wait(timeout)


def _get_timeout_details(options: PoolOptions) -> dict[str, float]:
details = {}
timeout = _csot.get_timeout()
Expand Down Expand Up @@ -992,8 +991,9 @@ def __init__(
# from the right side.
self.conns: collections.deque = collections.deque()
self.active_contexts: set[_CancellationContext] = set()
_lock = _create_lock()
self.lock = _ALock(_lock)
self.lock = _async_create_lock()
self.size_cond = _async_create_condition(self.lock, threading.Condition)
self._max_connecting_cond = _async_create_condition(self.lock, threading.Condition)
self.active_sockets = 0
# Monotonically increasing connection ID required for CMAP Events.
self.next_connection_id = 1
Expand All @@ -1019,15 +1019,13 @@ def __init__(
# The first portion of the wait queue.
# Enforces: maxPoolSize
# Also used for: clearing the wait queue
self.size_cond = _ACondition(threading.Condition(_lock))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move these back since the Condition vars are defined alongside the variables they protect as well as the explanatory comments.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

self.requests = 0
self.max_pool_size = self.opts.max_pool_size
if not self.max_pool_size:
self.max_pool_size = float("inf")
# The second portion of the wait queue.
# Enforces: maxConnecting
# Also used for: clearing the wait queue
self._max_connecting_cond = _ACondition(threading.Condition(_lock))
self._max_connecting = self.opts.max_connecting
self._pending = 0
self._client_id = client_id
Expand Down Expand Up @@ -1456,7 +1454,8 @@ async def _get_conn(
async with self.size_cond:
self._raise_if_not_ready(checkout_started_time, emit_event=True)
while not (self.requests < self.max_pool_size):
if not await _cond_wait(self.size_cond, deadline):
timeout = deadline - time.monotonic() if deadline else None
if not await _async_cond_wait(self.size_cond, timeout):
# Timed out, notify the next thread to ensure a
# timeout doesn't consume the condition.
if self.requests < self.max_pool_size:
Expand All @@ -1479,7 +1478,8 @@ async def _get_conn(
async with self._max_connecting_cond:
self._raise_if_not_ready(checkout_started_time, emit_event=False)
while not (self.conns or self._pending < self._max_connecting):
if not await _cond_wait(self._max_connecting_cond, deadline):
timeout = deadline - time.monotonic() if deadline else None
if not await _async_cond_wait(self._max_connecting_cond, timeout):
# Timed out, notify the next thread to ensure a
# timeout doesn't consume the condition.
if self.conns or self._pending < self._max_connecting:
Expand Down
15 changes: 9 additions & 6 deletions pymongo/asynchronous/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@
WriteError,
)
from pymongo.hello import Hello
from pymongo.lock import _ACondition, _ALock, _create_lock
from pymongo.lock import (
_async_cond_wait,
_async_create_condition,
_async_create_lock,
)
from pymongo.logger import (
_SDAM_LOGGER,
_SERVER_SELECTION_LOGGER,
Expand Down Expand Up @@ -169,9 +173,8 @@ def __init__(self, topology_settings: TopologySettings):
self._seed_addresses = list(topology_description.server_descriptions())
self._opened = False
self._closed = False
_lock = _create_lock()
self._lock = _ALock(_lock)
self._condition = _ACondition(self._settings.condition_class(_lock))
self._lock = _async_create_lock()
self._condition = _async_create_condition(self._lock, self._settings.condition_class)
self._servers: dict[_Address, Server] = {}
self._pid: Optional[int] = None
self._max_cluster_time: Optional[ClusterTime] = None
Expand Down Expand Up @@ -353,7 +356,7 @@ async def _select_servers_loop(
# change, or for a timeout. We won't miss any changes that
# came after our most recent apply_selector call, since we've
# held the lock until now.
await self._condition.wait(common.MIN_HEARTBEAT_INTERVAL)
await _async_cond_wait(self._condition, common.MIN_HEARTBEAT_INTERVAL)
self._description.check_compatible()
now = time.monotonic()
server_descriptions = self._description.apply_selector(
Expand Down Expand Up @@ -653,7 +656,7 @@ async def request_check_all(self, wait_time: int = 5) -> None:
"""Wake all monitors, wait for at least one to check its server."""
async with self._lock:
self._request_check_all()
await self._condition.wait(wait_time)
await _async_cond_wait(self._condition, wait_time)

def data_bearing_servers(self) -> list[ServerDescription]:
"""Return a list of all data-bearing servers.
Expand Down
Loading
Loading