Skip to content

Commit

Permalink
Add hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
hypergonial committed Jan 1, 2024
1 parent 06aa77a commit ec95ab2
Show file tree
Hide file tree
Showing 18 changed files with 613 additions and 47 deletions.
11 changes: 10 additions & 1 deletion arc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from arc import abc, command

from .abc import Option
from .abc import HookResult, Option, with_hook, with_post_hook
from .client import Client, GatewayClient, GatewayContext, GatewayPlugin, RESTClient, RESTContext, RESTPlugin
from .command import (
AttachmentParams,
Expand Down Expand Up @@ -42,6 +42,7 @@
from .extension import loader, unloader
from .internal.about import __author__, __author_email__, __license__, __maintainer__, __url__, __version__
from .plugin import GatewayPluginBase, PluginBase, RESTPluginBase
from .utils import bot_has_permissions, dm_only, guild_only, has_permissions, owner_only

__all__ = (
"__version__",
Expand Down Expand Up @@ -94,8 +95,16 @@
"RESTContext",
"RESTPlugin",
"GatewayPlugin",
"HookResult",
"abc",
"command",
"with_hook",
"with_post_hook",
"bot_has_permissions",
"dm_only",
"guild_only",
"has_permissions",
"owner_only",
)

# MIT License
Expand Down
5 changes: 5 additions & 0 deletions arc/abc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .client import Client
from .command import CallableCommandBase, CallableCommandProto, CommandBase, CommandProto
from .error_handler import HasErrorHandler
from .hookable import Hookable, HookResult, with_hook, with_post_hook
from .option import CommandOptionBase, Option, OptionBase, OptionParams, OptionWithChoices, OptionWithChoicesParams
from .plugin import PluginBase

Expand All @@ -18,4 +19,8 @@
"OptionWithChoicesParams",
"Client",
"PluginBase",
"Hookable",
"HookResult",
"with_hook",
"with_post_hook",
)
99 changes: 93 additions & 6 deletions arc/abc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from arc.context import AutodeferMode, Context
from arc.errors import ExtensionLoadError, ExtensionUnloadError
from arc.internal.sync import _sync_commands
from arc.internal.types import AppT, BuilderT, ResponseBuilderT
from arc.internal.types import AppT, BuilderT, HookT, PostHookT, ResponseBuilderT

if t.TYPE_CHECKING:
import typing_extensions as te
Expand Down Expand Up @@ -64,6 +64,9 @@ class Client(t.Generic[AppT], abc.ABC):
"_autosync",
"_plugins",
"_loaded_extensions",
"_hooks",
"_post_hooks",
"_owner_ids",
)

def __init__(
Expand All @@ -79,6 +82,9 @@ def __init__(
self._plugins: dict[str, PluginBase[te.Self]] = {}
self._loaded_extensions: list[str] = []
self._autosync = autosync
self._hooks: list[HookT[te.Self]] = []
self._post_hooks: list[PostHookT[te.Self]] = []
self._owner_ids: list[hikari.Snowflake] = []

@property
@abc.abstractmethod
Expand Down Expand Up @@ -143,6 +149,21 @@ def plugins(self) -> t.Mapping[str, PluginBase[te.Self]]:
"""The plugins added to this client."""
return self._plugins

@property
def hooks(self) -> t.MutableSequence[HookT[te.Self]]:
"""The pre-execution hooks for this client."""
return self._hooks

@property
def post_hooks(self) -> t.MutableSequence[PostHookT[te.Self]]:
"""The post-execution hooks for this client."""
return self._post_hooks

@property
def owner_ids(self) -> t.Sequence[hikari.Snowflake]:
"""The IDs of the owners of this application."""
return self._owner_ids

def _add_command(self, command: CommandBase[te.Self, t.Any]) -> None:
"""Add a command to this client. Called by include hooks."""
if isinstance(command, (SlashCommand, SlashGroup)):
Expand Down Expand Up @@ -193,6 +214,14 @@ async def _on_startup(self) -> None:
Fetches application, syncs commands, calls user-defined startup.
"""
self._application = await self.app.rest.fetch_application()

owner_ids = [self._application.owner.id]

if self._application.team is not None:
owner_ids.extend(member for member in self._application.team.members)

self._owner_ids = owner_ids

logger.debug(f"Fetched application: '{self.application}'")
if self._autosync:
await _sync_commands(self)
Expand Down Expand Up @@ -230,7 +259,9 @@ async def on_error(self, context: Context[te.Self], exception: Exception) -> Non
print(f"Unhandled error in command '{context.command.name}' callback: {exception}", file=sys.stderr)
traceback.print_exception(type(exception), exception, exception.__traceback__, file=sys.stderr)
with suppress(Exception):
await context.respond("❌ Something went wrong. Please contact the bot developer.")
# Try to respond to make autodefer less jarring when a command fails.
if not context._issued_response and context.is_valid:
await context.respond("❌ Something went wrong. Please contact the bot developer.")

async def on_command_interaction(self, interaction: hikari.CommandInteraction) -> ResponseBuilderT | None:
"""Should be called when a command interaction is sent by Discord.
Expand Down Expand Up @@ -381,17 +412,23 @@ async def cmd(ctx: arc.GatewayContext) -> None:
group._client_include_hook(self)
return group

def add_plugin(self, plugin: PluginBase[te.Self]) -> None:
def add_plugin(self, plugin: PluginBase[te.Self]) -> te.Self:
"""Add a plugin to this client.
Parameters
----------
plugin : Plugin[te.Self]
The plugin to add.
Returns
-------
te.Self
The client for chaining calls.
"""
plugin._client_include_hook(self)
return self

def remove_plugin(self, plugin: str | PluginBase[te.Self]) -> None:
def remove_plugin(self, plugin: str | PluginBase[te.Self]) -> te.Self:
"""Remove a plugin from this client.
Parameters
Expand All @@ -403,18 +440,62 @@ def remove_plugin(self, plugin: str | PluginBase[te.Self]) -> None:
------
ValueError
If there is no plugin with the given name.
Returns
-------
te.Self
The client for chaining calls.
"""
if isinstance(plugin, PluginBase):
if plugin not in self.plugins.values():
raise ValueError(f"Plugin '{plugin.name}' is not registered with this client.")
return plugin._client_remove_hook()
plugin._client_remove_hook()
return self

pg = self.plugins.get(plugin)

if pg is None:
raise ValueError(f"Plugin '{plugin}' is not registered with this client.")

pg._client_remove_hook()
return self

def add_hook(self, hook: HookT[te.Self]) -> te.Self:
"""Add a pre-execution hook to this client.
This hook will be executed before every command callback added to this client.
Parameters
----------
hook : HookT[te.Self]
The hook to add.
Returns
-------
te.Self
The client for chaining calls.
"""
self._hooks.append(hook)
return self

def add_post_hook(self, hook: PostHookT[te.Self]) -> te.Self:
"""Add a post-execution hook to this client.
This hook will be executed after every command callback added to this client.
!!! warning
Post-execution hooks will be called even if the command callback raises an exception.
Parameters
----------
hook : PostHookT[te.Self]
The hook to add.
Returns
-------
te.Self
The client for chaining calls.
"""
self._post_hooks.append(hook)
return self

def load_extension(self, path: str) -> te.Self:
"""Load a python module with path `path` as an extension.
Expand Down Expand Up @@ -568,7 +649,7 @@ def unload_extension(self, path: str) -> te.Self:

return self

def set_type_dependency(self, type_: t.Type[T], instance: T) -> None:
def set_type_dependency(self, type_: t.Type[T], instance: T) -> te.Self:
"""Set a type dependency for this client. This can then be injected into all arc callbacks.
Parameters
Expand All @@ -578,6 +659,11 @@ def set_type_dependency(self, type_: t.Type[T], instance: T) -> None:
instance : T
The instance of the dependency.
Returns
-------
te.Self
The client for chaining calls.
Usage
-----
Expand All @@ -603,6 +689,7 @@ async def cmd(ctx: arc.GatewayContext, dep: MyDependency = arc.inject()) -> None
A decorator to inject dependencies into arbitrary functions.
"""
self._injector.set_type_dependency(type_, instance)
return self

def get_type_dependency(self, type_: t.Type[T]) -> hikari.UndefinedOr[T]:
"""Get a type dependency for this client.
Expand Down
Loading

0 comments on commit ec95ab2

Please sign in to comment.