Skip to content

Commit

Permalink
Add second-order decos
Browse files Browse the repository at this point in the history
  • Loading branch information
hypergonial committed Jan 3, 2024
1 parent 2efbb26 commit a20c6d3
Show file tree
Hide file tree
Showing 9 changed files with 394 additions and 92 deletions.
256 changes: 206 additions & 50 deletions arc/abc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,29 +221,14 @@ def _remove_command(self, command: CommandBase[te.Self, t.Any]) -> None:

def _add_slash_command(self, command: SlashCommandLike[te.Self]) -> None:
"""Add a slash command to this client."""
if self.slash_commands.get(command.name) is not None:
logger.warning(
f"Shadowing already registered slash command '{command.name}'. Did you define multiple commands/groups with the same name?"
)

self._slash_commands[command.name] = command

def _add_message_command(self, command: MessageCommand[te.Self]) -> None:
"""Add a message command to this client."""
if self._message_commands.get(command.name) is not None:
logger.warning(
f"Shadowing already registered message command '{command.name}'. Did you define multiple commands with the same name?"
)

self._message_commands[command.name] = command

def _add_user_command(self, command: UserCommand[te.Self]) -> None:
"""Add a user command to this client."""
if self._user_commands.get(command.name) is not None:
logger.warning(
f"Shadowing already registered user command '{command.name}'. Did you define multiple commands with the same name?"
)

self._user_commands[command.name] = command

async def _on_startup(self) -> None:
Expand Down Expand Up @@ -366,8 +351,18 @@ async def on_autocomplete_interaction(

return await command._on_autocomplete(interaction)

@t.overload
def include(self) -> t.Callable[[CommandBase[te.Self, BuilderT]], CommandBase[te.Self, BuilderT]]:
...

@t.overload
def include(self, command: CommandBase[te.Self, BuilderT]) -> CommandBase[te.Self, BuilderT]:
"""First-order decorator to add a command to this client.
...

def include(
self, command: CommandBase[te.Self, BuilderT] | None = None
) -> CommandBase[te.Self, BuilderT] | t.Callable[[CommandBase[te.Self, BuilderT]], CommandBase[te.Self, BuilderT]]:
"""Decorator to add a command to this client.
!!! note
This should be the **last** (topmost) decorator on a command.
Expand All @@ -391,17 +386,27 @@ async def cmd(ctx: arc.GatewayContext) -> None:
...
```
"""
if command.plugin is not None:
raise RuntimeError(
f"Command '{command.name}' is already registered with plugin '{command.plugin.name}'."
f"\nYou should use '{type(self).__name__}.add_plugin()' to add the entire plugin to the client."
)

if existing := self.commands[command.command_type].get(command.name):
existing._client_remove_hook(self)
def decorator(command: CommandBase[te.Self, BuilderT]) -> CommandBase[te.Self, BuilderT]:
if command.plugin is not None:
raise RuntimeError(
f"Command '{command.name}' is already registered with plugin '{command.plugin.name}'."
f"\nYou should use '{type(self).__name__}.add_plugin()' to add the entire plugin to the client."
)

command._client_include_hook(self)
return command
if existing := self.commands[command.command_type].get(command.name):
logger.warning(
f"Shadowing already registered command '{command.name}'. Did you define multiple commands with the same name?"
)
existing._client_remove_hook(self)

command._client_include_hook(self)
return command

if command is not None:
return decorator(command)

return decorator

def include_slash_group(
self,
Expand Down Expand Up @@ -521,7 +526,15 @@ def remove_plugin(self, plugin: str | PluginBase[te.Self]) -> te.Self:
pg._client_remove_hook()
return self

@t.overload
def add_hook(self, hook: HookT[te.Self]) -> te.Self:
...

@t.overload
def add_hook(self) -> t.Callable[[HookT[te.Self]], HookT[te.Self]]:
...

def add_hook(self, hook: HookT[te.Self] | None = None) -> te.Self | t.Callable[[HookT[te.Self]], HookT[te.Self]]:
"""Add a pre-execution hook to this client.
This hook will be executed before **every command** that is added to this client.
Expand All @@ -539,10 +552,27 @@ def add_hook(self, hook: HookT[te.Self]) -> te.Self:
--------
- [`Client.add_post_hook`][arc.client.Client.add_post_hook]
"""
self._hooks.append(hook)
return self
if hook is not None:
self._hooks.append(hook)
return self

def decorator(hook: HookT[te.Self]) -> HookT[te.Self]:
self._hooks.append(hook)
return hook

return decorator

@t.overload
def add_post_hook(self, hook: PostHookT[te.Self]) -> te.Self:
...

@t.overload
def add_post_hook(self) -> t.Callable[[PostHookT[te.Self]], PostHookT[te.Self]]:
...

def add_post_hook(
self, hook: PostHookT[te.Self] | None = None
) -> te.Self | t.Callable[[PostHookT[te.Self]], PostHookT[te.Self]]:
"""Add a post-execution hook to this client.
This hook will be executed after **every command** that is added to this client.
Expand All @@ -563,10 +593,27 @@ def add_post_hook(self, hook: PostHookT[te.Self]) -> te.Self:
--------
- [`Client.add_hook`][arc.client.Client.add_hook]
"""
self._post_hooks.append(hook)
return self
if hook is not None:
self._post_hooks.append(hook)
return self

def decorator(hook: PostHookT[te.Self]) -> PostHookT[te.Self]:
self._post_hooks.append(hook)
return hook

return decorator

@t.overload
def set_error_handler(self, handler: ErrorHandlerCallbackT[te.Self]) -> None:
...

@t.overload
def set_error_handler(self) -> t.Callable[[ErrorHandlerCallbackT[te.Self]], None]:
...

def set_error_handler(
self, handler: ErrorHandlerCallbackT[te.Self] | None = None
) -> None | t.Callable[[ErrorHandlerCallbackT[te.Self]], None]:
"""Decorator to set the error handler for this client.
This will be called when a command callback raises an exception
Expand All @@ -591,16 +638,33 @@ async def error_handler_func(ctx: arc.GatewayContext, exception: Exception) -> N
client.set_error_handler(error_handler_func)
```
"""
self._error_handler = handler
if handler is not None:
self._error_handler = handler
return

def decorator(handler: ErrorHandlerCallbackT[te.Self]) -> None:
self._error_handler = handler

return decorator

def set_startup_hook(self, handler: LifeCycleHookT[te.Self]) -> None:
@t.overload
def set_startup_hook(self, hook: LifeCycleHookT[te.Self]) -> None:
...

@t.overload
def set_startup_hook(self) -> t.Callable[[LifeCycleHookT[te.Self]], None]:
...

def set_startup_hook(
self, hook: LifeCycleHookT[te.Self] | None = None
) -> None | t.Callable[[LifeCycleHookT[te.Self]], None]:
"""Decorator to set the startup hook for this client.
This will be called when the client starts up.
Parameters
----------
handler : LifeCycleHookT[te.Self]
hook : LifeCycleHookT[te.Self]
The startup hook to set.
Usage
Expand All @@ -617,16 +681,33 @@ async def startup_hook(client: arc.GatewayClient) -> None:
client.set_startup_hook(startup_hook)
```
"""
self._startup_hook = handler
if hook is not None:
self._startup_hook = hook
return

def decorator(handler: LifeCycleHookT[te.Self]) -> None:
self._startup_hook = handler

def set_shutdown_hook(self, handler: LifeCycleHookT[te.Self]) -> None:
return decorator

@t.overload
def set_shutdown_hook(self, hook: LifeCycleHookT[te.Self]) -> None:
...

@t.overload
def set_shutdown_hook(self) -> t.Callable[[LifeCycleHookT[te.Self]], None]:
...

def set_shutdown_hook(
self, hook: LifeCycleHookT[te.Self] | None = None
) -> None | t.Callable[[LifeCycleHookT[te.Self]], None]:
"""Decorator to set the shutdown hook for this client.
This will be called when the client shuts down.
Parameters
----------
handler : LifeCycleHookT[te.Self]
hook : LifeCycleHookT[te.Self]
The shutdown hook to set.
Usage
Expand All @@ -643,9 +724,26 @@ async def shutdown_hook(client: arc.GatewayClient) -> None:
client.set_shutdown_hook(shutdown_hook)
```
"""
self._shutdown_hook = handler
if hook is not None:
self._shutdown_hook = hook
return

def decorator(handler: LifeCycleHookT[te.Self]) -> None:
self._shutdown_hook = handler

return decorator

@t.overload
def set_command_locale_provider(self, provider: CommandLocaleRequestT) -> None:
...

@t.overload
def set_command_locale_provider(self) -> t.Callable[[CommandLocaleRequestT], None]:
...

def set_command_locale_provider(
self, provider: CommandLocaleRequestT | None = None
) -> None | t.Callable[[CommandLocaleRequestT], None]:
"""Decorator to set the command locale provider for this client.
This will be called for each command for each locale.
Expand All @@ -669,9 +767,26 @@ def command_locale_provider(request: arc.CommandLocaleRequest) -> arc.LocaleResp
client.set_command_locale_provider(command_locale_provider)
```
"""
self._command_locale_provider = provider
if provider is not None:
self._command_locale_provider = provider
return

def decorator(provider: CommandLocaleRequestT) -> None:
self._command_locale_provider = provider

return decorator

@t.overload
def set_option_locale_provider(self, provider: OptionLocaleRequestT) -> None:
...

@t.overload
def set_option_locale_provider(self) -> t.Callable[[OptionLocaleRequestT], None]:
...

def set_option_locale_provider(
self, provider: OptionLocaleRequestT | None = None
) -> None | t.Callable[[OptionLocaleRequestT], None]:
"""Decorator to set the option locale provider for this client.
This will be called for each option of each command for each locale.
Expand All @@ -695,9 +810,26 @@ def option_locale_provider(request: arc.OptionLocaleRequest) -> arc.LocaleRespon
client.set_option_locale_provider(option_locale_provider)
```
"""
self._option_locale_provider = provider
if provider is not None:
self._option_locale_provider = provider
return

def decorator(provider: OptionLocaleRequestT) -> None:
self._option_locale_provider = provider

return decorator

@t.overload
def set_custom_locale_provider(self, provider: CustomLocaleRequestT) -> None:
...

@t.overload
def set_custom_locale_provider(self) -> t.Callable[[CustomLocaleRequestT], None]:
...

def set_custom_locale_provider(
self, provider: CustomLocaleRequestT | None = None
) -> None | t.Callable[[CustomLocaleRequestT], None]:
"""Decorator to set the custom locale provider for this client.
This will be called for each custom locale request performed via [`Context.loc()`][arc.context.base.Context.loc].
Expand All @@ -721,7 +853,14 @@ def custom_locale_provider(request: arc.CustomLocaleRequest) -> str:
client.set_custom_locale_provider(custom_locale_provider)
```
"""
self._custom_locale_provider = provider
if provider is not None:
self._custom_locale_provider = provider
return

def decorator(provider: CustomLocaleRequestT) -> None:
self._custom_locale_provider = provider

return decorator

def load_extension(self, path: str) -> te.Self:
"""Load a python module with path `path` as an extension.
Expand Down Expand Up @@ -950,8 +1089,18 @@ def get_type_dependency(self, type_: t.Type[T]) -> hikari.UndefinedOr[T]:
"""
return self._injector.get_type_dependency(type_, default=hikari.UNDEFINED)

@t.overload
def inject_dependencies(self, func: t.Callable[P, T]) -> t.Callable[P, T]:
"""First order decorator to inject dependencies into the decorated function.
...

@t.overload
def inject_dependencies(self) -> t.Callable[[t.Callable[P, T]], t.Callable[P, T]]:
...

def inject_dependencies(
self, func: t.Callable[P, T] | None = None
) -> t.Callable[P, T] | t.Callable[[t.Callable[P, T]], t.Callable[P, T]]:
"""Decorator to inject dependencies into the decorated function.
!!! note
Command callbacks are automatically injected with dependencies,
Expand All @@ -978,20 +1127,27 @@ def my_func(dep: MyDependency = arc.inject()) -> None:
- [`Client.set_type_dependency`][arc.client.Client.set_type_dependency]
A method to set dependencies for the client.
"""
if inspect.iscoroutinefunction(func):

@functools.wraps(func)
async def decorator_async(*args: P.args, **kwargs: P.kwargs) -> T:
return await self.injector.call_with_async_di(func, *args, **kwargs)
def decorator(func: t.Callable[P, T]) -> t.Callable[P, T]:
if inspect.iscoroutinefunction(func):

return decorator_async # pyright: ignore reportGeneralTypeIssues
else:
@functools.wraps(func)
async def decorator_async(*args: P.args, **kwargs: P.kwargs) -> T:
return await self.injector.call_with_async_di(func, *args, **kwargs)

return decorator_async # pyright: ignore reportGeneralTypeIssues
else:

@functools.wraps(func)
def decorator_inner(*args: P.args, **kwargs: P.kwargs) -> T:
return self.injector.call_with_di(func, *args, **kwargs)

return decorator_inner

@functools.wraps(func)
def decorator(*args: P.args, **kwargs: P.kwargs) -> T:
return self.injector.call_with_di(func, *args, **kwargs)
if func is not None:
return decorator(func)

return decorator
return decorator

async def resync_commands(self) -> None:
"""Synchronize the commands registered in this client with Discord.
Expand Down
Loading

0 comments on commit a20c6d3

Please sign in to comment.