diff --git a/arc/__init__.py b/arc/__init__.py index d975dd6..adf24c3 100644 --- a/arc/__init__.py +++ b/arc/__init__.py @@ -13,7 +13,7 @@ from arc import abc, command -from .abc import Option +from .abc import HookResult, Option, with_hook, with_post_hook from .client import Client, GatewayClient, GatewayContext, GatewayPlugin, RESTClient, RESTContext, RESTPlugin from .command import ( AttachmentParams, @@ -42,6 +42,7 @@ from .extension import loader, unloader from .internal.about import __author__, __author_email__, __license__, __maintainer__, __url__, __version__ from .plugin import GatewayPluginBase, PluginBase, RESTPluginBase +from .utils import bot_has_permissions, dm_only, guild_only, has_permissions, owner_only __all__ = ( "__version__", @@ -94,8 +95,16 @@ "RESTContext", "RESTPlugin", "GatewayPlugin", + "HookResult", "abc", "command", + "with_hook", + "with_post_hook", + "bot_has_permissions", + "dm_only", + "guild_only", + "has_permissions", + "owner_only", ) # MIT License diff --git a/arc/abc/__init__.py b/arc/abc/__init__.py index fb6d67a..eace92c 100644 --- a/arc/abc/__init__.py +++ b/arc/abc/__init__.py @@ -1,6 +1,7 @@ from .client import Client from .command import CallableCommandBase, CallableCommandProto, CommandBase, CommandProto from .error_handler import HasErrorHandler +from .hookable import Hookable, HookResult, with_hook, with_post_hook from .option import CommandOptionBase, Option, OptionBase, OptionParams, OptionWithChoices, OptionWithChoicesParams from .plugin import PluginBase @@ -18,4 +19,8 @@ "OptionWithChoicesParams", "Client", "PluginBase", + "Hookable", + "HookResult", + "with_hook", + "with_post_hook", ) diff --git a/arc/abc/client.py b/arc/abc/client.py index 7dfcbbe..4384257 100644 --- a/arc/abc/client.py +++ b/arc/abc/client.py @@ -22,7 +22,7 @@ from arc.context import AutodeferMode, Context from arc.errors import ExtensionLoadError, ExtensionUnloadError from arc.internal.sync import _sync_commands -from arc.internal.types import AppT, BuilderT, ResponseBuilderT +from arc.internal.types import AppT, BuilderT, HookT, PostHookT, ResponseBuilderT if t.TYPE_CHECKING: import typing_extensions as te @@ -64,6 +64,9 @@ class Client(t.Generic[AppT], abc.ABC): "_autosync", "_plugins", "_loaded_extensions", + "_hooks", + "_post_hooks", + "_owner_ids", ) def __init__( @@ -79,6 +82,9 @@ def __init__( self._plugins: dict[str, PluginBase[te.Self]] = {} self._loaded_extensions: list[str] = [] self._autosync = autosync + self._hooks: list[HookT[te.Self]] = [] + self._post_hooks: list[PostHookT[te.Self]] = [] + self._owner_ids: list[hikari.Snowflake] = [] @property @abc.abstractmethod @@ -143,6 +149,21 @@ def plugins(self) -> t.Mapping[str, PluginBase[te.Self]]: """The plugins added to this client.""" return self._plugins + @property + def hooks(self) -> t.MutableSequence[HookT[te.Self]]: + """The pre-execution hooks for this client.""" + return self._hooks + + @property + def post_hooks(self) -> t.MutableSequence[PostHookT[te.Self]]: + """The post-execution hooks for this client.""" + return self._post_hooks + + @property + def owner_ids(self) -> t.Sequence[hikari.Snowflake]: + """The IDs of the owners of this application.""" + return self._owner_ids + def _add_command(self, command: CommandBase[te.Self, t.Any]) -> None: """Add a command to this client. Called by include hooks.""" if isinstance(command, (SlashCommand, SlashGroup)): @@ -193,6 +214,14 @@ async def _on_startup(self) -> None: Fetches application, syncs commands, calls user-defined startup. """ self._application = await self.app.rest.fetch_application() + + owner_ids = [self._application.owner.id] + + if self._application.team is not None: + owner_ids.extend(member for member in self._application.team.members) + + self._owner_ids = owner_ids + logger.debug(f"Fetched application: '{self.application}'") if self._autosync: await _sync_commands(self) @@ -230,7 +259,9 @@ async def on_error(self, context: Context[te.Self], exception: Exception) -> Non print(f"Unhandled error in command '{context.command.name}' callback: {exception}", file=sys.stderr) traceback.print_exception(type(exception), exception, exception.__traceback__, file=sys.stderr) with suppress(Exception): - await context.respond("❌ Something went wrong. Please contact the bot developer.") + # Try to respond to make autodefer less jarring when a command fails. + if not context._issued_response and context.is_valid: + await context.respond("❌ Something went wrong. Please contact the bot developer.") async def on_command_interaction(self, interaction: hikari.CommandInteraction) -> ResponseBuilderT | None: """Should be called when a command interaction is sent by Discord. @@ -381,17 +412,23 @@ async def cmd(ctx: arc.GatewayContext) -> None: group._client_include_hook(self) return group - def add_plugin(self, plugin: PluginBase[te.Self]) -> None: + def add_plugin(self, plugin: PluginBase[te.Self]) -> te.Self: """Add a plugin to this client. Parameters ---------- plugin : Plugin[te.Self] The plugin to add. + + Returns + ------- + te.Self + The client for chaining calls. """ plugin._client_include_hook(self) + return self - def remove_plugin(self, plugin: str | PluginBase[te.Self]) -> None: + def remove_plugin(self, plugin: str | PluginBase[te.Self]) -> te.Self: """Remove a plugin from this client. Parameters @@ -403,11 +440,17 @@ def remove_plugin(self, plugin: str | PluginBase[te.Self]) -> None: ------ ValueError If there is no plugin with the given name. + + Returns + ------- + te.Self + The client for chaining calls. """ if isinstance(plugin, PluginBase): if plugin not in self.plugins.values(): raise ValueError(f"Plugin '{plugin.name}' is not registered with this client.") - return plugin._client_remove_hook() + plugin._client_remove_hook() + return self pg = self.plugins.get(plugin) @@ -415,6 +458,44 @@ def remove_plugin(self, plugin: str | PluginBase[te.Self]) -> None: raise ValueError(f"Plugin '{plugin}' is not registered with this client.") pg._client_remove_hook() + return self + + def add_hook(self, hook: HookT[te.Self]) -> te.Self: + """Add a pre-execution hook to this client. + This hook will be executed before every command callback added to this client. + + Parameters + ---------- + hook : HookT[te.Self] + The hook to add. + + Returns + ------- + te.Self + The client for chaining calls. + """ + self._hooks.append(hook) + return self + + def add_post_hook(self, hook: PostHookT[te.Self]) -> te.Self: + """Add a post-execution hook to this client. + This hook will be executed after every command callback added to this client. + + !!! warning + Post-execution hooks will be called even if the command callback raises an exception. + + Parameters + ---------- + hook : PostHookT[te.Self] + The hook to add. + + Returns + ------- + te.Self + The client for chaining calls. + """ + self._post_hooks.append(hook) + return self def load_extension(self, path: str) -> te.Self: """Load a python module with path `path` as an extension. @@ -568,7 +649,7 @@ def unload_extension(self, path: str) -> te.Self: return self - def set_type_dependency(self, type_: t.Type[T], instance: T) -> None: + def set_type_dependency(self, type_: t.Type[T], instance: T) -> te.Self: """Set a type dependency for this client. This can then be injected into all arc callbacks. Parameters @@ -578,6 +659,11 @@ def set_type_dependency(self, type_: t.Type[T], instance: T) -> None: instance : T The instance of the dependency. + Returns + ------- + te.Self + The client for chaining calls. + Usage ----- @@ -603,6 +689,7 @@ async def cmd(ctx: arc.GatewayContext, dep: MyDependency = arc.inject()) -> None A decorator to inject dependencies into arbitrary functions. """ self._injector.set_type_dependency(type_, instance) + return self def get_type_dependency(self, type_: t.Type[T]) -> hikari.UndefinedOr[T]: """Get a type dependency for this client. diff --git a/arc/abc/command.py b/arc/abc/command.py index aec2929..d33e605 100644 --- a/arc/abc/command.py +++ b/arc/abc/command.py @@ -2,14 +2,25 @@ import abc import asyncio +import inspect import typing as t import attr import hikari from arc.abc.error_handler import HasErrorHandler +from arc.abc.hookable import Hookable, HookResult +from arc.abc.option import OptionBase from arc.context import AutodeferMode -from arc.internal.types import BuilderT, ClientT, CommandCallbackT, ResponseBuilderT +from arc.internal.types import ( + BuilderT, + ClientT, + CommandCallbackT, + ErrorHandlerCallbackT, + HookT, + PostHookT, + ResponseBuilderT, +) if t.TYPE_CHECKING: from arc.abc.plugin import PluginBase @@ -35,7 +46,7 @@ def qualified_name(self) -> t.Sequence[str]: """The fully qualified name of this command.""" -class CallableCommandProto(t.Protocol, t.Generic[ClientT]): +class CallableCommandProto(t.Protocol[ClientT]): """A protocol for any command-like object that can be called directly. This includes commands and subcommands.""" name: str @@ -56,7 +67,7 @@ def qualified_name(self) -> t.Sequence[str]: @abc.abstractmethod async def __call__(self, ctx: Context[ClientT], *args: t.Any, **kwargs: t.Any) -> None: - """Invoke this command with the given context. + """Call the callback of the command with the given context and arguments. Parameters ---------- @@ -98,9 +109,17 @@ async def invoke( async def _handle_exception(self, ctx: Context[ClientT], exc: Exception) -> None: ... + def _resolve_hooks(self) -> t.Sequence[HookT[ClientT]]: + """Resolve all pre-execution hooks that apply to this object.""" + ... + + def _resolve_post_hooks(self) -> t.Sequence[PostHookT[ClientT]]: + """Resolve all post-execution hooks that apply to this object.""" + ... + @attr.define(slots=True, kw_only=True) -class CommandBase(HasErrorHandler[ClientT], t.Generic[ClientT, BuilderT]): +class CommandBase(HasErrorHandler[ClientT], Hookable[ClientT], t.Generic[ClientT, BuilderT]): """An abstract base class for all application commands.""" name: str @@ -134,6 +153,27 @@ class CommandBase(HasErrorHandler[ClientT], t.Generic[ClientT, BuilderT]): _instances: dict[hikari.Snowflake | None, hikari.PartialCommand] = attr.field(factory=dict) """A mapping of guild IDs to command instances. None corresponds to the global instance, if any.""" + _error_handler: ErrorHandlerCallbackT[ClientT] | None = attr.field(init=False, default=None) + + _hooks: list[HookT[ClientT]] = attr.field(init=False, factory=list) + + _post_hooks: list[PostHookT[ClientT]] = attr.field(init=False, factory=list) + + @property + def error_handler(self) -> ErrorHandlerCallbackT[ClientT] | None: + """The error handler for this command.""" + return self._error_handler + + @property + def hooks(self) -> t.MutableSequence[HookT[ClientT]]: + """The pre-execution hooks for this command.""" + return self._hooks + + @property + def post_hooks(self) -> t.MutableSequence[PostHookT[ClientT]]: + """The post-execution hooks for this command.""" + return self._post_hooks + @property @abc.abstractmethod def command_type(self) -> hikari.CommandType: @@ -175,6 +215,14 @@ async def _handle_exception(self, ctx: Context[ClientT], exc: Exception) -> None else: await self.client._on_error(ctx, exc) + def _resolve_hooks(self) -> list[HookT[ClientT]]: + plugin_hooks = self.plugin._resolve_hooks() if self.plugin else [] + return self.client._hooks + plugin_hooks + self._hooks + + def _resolve_post_hooks(self) -> list[PostHookT[ClientT]]: + plugin_hooks = self.plugin._resolve_post_hooks() if self.plugin else [] + return self.client._post_hooks + plugin_hooks + self._post_hooks + async def publish(self, guild: hikari.SnowflakeishOr[hikari.PartialGuild] | None = None) -> hikari.PartialCommand: """Publish this command to the given guild, or globally if no guild is provided. @@ -270,15 +318,60 @@ def _plugin_include_hook(self, plugin: PluginBase[ClientT]) -> None: self._plugin = plugin self._plugin._add_command(self) + async def _handle_pre_hooks(self, command: CallableCommandProto[ClientT], ctx: Context[ClientT]) -> bool: + """Handle all pre-execution hooks for a command. + + Returns + ------- + bool + Whether the command should be aborted. + """ + aborted = False + try: + hooks = command._resolve_hooks() + for hook in hooks: + if inspect.iscoroutinefunction(hook): + res = await hook(ctx) + else: + res = hook(ctx) + + res = t.cast(HookResult | None, res) + + if res and res._abort: + aborted = True + except Exception as e: + aborted = True + await command._handle_exception(ctx, e) + + return aborted + + async def _handle_post_hooks(self, command: CallableCommandProto[ClientT], ctx: Context[ClientT]) -> None: + """Handle all post-execution hooks for a command.""" + try: + post_hooks = command._resolve_post_hooks() + for hook in post_hooks: + if inspect.iscoroutinefunction(hook): + await hook(ctx) + else: + hook(ctx) + except Exception as e: + await command._handle_exception(ctx, e) + async def _handle_callback( self, command: CallableCommandProto[ClientT], ctx: Context[ClientT], *args: t.Any, **kwargs: t.Any ) -> None: + """Handle the callback of a command. Invoke all hooks and the callback, and handle any exceptions.""" + # If hook aborted, stop invocation + if await self._handle_pre_hooks(command, ctx): + return + try: await self.client.injector.call_with_async_di(command.callback, ctx, *args, **kwargs) except Exception as e: + ctx._has_command_failed = True await command._handle_exception(ctx, e) - - # TODO - hooks, max_concurrency, cooldowns + finally: + await self._handle_post_hooks(command, ctx) @attr.define(slots=True, kw_only=True) @@ -302,3 +395,34 @@ async def invoke( self._invoke_task = asyncio.create_task(self._handle_callback(self, ctx, *args, **kwargs)) if self.client.is_rest: return ctx._resp_builder + + +ParentT = t.TypeVar("ParentT") + + +class SubCommandBase(OptionBase[ClientT], HasErrorHandler[ClientT], Hookable[ClientT], t.Generic[ClientT, ParentT]): + """An abstract base class for all slash subcommands and subgroups.""" + + _error_handler: ErrorHandlerCallbackT[ClientT] | None = attr.field(default=None, init=False) + + _hooks: list[HookT[ClientT]] = attr.field(factory=list, init=False) + + _post_hooks: list[PostHookT[ClientT]] = attr.field(factory=list, init=False) + + parent: ParentT | None = attr.field(default=None, init=False) + """The parent of this subcommand or subgroup.""" + + @property + def error_handler(self) -> ErrorHandlerCallbackT[ClientT] | None: + """The error handler for this object.""" + return self._error_handler + + @property + def hooks(self) -> t.MutableSequence[HookT[ClientT]]: + """The pre-execution hooks for this object.""" + return self._hooks + + @property + def post_hooks(self) -> t.MutableSequence[PostHookT[ClientT]]: + """The post-execution hooks for this object.""" + return self._post_hooks diff --git a/arc/abc/error_handler.py b/arc/abc/error_handler.py index 3bb76d9..3894cdd 100644 --- a/arc/abc/error_handler.py +++ b/arc/abc/error_handler.py @@ -3,21 +3,19 @@ import abc import typing as t -import attr - from arc.internal.types import ClientT, ErrorHandlerCallbackT if t.TYPE_CHECKING: from ..context import Context -@attr.define(slots=False) class HasErrorHandler(abc.ABC, t.Generic[ClientT]): - _error_handler: ErrorHandlerCallbackT[ClientT] | None = attr.field(default=None, init=False) + """An interface for objects that can have an error handler set on them.""" @property - def error_handler(self) -> t.Optional[ErrorHandlerCallbackT[ClientT]]: - return self._error_handler + @abc.abstractmethod + def error_handler(self) -> ErrorHandlerCallbackT[ClientT] | None: + """The error handler for this object.""" def set_error_handler(self, callback: ErrorHandlerCallbackT[ClientT]) -> ErrorHandlerCallbackT[ClientT]: """Decorator to set an error handler for this object. This can be added to commands, groups, or plugins. @@ -42,4 +40,4 @@ async def foo_error_handler(ctx: arc.GatewayContext, exc: Exception) -> None: @abc.abstractmethod async def _handle_exception(self, ctx: Context[ClientT], exc: Exception) -> None: - ... + """Handle an exception or propagate it to the next error handler if it cannot be handled.""" diff --git a/arc/abc/hookable.py b/arc/abc/hookable.py new file mode 100644 index 0000000..0346b9a --- /dev/null +++ b/arc/abc/hookable.py @@ -0,0 +1,134 @@ +from __future__ import annotations + +import abc +import typing as t + +from arc.internal.types import ClientT, HookableT, HookT, PostHookT + +if t.TYPE_CHECKING: + import typing_extensions as te + + +class HookResult: + """The result of a hook. + + Parameters + ---------- + abort : bool + Whether to abort the execution of the command. + If True, the command execution will be silently aborted. + If this is undesired, you should raise an exception instead. + """ + + def __init__(self, abort: bool = False) -> None: + self._abort = abort + + +class Hookable(abc.ABC, t.Generic[ClientT]): + """An interface for objects that can have hooks set on them.""" + + @property + @abc.abstractmethod + def hooks(self) -> t.MutableSequence[HookT[ClientT]]: + """The pre-execution hooks for this object.""" + + @property + @abc.abstractmethod + def post_hooks(self) -> t.MutableSequence[PostHookT[ClientT]]: + """The post-execution hooks for this object.""" + + def _resolve_hooks(self) -> list[HookT[ClientT]]: + """Resolve all pre-execution hooks that apply to this object.""" + ... + + def _resolve_post_hooks(self) -> list[PostHookT[ClientT]]: + """Resolve all post-execution hooks that apply to this object.""" + ... + + def add_hook(self, hook: HookT[ClientT]) -> te.Self: + """Add a new pre-execution hook to this object. + + Parameters + ---------- + hook : HookT[ClientT] + The hook to add. + + Returns + ------- + te.Self + This object for chaining. + """ + self.hooks.append(hook) + return self + + def add_post_hook(self, hook: PostHookT[ClientT]) -> te.Self: + """Add a new post-execution hook to this object. + + Parameters + ---------- + hook : PostHookT[ClientT] + The post-execution hook to add. + + Returns + ------- + te.Self + This object for chaining. + """ + self.post_hooks.append(hook) + return self + + +def with_hook(hook: HookT[ClientT]) -> t.Callable[[HookableT], HookableT]: + """Add a new pre-execution hook to a hookable object. It will run before the command callback. + + Any function that takes a [Context][`arc.context.base.Context`] as its sole parameter + and returns either a [`HookResult`][arc.abc.hookable.HookResult] or + `None` can be used as a hook. + + Usage + ----- + ```py + @client.include + @arc.with_hook(arc.guild_only) # Add a pre-execution hook to a command + @arc.slash_command("foo", "Foo command description") + async def foo(ctx: arc.GatewayContext) -> None: + ... + ``` + """ + + def decorator(hookable: HookableT) -> HookableT: + hookable.hooks.append(hook) + return hookable + + return decorator + + +def with_post_hook(hook: PostHookT[ClientT]) -> t.Callable[[HookableT], HookableT]: + """Add a new post-execution hook to a hookable object. It will run after the command callback. + + Any function that takes a [Context][`arc.context.base.Context`] as its sole parameter + and returns either a [`HookResult`][arc.abc.hookable.HookResult] or + `None` can be used as a hook. + + Post-execution hooks are not executed if a pre-execution hook aborts the execution of the command. + + !!! warning + Post-execution hooks **are** called even if the command callback raises an exception. + You can see if the command callback failed by checking [`Context.has_command_failed`][arc.context.base.Context.has_command_failed]. + + Usage + ----- + ```py + @client.include + @arc.with_post_hook(arc.guild_only) # Add a post-execution hook to a command + @arc.slash_command("foo", "Foo command description") + async def foo(ctx: arc.GatewayContext) -> None: + ... + ``` + """ + + def decorator(hookable: HookableT) -> HookableT: + hookable.post_hooks.append(hook) + return hookable + + return decorator diff --git a/arc/abc/plugin.py b/arc/abc/plugin.py index 1480231..1ea691e 100644 --- a/arc/abc/plugin.py +++ b/arc/abc/plugin.py @@ -9,9 +9,10 @@ import hikari from arc.abc.error_handler import HasErrorHandler +from arc.abc.hookable import Hookable from arc.command import MessageCommand, SlashCommand, SlashGroup, UserCommand from arc.context import AutodeferMode, Context -from arc.internal.types import BuilderT, ClientT, SlashCommandLike +from arc.internal.types import BuilderT, ClientT, ErrorHandlerCallbackT, HookT, PostHookT, SlashCommandLike if t.TYPE_CHECKING: from arc.abc.command import CommandBase @@ -23,7 +24,7 @@ T = t.TypeVar("T") -class PluginBase(HasErrorHandler[ClientT], t.Generic[ClientT]): +class PluginBase(HasErrorHandler[ClientT], Hookable[ClientT]): """An abstract base class for plugins. Parameters @@ -35,13 +36,30 @@ class PluginBase(HasErrorHandler[ClientT], t.Generic[ClientT]): def __init__( self, name: str, *, default_enabled_guilds: hikari.UndefinedOr[t.Sequence[hikari.Snowflake]] = hikari.UNDEFINED ) -> None: - super().__init__() self._client: ClientT | None = None self._name = name self._slash_commands: dict[str, SlashCommandLike[ClientT]] = {} self._user_commands: dict[str, UserCommand[ClientT]] = {} self._message_commands: dict[str, MessageCommand[ClientT]] = {} self._default_enabled_guilds = default_enabled_guilds + self._error_handler: ErrorHandlerCallbackT[ClientT] | None = None + self._hooks: list[HookT[ClientT]] = [] + self._post_hooks: list[PostHookT[ClientT]] = [] + + @property + def error_handler(self) -> ErrorHandlerCallbackT[ClientT] | None: + """The error handler for this plugin.""" + return self._error_handler + + @property + def hooks(self) -> t.MutableSequence[HookT[ClientT]]: + """The pre-execution hooks for this plugin.""" + return self._hooks + + @property + def post_hooks(self) -> t.MutableSequence[PostHookT[ClientT]]: + """The post-execution hooks for this plugin.""" + return self._post_hooks @property @abc.abstractmethod @@ -67,6 +85,21 @@ def default_enabled_guilds(self) -> hikari.UndefinedOr[t.Sequence[hikari.Snowfla """The default guilds to enable commands in.""" return self._default_enabled_guilds + async def _handle_exception(self, ctx: Context[ClientT], exc: Exception) -> None: + try: + if self.error_handler is not None: + await self.error_handler(ctx, exc) + else: + raise exc + except Exception as exc: + await self.client._on_error(ctx, exc) + + def _resolve_hooks(self) -> list[HookT[ClientT]]: + return self._hooks + + def _resolve_post_hooks(self) -> list[PostHookT[ClientT]]: + return self._post_hooks + def _client_include_hook(self, client: ClientT) -> None: if client._plugins.get(self.name) is not None: raise RuntimeError(f"Plugin '{self.name}' is already included in client.") @@ -120,15 +153,6 @@ def include(self, command: CommandBase[ClientT, BuilderT]) -> CommandBase[Client command._plugin_include_hook(self) return command - async def _handle_exception(self, ctx: Context[ClientT], exc: Exception) -> None: - try: - if self.error_handler is not None: - await self.error_handler(ctx, exc) - else: - raise exc - except Exception as exc: - await self.client._on_error(ctx, exc) - def include_slash_group( self, name: str, diff --git a/arc/command/slash.py b/arc/command/slash.py index 341ca85..8cc16f3 100644 --- a/arc/command/slash.py +++ b/arc/command/slash.py @@ -6,13 +6,12 @@ import attr import hikari -from arc.abc.command import CallableCommandBase, CommandBase -from arc.abc.error_handler import HasErrorHandler -from arc.abc.option import OptionBase, OptionWithChoices +from arc.abc.command import CallableCommandBase, CommandBase, SubCommandBase +from arc.abc.option import OptionWithChoices from arc.context import AutocompleteData, AutodeferMode, Context from arc.errors import AutocompleteError, CommandInvokeError from arc.internal.sigparse import parse_function_signature -from arc.internal.types import ClientT, CommandCallbackT, ResponseBuilderT, SlashCommandLike +from arc.internal.types import ClientT, CommandCallbackT, HookT, PostHookT, ResponseBuilderT, SlashCommandLike if t.TYPE_CHECKING: from asyncio.futures import Future @@ -397,19 +396,16 @@ def include_subgroup( autodefer=AutodeferMode(autodefer) if autodefer else hikari.UNDEFINED, name_localizations=name_localizations or {}, description_localizations=description_localizations or {}, - parent=self, ) + group.parent = self self.children[name] = group return group @attr.define(slots=True, kw_only=True) -class SlashSubGroup(OptionBase[ClientT], HasErrorHandler[ClientT]): +class SlashSubGroup(SubCommandBase[ClientT, SlashGroup[ClientT]]): """A subgroup of a slash command group.""" - parent: SlashGroup[ClientT] | None = None - """The parent group of this subgroup.""" - children: dict[str, SlashSubCommand[ClientT]] = attr.field(factory=dict) """Subcommands that belong to this subgroup.""" @@ -454,6 +450,14 @@ def _to_dict(self) -> dict[str, t.Any]: "options": [subcommand.to_command_option() for subcommand in self.children.values()], } + def _resolve_hooks(self) -> list[HookT[ClientT]]: + assert self.parent is not None + return self.parent._resolve_hooks() + self._hooks + + def _resolve_post_hooks(self) -> list[PostHookT[ClientT]]: + assert self.parent is not None + return self.parent._resolve_post_hooks() + self._post_hooks + async def _handle_exception(self, ctx: Context[ClientT], exc: Exception) -> None: try: if self.error_handler: @@ -472,12 +476,9 @@ def include(self, command: SlashSubCommand[ClientT]) -> SlashSubCommand[ClientT] @attr.define(slots=True, kw_only=True) -class SlashSubCommand(OptionBase[ClientT], HasErrorHandler[ClientT]): +class SlashSubCommand(SubCommandBase[ClientT, SlashGroup[ClientT] | SlashSubGroup[ClientT]]): """A subcommand of a slash command group.""" - parent: SlashGroup[ClientT] | SlashSubGroup[ClientT] | None = None - """The parent group of this subcommand.""" - callback: CommandCallbackT[ClientT] """The callback that will be invoked when this subcommand is invoked.""" @@ -491,6 +492,14 @@ class SlashSubCommand(OptionBase[ClientT], HasErrorHandler[ClientT]): _invoke_task: asyncio.Task[t.Any] | None = attr.field(default=None, init=False) + def _resolve_hooks(self) -> list[HookT[ClientT]]: + assert self.parent is not None + return self.parent._resolve_hooks() + self._hooks + + def _resolve_post_hooks(self) -> list[PostHookT[ClientT]]: + assert self.parent is not None + return self.parent._resolve_post_hooks() + self._post_hooks + async def _handle_exception(self, ctx: Context[ClientT], exc: Exception) -> None: try: if self.error_handler: diff --git a/arc/context/base.py b/arc/context/base.py index 017dc4d..3a5363f 100644 --- a/arc/context/base.py +++ b/arc/context/base.py @@ -242,6 +242,7 @@ class Context(t.Generic[ClientT]): "_autodefer_task", "_created_at", "_autodefer_task", + "_has_command_failed", ) def __init__( @@ -256,6 +257,7 @@ def __init__( self._response_lock: asyncio.Lock = asyncio.Lock() self._created_at = datetime.datetime.now() self._autodefer_task: asyncio.Task[None] | None = None + self._has_command_failed: bool = False @property def interaction(self) -> hikari.CommandInteraction: @@ -336,6 +338,11 @@ def is_valid(self) -> bool: else: return datetime.datetime.now() - self._created_at <= datetime.timedelta(seconds=3) + @property + def has_command_failed(self) -> bool: + """Returns if the command callback failed to execute or not.""" + return self._has_command_failed + def _start_autodefer(self, autodefer_mode: AutodeferMode) -> None: """Start the autodefer task.""" if self._autodefer_task is not None: diff --git a/arc/errors.py b/arc/errors.py index 49e9f1f..58b8f55 100644 --- a/arc/errors.py +++ b/arc/errors.py @@ -1,3 +1,7 @@ +import typing as t + +import hikari + __all__ = ( "ArcError", "AutocompleteError", @@ -37,6 +41,58 @@ class InteractionResponseError(ArcError): """Base exception for all interaction response errors.""" +class HookAbortError(ArcError): + """Raised when a built-in hook aborts the execution of a command.""" + + +class GuildOnlyError(HookAbortError): + """Raised when a command is invoked outside of a guild and a + [`guild_only`][arc.utils.hooks.guild_only] hook is present. + """ + + +class NotOwnerError(HookAbortError): + """Raised when a command is invoked by a non-owner and a + [`owner_only`][arc.utils.hooks.owner_only] hook is present. + """ + + +class DMOnlyError(HookAbortError): + """Raised when a command is invoked outside of a DM and a + [`dm_only`][arc.utils.hooks.dm_only] hook is present. + """ + + +class InvokerMissingPermissionsError(HookAbortError): + """Raised when a command is invoked by a user without the + required permissions set by a [`has_permissions`][arc.utils.hooks.has_permissions] hook. + + Attributes + ---------- + missing_permissions : hikari.Permissions + The permissions that the invoker is missing. + """ + + def __init__(self, missing_permissions: hikari.Permissions, *args: t.Any) -> None: + self.missing_permissions = missing_permissions + super().__init__(*args) + + +class BotMissingPermissionsError(HookAbortError): + """Raised when a command is invoked and the bot is missing the + required permissions set by a [`bot_has_permissions`][arc.utils.hooks.bot_has_permissions] hook. + + Attributes + ---------- + missing_permissions : hikari.Permissions + The permissions that the bot is missing. + """ + + def __init__(self, missing_permissions: hikari.Permissions, *args: t.Any) -> None: + self.missing_permissions = missing_permissions + super().__init__(*args) + + class NoResponseIssuedError(InteractionResponseError): """Raised when no response was issued by a command. Interactions must be responded to or deferred within 3 seconds to avoid this error. diff --git a/arc/internal/about.py b/arc/internal/about.py index 48c6f84..2123ce5 100644 --- a/arc/internal/about.py +++ b/arc/internal/about.py @@ -5,7 +5,7 @@ __maintainer__: t.Final[str] = "hypergonial" __license__: t.Final[str] = "MIT" __url__: t.Final[str] = "https://github.com/hypergonial/hikari-arc" -__version__: t.Final[str] = "0.2.1" +__version__: t.Final[str] = "0.3.0" # MIT License # diff --git a/arc/internal/types.py b/arc/internal/types.py index f3394fc..00fa454 100644 --- a/arc/internal/types.py +++ b/arc/internal/types.py @@ -5,7 +5,7 @@ if t.TYPE_CHECKING: import hikari - from arc.abc import Client, OptionParams + from arc.abc import Client, Hookable, HookResult, OptionParams from arc.client import GatewayClient, RESTClient from arc.command import SlashCommand, SlashGroup from arc.context import AutocompleteData, Context @@ -20,6 +20,7 @@ EventT = t.TypeVar("EventT", bound="hikari.Event") BuilderT = t.TypeVar("BuilderT", bound="hikari.api.SlashCommandBuilder | hikari.api.ContextMenuCommandBuilder") ParamsT = t.TypeVar("ParamsT", bound="OptionParams[t.Any]") +HookableT = t.TypeVar("HookableT", bound="Hookable[t.Any]") # Type aliases EventCallbackT: t.TypeAlias = "t.Callable[[EventT], t.Coroutine[t.Any, t.Any, None]]" @@ -34,3 +35,5 @@ ResponseBuilderT: t.TypeAlias = ( "hikari.api.InteractionMessageBuilder | hikari.api.InteractionDeferredBuilder | hikari.api.InteractionModalBuilder" ) +HookT: t.TypeAlias = "t.Callable[[Context[ClientT]], t.Awaitable[HookResult]] | t.Callable[[Context[ClientT]], HookResult] | t.Callable[[Context[ClientT]], None] | t.Callable[[Context[ClientT]], t.Awaitable[None]]" +PostHookT: t.TypeAlias = "t.Callable[[Context[ClientT]], None] | t.Callable[[Context[ClientT]], t.Awaitable[None]]" diff --git a/arc/utils/__init__.py b/arc/utils/__init__.py new file mode 100644 index 0000000..2745c69 --- /dev/null +++ b/arc/utils/__init__.py @@ -0,0 +1,3 @@ +from .hooks import bot_has_permissions, dm_only, guild_only, has_permissions, owner_only + +__all__ = ("guild_only", "owner_only", "dm_only", "has_permissions", "bot_has_permissions") diff --git a/arc/utils/hooks.py b/arc/utils/hooks.py new file mode 100644 index 0000000..6cd2c43 --- /dev/null +++ b/arc/utils/hooks.py @@ -0,0 +1,74 @@ +import typing as t + +import hikari + +from arc.abc.hookable import HookResult +from arc.context import Context +from arc.errors import ( + BotMissingPermissionsError, + DMOnlyError, + GuildOnlyError, + InvokerMissingPermissionsError, + NotOwnerError, +) + + +def guild_only(ctx: Context[t.Any]) -> HookResult: + """A pre-execution hook that aborts the execution of a command if it is invoked outside of a guild.""" + if ctx.guild_id is None: + raise GuildOnlyError("This command can only be used in a guild.") + return HookResult() + + +def dm_only(ctx: Context[t.Any]) -> HookResult: + """A pre-execution hook that aborts the execution of a command if it is invoked outside of a DM.""" + if ctx.guild_id is not None: + raise DMOnlyError("This command can only be used in a DM.") + return HookResult() + + +def owner_only(ctx: Context[t.Any]) -> HookResult: + """A pre-execution hook that aborts the execution of a command if it is invoked by a non-owner.""" + if ctx.author.id not in ctx.client.owner_ids: + raise NotOwnerError("This command can only be used by the application owners.") + return HookResult() + + +def _has_permissions(ctx: Context[t.Any], perms: hikari.Permissions) -> HookResult: + """Check if the invoker has the specified permissions.""" + if ctx.member is None: + raise GuildOnlyError("This command can only be used in a guild.") + + missing_perms = ~ctx.member.permissions & perms + + if missing_perms is not hikari.Permissions.NONE: + raise InvokerMissingPermissionsError( + missing_perms, f"Invoker is missing '{missing_perms}' permissions to run this command." + ) + + return HookResult() + + +def has_permissions(perms: hikari.Permissions) -> t.Callable[[Context[t.Any]], HookResult]: + """A pre-execution hook that aborts the execution of a command if the invoker is missing the specified permissions.""" + return lambda ctx: _has_permissions(ctx, perms) + + +def _bot_has_permissions(ctx: Context[t.Any], perms: hikari.Permissions) -> HookResult: + """Check if the bot has the specified permissions.""" + if ctx.app_permissions is None: + raise GuildOnlyError("This command can only be used in a guild.") + + missing_perms = ~ctx.app_permissions & perms + + if missing_perms is not hikari.Permissions.NONE: + raise BotMissingPermissionsError( + missing_perms, f"Bot is missing '{missing_perms}' permissions to run this command." + ) + + return HookResult() + + +def bot_has_permissions(perms: hikari.Permissions) -> t.Callable[[Context[t.Any]], HookResult]: + """A pre-execution hook that aborts the execution of a command if the bot is missing the specified permissions.""" + return lambda ctx: _bot_has_permissions(ctx, perms) diff --git a/docs/api_reference/abc/hooks.md b/docs/api_reference/abc/hooks.md new file mode 100644 index 0000000..62b234c --- /dev/null +++ b/docs/api_reference/abc/hooks.md @@ -0,0 +1,8 @@ +--- +title: Hook ABCs +description: Abstract Base Classes API reference +--- + +# Hook ABCs + +::: arc.abc.hookable diff --git a/docs/api_reference/utils/hooks.md b/docs/api_reference/utils/hooks.md new file mode 100644 index 0000000..7ce6144 --- /dev/null +++ b/docs/api_reference/utils/hooks.md @@ -0,0 +1,13 @@ +--- +title: Hooks +description: Hooks API reference +--- + +# Hooks + +This module contains all the built-in hooks contained in `arc`. + +!!! note + Any function that takes a [Context][arc.context.base.Context] as it's sole parameter and returns either `None` or a [`HookResult`][arc.abc.hookable.HookResult] is a valid hook. + +::: arc.utils.hooks diff --git a/docs/api_reference/utils/index.md b/docs/api_reference/utils/index.md new file mode 100644 index 0000000..ff99be5 --- /dev/null +++ b/docs/api_reference/utils/index.md @@ -0,0 +1,8 @@ +--- +title: Utils +description: Utils API reference +--- + +# Utils + +Here you can find all the utilities `arc` exports. diff --git a/mkdocs.yml b/mkdocs.yml index bdc56bb..9f0bb33 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -77,12 +77,16 @@ nav: - api_reference/errors.md - api_reference/events.md - api_reference/plugin.md + - Utils: + - api_reference/utils/index.md + - api_reference/utils/hooks.md - ABC: - api_reference/abc/index.md - api_reference/abc/client.md - api_reference/abc/command.md - api_reference/abc/option.md - api_reference/abc/plugin.md + - api_reference/abc/hooks.md - api_reference/abc/error_handler.md - Changelog: changelog.md