diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index 81a0fc3c7..57905774b 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -12,6 +12,8 @@ APScheduler, see the :doc:`migration section `. triggers and have no more associated jobs running. Previously, schedules were automatically deleted instantly once their triggers could no longer produce any fire times. +- Fixed large parts of ``MongoDBDataStore`` still calling blocking functions in the + event loop thread **4.0.0a4** diff --git a/src/apscheduler/datastores/mongodb.py b/src/apscheduler/datastores/mongodb.py index cbf5c31de..87c6ee09f 100644 --- a/src/apscheduler/datastores/mongodb.py +++ b/src/apscheduler/datastores/mongodb.py @@ -1,12 +1,13 @@ from __future__ import annotations import operator +import sys from collections import defaultdict -from collections.abc import Mapping +from collections.abc import AsyncIterator, Mapping from contextlib import AsyncExitStack from datetime import datetime, timedelta, timezone from logging import Logger -from typing import Any, Callable, ClassVar, Iterable +from typing import Any, Callable, ClassVar, Generic, Iterable, TypeVar from uuid import UUID import attrs @@ -17,6 +18,7 @@ from bson.codec_options import TypeEncoder, TypeRegistry from pymongo import ASCENDING, DeleteOne, MongoClient, UpdateOne from pymongo.collection import Collection +from pymongo.cursor import Cursor from pymongo.errors import ConnectionFailure, DuplicateKeyError from .._enums import CoalescePolicy, ConflictPolicy, JobOutcome @@ -41,6 +43,13 @@ from ..abc import EventBroker from .base import BaseExternalDataStore +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + +T = TypeVar("T", bound=Mapping[str, Any]) + class CustomEncoder(TypeEncoder): def __init__(self, python_type: type, encoder: Callable): @@ -83,6 +92,40 @@ def unmarshal_timestamps(document: dict[str, Any]) -> None: document[key[:-10]] = datetime.fromtimestamp(time_micro, tzinfo) +class AsyncCursor(Generic[T]): + sentinel: ClassVar[object] = object() + + def __init__(self, cursor: Cursor[T]): + self._cursor = cursor + + def __aiter__(self) -> AsyncIterator[T]: + return self + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + await to_thread.run_sync(self._cursor.close) + + def _get_next(self) -> T: + try: + return next(self._cursor) + except StopIteration: + return self.sentinel + + async def __anext__(self) -> T: + obj = await to_thread.run_sync(self._get_next) + if obj is self.sentinel: + raise StopAsyncIteration + + return obj + + @classmethod + async def create(cls, func: Callable[..., Cursor[T]]) -> AsyncCursor[T]: + cursor = await to_thread.run_sync(func) + return AsyncCursor(cursor) + + @attrs.define(eq=False) class MongoDBDataStore(BaseExternalDataStore): """ @@ -222,11 +265,14 @@ async def get_tasks(self) -> list[Task]: async for attempt in self._retry(): with attempt: tasks: list[Task] = [] - for document in self._tasks.find( - projection=self._task_attrs, sort=[("_id", pymongo.ASCENDING)] - ): - document["id"] = document.pop("_id") - tasks.append(Task.unmarshal(self.serializer, document)) + async with await AsyncCursor.create( + lambda: self._tasks.find( + projection=self._task_attrs, sort=[("_id", pymongo.ASCENDING)] + ) + ) as cursor: + async for document in cursor: + document["id"] = document.pop("_id") + tasks.append(Task.unmarshal(self.serializer, document)) return tasks @@ -235,19 +281,21 @@ async def get_schedules(self, ids: set[str] | None = None) -> list[Schedule]: async for attempt in self._retry(): with attempt: schedules: list[Schedule] = [] - cursor = self._schedules.find(filters).sort("_id") - for document in cursor: - document["id"] = document.pop("_id") - unmarshal_timestamps(document) - try: - schedule = Schedule.unmarshal(self.serializer, document) - except DeserializationError: - self._logger.warning( - "Failed to deserialize schedule %r", document["_id"] - ) - continue + async with await AsyncCursor.create( + lambda: self._schedules.find(filters).sort("_id") + ) as cursor: + async for document in cursor: + document["id"] = document.pop("_id") + unmarshal_timestamps(document) + try: + schedule = Schedule.unmarshal(self.serializer, document) + except DeserializationError: + self._logger.warning( + "Failed to deserialize schedule %r", document["_id"] + ) + continue - schedules.append(schedule) + schedules.append(schedule) return schedules @@ -260,15 +308,18 @@ async def add_schedule( try: async for attempt in self._retry(): with attempt: - self._schedules.insert_one(document) + await to_thread.run_sync(self._schedules.insert_one, document) except DuplicateKeyError: if conflict_policy is ConflictPolicy.exception: raise ConflictingIdError(schedule.id) from None elif conflict_policy is ConflictPolicy.replace: async for attempt in self._retry(): with attempt: - self._schedules.replace_one( - {"_id": schedule.id}, document, True + await to_thread.run_sync( + self._schedules.replace_one, + {"_id": schedule.id}, + document, + True, ) event = ScheduleUpdated( @@ -289,12 +340,14 @@ async def remove_schedules(self, ids: Iterable[str]) -> None: filters = {"_id": {"$in": list(ids)}} if ids is not None else {} async for attempt in self._retry(): with attempt, self.client.start_session() as session: - cursor = self._schedules.find( - filters, projection=["_id", "task_id"], session=session - ) - ids = [(doc["_id"], doc["task_id"]) for doc in cursor] - if ids: - self._schedules.delete_many(filters, session=session) + async with await AsyncCursor.create( + lambda: self._schedules.find( + filters, projection=["_id", "task_id"], session=session + ) + ) as cursor: + ids = [(doc["_id"], doc["task_id"]) async for doc in cursor] + if ids: + self._schedules.delete_many(filters, session=session) for schedule_id, task_id in ids: await self._event_broker.publish( @@ -308,8 +361,8 @@ async def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedul with attempt, self.client.start_session() as session: schedules: list[Schedule] = [] now = datetime.now(timezone.utc).timestamp() - cursor = ( - self._schedules.find( + async with await AsyncCursor.create( + lambda: self._schedules.find( { "next_fire_time": {"$lte": now}, "$or": [ @@ -321,12 +374,12 @@ async def acquire_schedules(self, scheduler_id: str, limit: int) -> list[Schedul ) .sort("next_fire_time") .limit(limit) - ) - for document in cursor: - document["id"] = document.pop("_id") - unmarshal_timestamps(document) - schedule = Schedule.unmarshal(self.serializer, document) - schedules.append(schedule) + ) as cursor: + async for document in cursor: + document["id"] = document.pop("_id") + unmarshal_timestamps(document) + schedule = Schedule.unmarshal(self.serializer, document) + schedules.append(schedule) if schedules: now = datetime.now(timezone.utc) @@ -380,7 +433,11 @@ async def release_schedules( if requests: async for attempt in self._retry(): with attempt, self.client.start_session() as session: - self._schedules.bulk_write(requests, ordered=False, session=session) + await to_thread.run_sync( + lambda: self._schedules.bulk_write( + requests, ordered=False, session=session + ) + ) for schedule_id, next_fire_time in updated_schedules: event = ScheduleUpdated( @@ -402,10 +459,12 @@ async def release_schedules( async def get_next_schedule_run_time(self) -> datetime | None: async for attempt in self._retry(): with attempt: - document = self._schedules.find_one( - {"next_fire_time": {"$ne": None}}, - projection=["next_fire_time", "next_fire_time_utcoffset"], - sort=[("next_fire_time", ASCENDING)], + document = await to_thread.run_sync( + lambda: self._schedules.find_one( + {"next_fire_time": {"$ne": None}}, + projection=["next_fire_time", "next_fire_time_utcoffset"], + sort=[("next_fire_time", ASCENDING)], + ) ) if document: @@ -419,7 +478,7 @@ async def add_job(self, job: Job) -> None: marshal_document(document) async for attempt in self._retry(): with attempt: - self._jobs.insert_one(document) + await to_thread.run_sync(self._jobs.insert_one, document) event = JobAdded( job_id=job.id, @@ -433,19 +492,21 @@ async def get_jobs(self, ids: Iterable[UUID] | None = None) -> list[Job]: async for attempt in self._retry(): with attempt: jobs: list[Job] = [] - cursor = self._jobs.find(filters).sort("_id") - for document in cursor: - document["id"] = document.pop("_id") - unmarshal_timestamps(document) - try: - job = Job.unmarshal(self.serializer, document) - except DeserializationError: - self._logger.warning( - "Failed to deserialize job %r", document["id"] - ) - continue + async with await AsyncCursor.create( + lambda: self._jobs.find(filters).sort("_id") + ) as cursor: + async for document in cursor: + document["id"] = document.pop("_id") + unmarshal_timestamps(document) + try: + job = Job.unmarshal(self.serializer, document) + except DeserializationError: + self._logger.warning( + "Failed to deserialize job %r", document["id"] + ) + continue - jobs.append(job) + jobs.append(job) return jobs @@ -453,25 +514,32 @@ async def acquire_jobs(self, worker_id: str, limit: int | None = None) -> list[J async for attempt in self._retry(): with attempt, self.client.start_session() as session: now = datetime.now(timezone.utc) - cursor = self._jobs.find( - { - "$or": [ - {"acquired_until": {"$exists": False}}, - {"acquired_until": {"$lt": now.timestamp()}}, - ] - }, - sort=[("created_at", ASCENDING)], - limit=limit, - session=session, - ) - documents = list(cursor) + async with await AsyncCursor.create( + lambda: self._jobs.find( + { + "$or": [ + {"acquired_until": {"$exists": False}}, + {"acquired_until": {"$lt": now.timestamp()}}, + ] + }, + sort=[("created_at", ASCENDING)], + limit=limit, + session=session, + ) + ) as cursor: + documents = [doc async for doc in cursor] # Retrieve the limits task_ids: set[str] = {document["task_id"] for document in documents} - task_limits = self._tasks.find( - {"_id": {"$in": list(task_ids)}, "max_running_jobs": {"$ne": None}}, - projection=["max_running_jobs", "running_jobs"], - session=session, + task_limits = await to_thread.run_sync( + lambda: self._tasks.find( + { + "_id": {"$in": list(task_ids)}, + "max_running_jobs": {"$ne": None}, + }, + projection=["max_running_jobs", "running_jobs"], + session=session, + ) ) job_slots_left = { doc["_id"]: doc["max_running_jobs"] - doc["running_jobs"] @@ -508,14 +576,18 @@ async def acquire_jobs(self, worker_id: str, limit: int | None = None) -> list[J **marshal_timestamp(acquired_until, "acquired_until"), } } - self._jobs.update_many(filters, update, session=session) + await to_thread.run_sync( + lambda: self._jobs.update_many(filters, update, session=session) + ) # Increment the running job counters on each task for task_id, increment in increments.items(): - self._tasks.find_one_and_update( - {"_id": task_id}, - {"$inc": {"running_jobs": increment}}, - session=session, + await to_thread.run_sync( + lambda: self._tasks.find_one_and_update( + {"_id": task_id}, + {"$inc": {"running_jobs": increment}}, + session=session, + ) ) # Publish the appropriate events @@ -539,17 +611,27 @@ async def release_job( self._jobs_results.insert_one(document, session=session) # Decrement the running jobs counter - self._tasks.find_one_and_update( - {"_id": task_id}, {"$inc": {"running_jobs": -1}}, session=session + await to_thread.run_sync( + lambda: self._tasks.find_one_and_update( + {"_id": task_id}, + {"$inc": {"running_jobs": -1}}, + session=session, + ) ) # Delete the job - self._jobs.delete_one({"_id": result.job_id}, session=session) + await to_thread.run_sync( + lambda: self._jobs.delete_one( + {"_id": result.job_id}, session=session + ) + ) async def get_job_result(self, job_id: UUID) -> JobResult | None: async for attempt in self._retry(): with attempt: - document = self._jobs_results.find_one_and_delete({"_id": job_id}) + document = await to_thread.run_sync( + lambda: self._jobs_results.find_one_and_delete({"_id": job_id}) + ) if document: document["job_id"] = document.pop("_id") @@ -571,27 +653,28 @@ async def cleanup(self) -> None: ) # Find finished schedules - cursor = await to_thread.run_sync( + async with await AsyncCursor.create( lambda: self._schedules.find( {"next_fire_time": None}, projection=["_id"], session=session ) - ) - if finished_schedule_ids := {item["_id"] for item in cursor}: - # Find distinct schedule IDs of jobs associated with these schedules - for schedule_id in await to_thread.run_sync( - lambda: self._jobs.distinct( - "schedule_id", - {"schedule_id": {"$in": list(finished_schedule_ids)}}, - session=session, - ) - ): - finished_schedule_ids.discard(schedule_id) - - # Delete finished schedules that not having any associated jobs - if finished_schedule_ids: - await to_thread.run_sync( - lambda: self._jobs_results.delete_many( + ) as cursor: + if finished_schedule_ids := {item["_id"] async for item in cursor}: + # Find distinct schedule IDs of jobs associated with these + # schedules + for schedule_id in await to_thread.run_sync( + lambda: self._jobs.distinct( + "schedule_id", {"schedule_id": {"$in": list(finished_schedule_ids)}}, session=session, ) + ): + finished_schedule_ids.discard(schedule_id) + + # Delete finished schedules that not having any associated jobs + if finished_schedule_ids: + await to_thread.run_sync( + lambda: self._jobs_results.delete_many( + {"schedule_id": {"$in": list(finished_schedule_ids)}}, + session=session, ) + )