Skip to content

Commit

Permalink
Change to @overload approach.
Browse files Browse the repository at this point in the history
This commit keeps the behavior of the subscribe methods as they are for the case where a subscription is added for an iterable of event types. However, for the case where a subscription is added for a single event type, it allows the callable to be typed to receive an instance of that event type, removing the need to type-narrow in that case.
  • Loading branch information
peterschutt committed Feb 29, 2024
1 parent 82efa5d commit f2b6f3a
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 7 deletions.
24 changes: 23 additions & 1 deletion src/apscheduler/_schedulers/async_.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from inspect import isbuiltin, isclass, ismethod, ismodule
from logging import Logger, getLogger
from types import TracebackType
from typing import Any, Callable, Iterable, Mapping, cast
from typing import Any, Callable, Iterable, Mapping, cast, overload
from uuid import UUID, uuid4

import anyio
Expand Down Expand Up @@ -215,6 +215,28 @@ async def cleanup(self) -> None:
await self.data_store.cleanup()
self.logger.info("Cleaned up expired job results and finished schedules")

@overload
def subscribe(
self,
callback: Callable[[T_Event], Any],
event_types: type[T_Event],
*,
one_shot: bool = ...,
is_async: bool = ...,
) -> Subscription:
...

@overload
def subscribe(
self,
callback: Callable[[Event], Any],
event_types: Iterable[type[Event]] | None = None,
*,
one_shot: bool = False,
is_async: bool = True,
) -> Subscription:
...

def subscribe(
self,
callback: Callable[[T_Event], Any],
Expand Down
26 changes: 23 additions & 3 deletions src/apscheduler/_schedulers/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
from functools import partial
from logging import Logger
from types import TracebackType
from typing import Any, Callable, Iterable, Mapping
from typing import Any, Callable, Iterable, Mapping, overload
from uuid import UUID

from anyio.from_thread import BlockingPortal, start_blocking_portal

from .. import current_scheduler
from .._enums import CoalescePolicy, ConflictPolicy, RunState, SchedulerRole
from .._events import T_Event
from .._events import Event, T_Event
from .._structures import Job, JobResult, Schedule, Task
from .._utils import UnsetValue, unset
from ..abc import DataStore, EventBroker, JobExecutor, Subscription, Trigger
Expand Down Expand Up @@ -156,10 +156,30 @@ def cleanup(self) -> None:
self._ensure_services_ready()
return self._portal.call(self._async_scheduler.cleanup)

@overload
def subscribe(
self,
callback: Callable[[T_Event], Any],
event_types: Iterable[type[T_Event]] | None = None,
event_types: type[T_Event],
*,
one_shot: bool = ...,
) -> Subscription:
...

@overload
def subscribe(
self,
callback: Callable[[Event], Any],
event_types: Iterable[type[Event]] | None = None,
*,
one_shot: bool = False,
) -> Subscription:
...

def subscribe(
self,
callback: Callable[[T_Event], Any],
event_types: type[T_Event] | Iterable[type[T_Event]] | None = None,
*,
one_shot: bool = False,
) -> Subscription:
Expand Down
6 changes: 3 additions & 3 deletions src/apscheduler/eventbrokers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from anyio.abc import TaskGroup

from .. import _events
from .._events import Event, T_Event
from .._events import Event
from .._exceptions import DeserializationError
from .._retry import RetryMixin
from ..abc import EventBroker, Serializer, Subscription
Expand Down Expand Up @@ -47,8 +47,8 @@ async def start(self, exit_stack: AsyncExitStack, logger: Logger) -> None:

def subscribe(
self,
callback: Callable[[T_Event], Any],
event_types: Iterable[type[T_Event]] | None = None,
callback: Callable[[Event], Any],
event_types: Iterable[type[Event]] | None = None,
*,
is_async: bool = True,
one_shot: bool = False,
Expand Down

0 comments on commit f2b6f3a

Please sign in to comment.