Skip to content

Commit

Permalink
Add event listeners
Browse files Browse the repository at this point in the history
  • Loading branch information
hypergonial committed Jan 3, 2024
1 parent a20c6d3 commit 9b0707e
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 24 deletions.
45 changes: 40 additions & 5 deletions arc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from arc.context import Context
from arc.errors import NoResponseIssuedError
from arc.events import CommandErrorEvent
from arc.internal.sigparse import parse_event_signature
from arc.plugin import GatewayPluginBase, RESTPluginBase

if t.TYPE_CHECKING:
Expand Down Expand Up @@ -98,28 +99,62 @@ async def _on_error(self, ctx: Context[te.Self], exception: Exception) -> None:

self.app.event_manager.dispatch(CommandErrorEvent(self, ctx, exception))

def subscribe(self, event_type: type[EventT], callback: EventCallbackT[EventT]) -> None:
"""Subscribe to an event.
Parameters
----------
event_type : type[EventT]
The event type to subscribe to.
`EventT` must be a subclass of `hikari.events.base_events.Event`.
callback : t.Callable[t.Concatenate[EventT, ...], t.Awaitable[None]]
The callback to call when the event is dispatched.
"""
self.app.event_manager.subscribe(event_type, callback) # pyright: ignore reportGeneralTypeIssues

def unsubscribe(self, event_type: type[EventT], callback: EventCallbackT[EventT]) -> None:
"""Unsubscribe from an event.
Parameters
----------
event_type : type[EventT]
The event type to unsubscribe from.
callback : t.Callable[t.Concatenate[EventT, ...], t.Awaitable[None]]
The callback to unsubscribe.
"""
self.app.event_manager.unsubscribe(event_type, callback) # pyright: ignore reportGeneralTypeIssues

def listen(self, *event_types: t.Type[EventT]) -> t.Callable[[EventCallbackT[EventT]], EventCallbackT[EventT]]:
"""Generate a decorator to subscribe a callback to an event type.
This is a second-order decorator.
Parameters
----------
*event_types : t.Type[EventT] | None
The event types to subscribe to. The implementation may allow this
to be undefined. If this is the case, the event type will be inferred
*event_types : type[EventT]
The event types to subscribe to. If not provided, the event type will be inferred
instead from the type hints on the function signature.
`EventT` must be a subclass of `hikari.events.base_events.Event`.
Returns
-------
t.Callable[[EventT], EventT]
t.Callable[t.Callable[t.Concatenate[EventT, ...], t.Awaitable[None]]], t.Callable[t.Concatenate[EventT, ...], t.Awaitable[None]]]
A decorator for a coroutine function that passes it to
`EventManager.subscribe` before returning the function
reference.
"""
return self.app.event_manager.listen(*event_types)

def decorator(func: EventCallbackT[EventT]) -> EventCallbackT[EventT]:
types = event_types or parse_event_signature(func)

for event_type in types:
self.subscribe(event_type, func)

return func

return decorator


class RESTClient(Client[hikari.RESTBotAware]):
Expand Down
6 changes: 3 additions & 3 deletions arc/command/slash.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from arc.abc.option import OptionWithChoices
from arc.context import AutocompleteData, AutodeferMode, Context
from arc.errors import AutocompleteError, CommandInvokeError
from arc.internal.sigparse import parse_function_signature
from arc.internal.sigparse import parse_command_signature
from arc.internal.types import ClientT, CommandCallbackT, HookT, PostHookT, ResponseBuilderT, SlashCommandLike
from arc.locale import CommandLocaleRequest, LocaleResponse

Expand Down Expand Up @@ -753,7 +753,7 @@ async def hi_slash(

def decorator(func: t.Callable[t.Concatenate[Context[ClientT], ...], t.Awaitable[None]]) -> SlashCommand[ClientT]:
guild_ids = [hikari.Snowflake(guild) for guild in guilds] if guilds else []
options = parse_function_signature(func)
options = parse_command_signature(func)

return SlashCommand(
callback=func,
Expand Down Expand Up @@ -819,7 +819,7 @@ async def hi_slashsub(
def decorator(
func: t.Callable[t.Concatenate[Context[ClientT], ...], t.Awaitable[None]],
) -> SlashSubCommand[ClientT]:
options = parse_function_signature(func)
options = parse_command_signature(func)

return SlashSubCommand(
callback=func,
Expand Down
30 changes: 26 additions & 4 deletions arc/internal/sigparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@
if t.TYPE_CHECKING:
from arc.abc.option import CommandOptionBase
from arc.context import Context
from arc.internal.types import ClientT
from arc.internal.types import ClientT, EventT


__all__ = ("parse_function_signature",)
__all__ = ("parse_command_signature",)

TYPE_TO_OPTION_MAPPING: dict[t.Type[t.Any], t.Type[CommandOptionBase[t.Any, t.Any, t.Any]]] = {
bool: BoolOption,
Expand Down Expand Up @@ -249,8 +249,7 @@ def _parse_channel_union_type_hint(hint: t.Any) -> list[hikari.ChannelType]:
return _channels_to_channel_types(arg for arg in args if arg is not type(None))


# TODO Detect if param has a default value and also make it optional
def parse_function_signature( # noqa: C901
def parse_command_signature( # noqa: C901
func: t.Callable[t.Concatenate[Context[ClientT], ...], t.Awaitable[None]],
) -> dict[str, CommandOptionBase[t.Any, t.Any, t.Any]]:
"""Parse a command callback function's signature and return a list of options.
Expand Down Expand Up @@ -367,6 +366,29 @@ def parse_function_signature( # noqa: C901
return options


def parse_event_signature(func: t.Callable[t.Concatenate[EventT, ...], t.Awaitable[None]]) -> list[type[EventT]]:
"""Parse an event callback function's signature and return the event type, ignore other type hints."""
hints = t.get_type_hints(func)

# Remove the return type
hints.pop("return", None)

first = next(iter(hints.values()))

if _is_union(first):
events = [arg for arg in t.get_args(first) if issubclass(arg, hikari.Event)]
if not events:
raise TypeError("Expected event callback to have first argument that inherits from 'hikari.Event'")
return events # pyright: ignore reportGeneralTypeIssues

elif issubclass(first, hikari.Event):
return [first] # pyright: ignore reportGeneralTypeIssues

raise TypeError(
f"Expected event callback to have first argument that inherits from 'hikari.Event', got '{first!r}'"
)


# MIT License
#
# Copyright (c) 2023-present hypergonial
Expand Down
6 changes: 4 additions & 2 deletions arc/internal/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@
P = t.ParamSpec("P")

# Type aliases
EventCallbackT: t.TypeAlias = "t.Callable[[EventT], t.Coroutine[t.Any, t.Any, None]]"
ErrorHandlerCallbackT: t.TypeAlias = "t.Callable[[Context[ClientT], Exception], t.Coroutine[t.Any, t.Any, None]]"
EventCallbackT: t.TypeAlias = "t.Callable[t.Concatenate[EventT, ...], t.Awaitable[None]]"
ErrorHandlerCallbackT: t.TypeAlias = (
"t.Callable[t.Concatenate[Context[ClientT], Exception, ...], t.Coroutine[t.Any, t.Any, None]]"
)
SlashCommandLike: t.TypeAlias = "SlashCommand[ClientT] | SlashGroup[ClientT]"
CommandCallbackT: t.TypeAlias = "t.Callable[t.Concatenate[Context[ClientT], ...], t.Awaitable[None]]"
MessageContextCallbackT: t.TypeAlias = (
Expand Down
78 changes: 73 additions & 5 deletions arc/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import typing as t

from arc.abc.plugin import PluginBase
from arc.internal.types import EventCallbackT, GatewayClientT, RESTClientT
from arc.internal.sigparse import parse_event_signature
from arc.internal.types import EventCallbackT, EventT, GatewayClientT, RESTClientT

if t.TYPE_CHECKING:
import hikari
Expand Down Expand Up @@ -71,30 +72,97 @@ async def ping(ctx: arc.GatewayContext) -> None:

def __init__(self, name: str) -> None:
super().__init__(name)
self._listeners: dict[t.Type[hikari.Event], set[EventCallbackT[t.Any]]] = {}
self._listeners: dict[type[hikari.Event], set[EventCallbackT[t.Any]]] = {}

@property
def is_rest(self) -> bool:
return False

@property
def listeners(self) -> t.Mapping[t.Type[hikari.Event], t.Collection[EventCallbackT[t.Any]]]:
def listeners(self) -> t.Mapping[type[hikari.Event], t.Collection[EventCallbackT[t.Any]]]:
return self._listeners

def subscribe(self, event: type[hikari.Event], callback: EventCallbackT[t.Any]) -> None:
"""Subscribe to an event.
Parameters
----------
event : type[hikari.Event]
The event to subscribe to.
callback : Callable[[EventT], Awaitable[None]]
The callback to call when the event is dispatched.
"""
if event not in self.listeners:
self._listeners[event] = set()

self._listeners[event].add(callback)

if self._client is not None:
self._client.subscribe(event, callback)

def unsubscribe(self, event: type[hikari.Event], callback: EventCallbackT[t.Any]) -> None:
"""Unsubscribe from an event.
Parameters
----------
event : type[hikari.Event]
The event to unsubscribe from.
callback : Callable[[EventT], Awaitable[None]]
The callback to unsubscribe.
"""
if event not in self.listeners:
return

self._listeners[event].remove(callback)

if self._client is not None:
self._client.unsubscribe(event, callback)

def listen(self, *event_types: type[EventT]) -> t.Callable[[EventCallbackT[EventT]], EventCallbackT[EventT]]:
"""Generate a decorator to subscribe a callback to an event type.
This is a second-order decorator.
Parameters
----------
*event_types : type[EventT]
The event types to subscribe to. If not provided, the event type will be inferred
instead from the type hints on the function signature.
`EventT` must be a subclass of `hikari.events.base_events.Event`.
Returns
-------
t.Callable[[EventT], EventT]
A decorator for a coroutine function that passes it to
`EventManager.subscribe` before returning the function
reference.
"""

def decorator(func: EventCallbackT[EventT]) -> EventCallbackT[EventT]:
types = event_types or parse_event_signature(func)

for event_type in types:
self.subscribe(event_type, func)

return func

return decorator

def _client_include_hook(self, client: GatewayClientT) -> None:
super()._client_include_hook(client)

for event, callbacks in self.listeners.items():
for callback in callbacks:
client.app.event_manager.subscribe(event, callback)
client.subscribe(event, callback)

def _client_remove_hook(self) -> None:
if self._client is None:
raise RuntimeError(f"Plugin '{self.name}' is not included in a client.")

for event, callbacks in self.listeners.items():
for callback in callbacks:
self.client.app.event_manager.unsubscribe(event, callback)
self.client.unsubscribe(event, callback)

super()._client_remove_hook()

Expand Down
10 changes: 5 additions & 5 deletions tests/test_sigparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest

import arc
from arc.internal.sigparse import parse_function_signature
from arc.internal.sigparse import parse_command_signature


async def correct_command(
Expand All @@ -21,7 +21,7 @@ async def correct_command(


def test_correct_command() -> None:
options = parse_function_signature(correct_command)
options = parse_command_signature(correct_command)
assert len(options) == 9

assert isinstance(options["a"], arc.command.IntOption)
Expand Down Expand Up @@ -100,7 +100,7 @@ async def wrong_params_type(

def test_wrong_params_type() -> None:
with pytest.raises(TypeError):
parse_function_signature(wrong_params_type)
parse_command_signature(wrong_params_type)


class WrongType:
Expand All @@ -113,7 +113,7 @@ async def wrong_opt_type(ctx: arc.GatewayContext, a: arc.Option[WrongType, arc.I

def test_wrong_opt_type() -> None:
with pytest.raises(TypeError):
parse_function_signature(wrong_opt_type)
parse_command_signature(wrong_opt_type)


class MyType:
Expand All @@ -131,7 +131,7 @@ async def di_annotation(


def test_di_annotation() -> None:
options = parse_function_signature(di_annotation)
options = parse_command_signature(di_annotation)
assert len(options) == 2

assert isinstance(options["a"], arc.command.IntOption)
Expand Down

0 comments on commit 9b0707e

Please sign in to comment.