From b4db236cceb826ed32d64ff396800c258b8e15ef Mon Sep 17 00:00:00 2001 From: Peter Schutt Date: Sun, 18 Feb 2024 10:14:19 +1000 Subject: [PATCH] Change to @overload approach. This commit keeps the behavior of the subscribe methods as they are for the case where a subscription is added for an iterable of event types. However, for the case where a subscription is added for a single event type, it allows the callable to be typed to receive an instance of that event type, removing the need to type-narrow in that case. --- src/apscheduler/_schedulers/async_.py | 24 +++++++++++++++++++++++- src/apscheduler/_schedulers/sync.py | 26 +++++++++++++++++++++++--- src/apscheduler/eventbrokers/base.py | 6 +++--- 3 files changed, 49 insertions(+), 7 deletions(-) diff --git a/src/apscheduler/_schedulers/async_.py b/src/apscheduler/_schedulers/async_.py index 82ced166..8260aa1f 100644 --- a/src/apscheduler/_schedulers/async_.py +++ b/src/apscheduler/_schedulers/async_.py @@ -11,7 +11,7 @@ from inspect import isbuiltin, isclass, ismethod, ismodule from logging import Logger, getLogger from types import TracebackType -from typing import Any, Callable, Iterable, Mapping, cast +from typing import Any, Callable, Iterable, Mapping, cast, overload from uuid import UUID, uuid4 import anyio @@ -215,6 +215,28 @@ async def cleanup(self) -> None: await self.data_store.cleanup() self.logger.info("Cleaned up expired job results and finished schedules") + @overload + def subscribe( + self, + callback: Callable[[T_Event], Any], + event_types: type[T_Event], + *, + one_shot: bool = ..., + is_async: bool = ..., + ) -> Subscription: + ... + + @overload + def subscribe( + self, + callback: Callable[[Event], Any], + event_types: Iterable[type[Event]] | None = None, + *, + one_shot: bool = False, + is_async: bool = True, + ) -> Subscription: + ... + def subscribe( self, callback: Callable[[T_Event], Any], diff --git a/src/apscheduler/_schedulers/sync.py b/src/apscheduler/_schedulers/sync.py index fe3fdafa..f605a3ac 100644 --- a/src/apscheduler/_schedulers/sync.py +++ b/src/apscheduler/_schedulers/sync.py @@ -10,14 +10,14 @@ from functools import partial from logging import Logger from types import TracebackType -from typing import Any, Callable, Iterable, Mapping +from typing import Any, Callable, Iterable, Mapping, overload from uuid import UUID from anyio.from_thread import BlockingPortal, start_blocking_portal from .. import current_scheduler from .._enums import CoalescePolicy, ConflictPolicy, RunState, SchedulerRole -from .._events import T_Event +from .._events import Event, T_Event from .._structures import Job, JobResult, Schedule, Task from .._utils import UnsetValue, unset from ..abc import DataStore, EventBroker, JobExecutor, Subscription, Trigger @@ -156,10 +156,30 @@ def cleanup(self) -> None: self._ensure_services_ready() return self._portal.call(self._async_scheduler.cleanup) + @overload def subscribe( self, callback: Callable[[T_Event], Any], - event_types: Iterable[type[T_Event]] | None = None, + event_types: type[T_Event], + *, + one_shot: bool = ..., + ) -> Subscription: + ... + + @overload + def subscribe( + self, + callback: Callable[[Event], Any], + event_types: Iterable[type[Event]] | None = None, + *, + one_shot: bool = False, + ) -> Subscription: + ... + + def subscribe( + self, + callback: Callable[[T_Event], Any], + event_types: type[T_Event] | Iterable[type[T_Event]] | None = None, *, one_shot: bool = False, ) -> Subscription: diff --git a/src/apscheduler/eventbrokers/base.py b/src/apscheduler/eventbrokers/base.py index a9648e54..3c106ba9 100644 --- a/src/apscheduler/eventbrokers/base.py +++ b/src/apscheduler/eventbrokers/base.py @@ -11,7 +11,7 @@ from anyio.abc import TaskGroup from .. import _events -from .._events import Event, T_Event +from .._events import Event from .._exceptions import DeserializationError from .._retry import RetryMixin from ..abc import EventBroker, Serializer, Subscription @@ -47,8 +47,8 @@ async def start(self, exit_stack: AsyncExitStack, logger: Logger) -> None: def subscribe( self, - callback: Callable[[T_Event], Any], - event_types: Iterable[type[T_Event]] | None = None, + callback: Callable[[Event], Any], + event_types: Iterable[type[Event]] | None = None, *, is_async: bool = True, one_shot: bool = False,