Skip to content

Commit

Permalink
Update to 3.12
Browse files Browse the repository at this point in the history
  • Loading branch information
Sachaa-Thanasius committed Jul 22, 2024
1 parent 003380e commit 21cd02e
Show file tree
Hide file tree
Showing 45 changed files with 416 additions and 534 deletions.
2 changes: 1 addition & 1 deletion core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from . import tree as tree, utils as utils
from . import tree as tree
from .bot import Beira as Beira
from .checks import *
from .config import *
Expand Down
26 changes: 8 additions & 18 deletions core/bot.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
"""
bot.py: The main bot code.
"""

from __future__ import annotations
"""bot.py: The main bot code."""

import logging
import sys
import time
import traceback
from typing import TYPE_CHECKING, Any
from typing import Any
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError

import aiohttp
import ao3
import async_lru
import asyncpg
import atlas_api
import discord
import fichub_api
Expand All @@ -26,13 +21,9 @@
from .checks import is_blocked
from .config import CONFIG
from .context import Context
from .utils import LoggingManager, Pool_alias


if TYPE_CHECKING:
from core.utils import LoggingManager
else:
LoggingManager = object

LOGGER = logging.getLogger(__name__)


Expand All @@ -58,7 +49,7 @@ class Beira(commands.Bot):
def __init__(
self,
*args: Any,
db_pool: asyncpg.Pool[asyncpg.Record],
db_pool: Pool_alias,
web_session: aiohttp.ClientSession,
initial_extensions: list[str] | None = None,
**kwargs: Any,
Expand Down Expand Up @@ -189,8 +180,8 @@ def owner(self) -> discord.User:
async def _load_blocked_entities(self) -> None:
"""Load all blocked users and guilds from the bot database."""

user_query = """SELECT user_id FROM users WHERE is_blocked;"""
guild_query = """SELECT guild_id FROM guilds WHERE is_blocked;"""
user_query = "SELECT user_id FROM users WHERE is_blocked;"
guild_query = "SELECT guild_id FROM guilds WHERE is_blocked;"

async with self.db_pool.acquire() as conn, conn.transaction():
user_records = await conn.fetch(user_query)
Expand All @@ -202,7 +193,7 @@ async def _load_blocked_entities(self) -> None:
async def _load_guild_prefixes(self, guild_id: int | None = None) -> None:
"""Load all prefixes from the bot database."""

query = """SELECT guild_id, prefix FROM guild_prefixes"""
query = "SELECT guild_id, prefix FROM guild_prefixes"
try:
if guild_id:
query += " WHERE guild_id = $1"
Expand Down Expand Up @@ -248,8 +239,7 @@ async def _load_special_friends(self) -> None:

@async_lru.alru_cache()
async def get_user_timezone(self, user_id: int) -> str | None:
query = "SELECT timezone FROM users WHERE user_id = $1;"
record = await self.db_pool.fetchrow(query, user_id)
record = await self.db_pool.fetchrow("SELECT timezone FROM users WHERE user_id = $1;", user_id)
return record["timezone"] if record else None

async def get_user_tzinfo(self, user_id: int) -> ZoneInfo:
Expand Down
75 changes: 36 additions & 39 deletions core/checks.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
"""
checks.py: Custom checks used by the bot.
"""

from __future__ import annotations
"""checks.py: Custom checks used by the bot."""

from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Protocol, TypeVar
from typing import TYPE_CHECKING, Any, Protocol

import discord
from discord import app_commands
Expand All @@ -18,15 +14,11 @@
if TYPE_CHECKING:
from discord.ext.commands._types import Check # type: ignore [reportMissingTypeStubs]

from .context import Context, GuildContext

T = TypeVar("T")


class AppCheck(Protocol):
predicate: AppCheckFunc

def __call__(self, coro_or_commands: T) -> T: ...
def __call__[T](self, coro_or_commands: T) -> T: ...


__all__ = (
Expand All @@ -39,16 +31,16 @@ def __call__(self, coro_or_commands: T) -> T: ...
)


def is_owner_or_friend() -> Check[Any]:
"""A :func:`.check` that checks if the person invoking this command is the
owner of the bot or on a special friends list.
def is_owner_or_friend() -> "Check[Any]":
"""A `.check` that checks if the person invoking this command is the owner of the bot or on a special friends list.
This is partially powered by :meth:`.Bot.is_owner`.
This is partially powered by `.Bot.is_owner`.
This check raises a special exception, :exc:`.NotOwnerOrFriend` that is derived
from :exc:`commands.CheckFailure`.
This check raises a special exception, `.NotOwnerOrFriend` that is derived from `commands.CheckFailure`.
"""

from .context import Context

async def predicate(ctx: Context) -> bool:
if not (ctx.bot.is_special_friend(ctx.author) or await ctx.bot.is_owner(ctx.author)):
raise NotOwnerOrFriend
Expand All @@ -57,14 +49,15 @@ async def predicate(ctx: Context) -> bool:
return commands.check(predicate)


def is_admin() -> Check[Any]:
"""A :func:`.check` that checks if the person invoking this command is an
administrator of the guild in the current context.
def is_admin() -> "Check[Any]":
"""A `.check` that checks if the person invoking this command is an administrator of the guild in the current
context.
This check raises a special exception, :exc:`NotAdmin` that is derived
from :exc:`commands.CheckFailure`.
This check raises a special exception, `NotAdmin` that is derived from `commands.CheckFailure`.
"""

from .context import GuildContext

async def predicate(ctx: GuildContext) -> bool:
if not ctx.author.guild_permissions.administrator:
raise NotAdmin
Expand All @@ -73,14 +66,15 @@ async def predicate(ctx: GuildContext) -> bool:
return commands.check(predicate)


def in_bot_vc() -> Check[Any]:
"""A :func:`.check` that checks if the person invoking this command is in
the same voice channel as the bot within a guild.
def in_bot_vc() -> "Check[Any]":
"""A `.check` that checks if the person invoking this command is in the same voice channel as the bot within
a guild.
This check raises a special exception, :exc:`NotInBotVoiceChannel` that is derived
from :exc:`commands.CheckFailure`.
This check raises a special exception, `NotInBotVoiceChannel` that is derived from `commands.CheckFailure`.
"""

from .context import GuildContext

async def predicate(ctx: GuildContext) -> bool:
vc = ctx.voice_client

Expand All @@ -94,13 +88,14 @@ async def predicate(ctx: GuildContext) -> bool:
return commands.check(predicate)


def in_aci100_guild() -> Check[Any]:
"""A :func:`.check` that checks if the person invoking this command is in
the ACI100 guild.
def in_aci100_guild() -> "Check[Any]":
"""A `.check` that checks if the person invoking this command is in the ACI100 guild.
This check raises the exception :exc:`commands.CheckFailure`.
This check raises the exception `commands.CheckFailure`.
"""

from .context import GuildContext

async def predicate(ctx: GuildContext) -> bool:
if ctx.guild.id != 602735169090224139:
msg = "This command isn't active in this guild."
Expand All @@ -110,12 +105,14 @@ async def predicate(ctx: GuildContext) -> bool:
return commands.check(predicate)


def is_blocked() -> Check[Any]:
"""A :func:`.check` that checks if the command is being invoked from a blocked user or guild.
def is_blocked() -> "Check[Any]":
"""A `.check` that checks if the command is being invoked from a blocked user or guild.
This check raises the exception :exc:`commands.CheckFailure`.
This check raises the exception `commands.CheckFailure`.
"""

from .context import Context

async def predicate(ctx: Context) -> bool:
if not (await ctx.bot.is_owner(ctx.author)):
if ctx.author.id in ctx.bot.blocked_entities_cache["users"]:
Expand All @@ -128,22 +125,22 @@ async def predicate(ctx: Context) -> bool:


# TODO: Actually check if this works.
def check_any(*checks: AppCheck) -> Callable[[T], T]:
def check_any[T](*checks: AppCheck) -> Callable[[T], T]:
"""An attempt at making a `check_any` decorator for application commands that checks if any of the checks passed
will pass, i.e. using logical OR.

If all checks fail then :exc:`CheckAnyFailure` is raised to signal the failure.
It inherits from :exc:`app_commands.CheckFailure`.
If all checks fail then :exc:`CheckAnyFailure` is raised to signal the failure. It inherits from
`app_commands.CheckFailure`.

Parameters
----------
checks: `AppCheckProtocol`
An argument list of checks that have been decorated with :func:`app_commands.check` decorator.
An argument list of checks that have been decorated with `app_commands.check` decorator.

Raises
------
TypeError
A check passed has not been decorated with the :func:`app_commands.check` decorator.
A check passed has not been decorated with the `app_commands.check` decorator.
"""

unwrapped: list[AppCheckFunc] = []
Expand Down
15 changes: 2 additions & 13 deletions core/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
"""
config.py: Imports configuration information, such as api keys and tokens, default prefixes, etc.
"""
"""config.py: Imports configuration information, such as api keys and tokens, default prefixes, etc."""

import pathlib
from typing import Any
Expand Down Expand Up @@ -82,13 +80,4 @@ def decode(data: bytes | str) -> Config:
return msgspec.toml.decode(data, type=Config)


def encode(msg: Config) -> bytes:
"""Encode a ``Config`` object to TOML."""

return msgspec.toml.encode(msg)


with pathlib.Path("config.toml").open(encoding="utf-8") as f:
data = f.read()

CONFIG = decode(data)
CONFIG = decode(pathlib.Path("config.toml").read_text(encoding="utf-8"))
8 changes: 3 additions & 5 deletions core/context.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
"""
context.py: For the custom context and interaction subclasses. Mainly used for type narrowing.
"""
"""context.py: For the custom context and interaction subclasses. Mainly used for type narrowing."""

from __future__ import annotations

from typing import TYPE_CHECKING, TypeAlias
from typing import TYPE_CHECKING

import aiohttp
import discord
Expand All @@ -20,7 +18,7 @@

__all__ = ("Context", "GuildContext", "Interaction")

Interaction: TypeAlias = discord.Interaction["Beira"]
type Interaction = discord.Interaction[Beira]


class Context(commands.Context["Beira"]):
Expand Down
4 changes: 1 addition & 3 deletions core/errors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
"""
errors.py: Custom errors used by the bot.
"""
"""errors.py: Custom errors used by the bot."""

from discord import app_commands
from discord.ext import commands
Expand Down
30 changes: 16 additions & 14 deletions core/tree.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from __future__ import annotations

import asyncio
import logging
import traceback
from collections.abc import Callable, Coroutine
from typing import TYPE_CHECKING, Any, ParamSpec, TypeAlias
from typing import TYPE_CHECKING, Any

import discord
from discord import Client, Interaction
Expand All @@ -22,19 +20,21 @@

ClientT_co = TypeVar("ClientT_co", bound=Client, covariant=True)

P = ParamSpec("P")
T = TypeVar("T")
Coro: TypeAlias = Coroutine[Any, Any, T]
CoroFunc: TypeAlias = Callable[..., Coro[Any]]
GroupT = TypeVar("GroupT", bound=Group | commands.Cog)
AppHook: TypeAlias = Callable[[GroupT, Interaction[Any]], Coro[Any]] | Callable[[Interaction[Any]], Coro[Any]]

type Coro[T] = Coroutine[Any, Any, T]
type CoroFunc = Callable[..., Coro[Any]]
type AppHook[GroupT: (Group | commands.Cog)] = (
Callable[[GroupT, Interaction[Any]], Coro[Any]] | Callable[[Interaction[Any]], Coro[Any]]
)

LOGGER = logging.getLogger(__name__)

__all__ = ("before_app_invoke", "after_app_invoke", "HookableTree")


def before_app_invoke(coro: AppHook[GroupT]) -> Callable[[Command[GroupT, P, T]], Command[GroupT, P, T]]:
def before_app_invoke[GroupT: (Group | commands.Cog), **P, T](
coro: AppHook[GroupT],
) -> Callable[[Command[GroupT, P, T]], Command[GroupT, P, T]]:
"""A decorator that registers a coroutine as a pre-invoke hook.
This allows you to refer to one before invoke hook for several commands that
Expand Down Expand Up @@ -67,7 +67,9 @@ def decorator(inner: Command[GroupT, P, T]) -> Command[GroupT, P, T]:
return decorator


def after_app_invoke(coro: AppHook[GroupT]) -> Callable[[Command[GroupT, P, T]], Command[GroupT, P, T]]:
def after_app_invoke[GroupT: (Group | commands.Cog), **P, T](
coro: AppHook[GroupT],
) -> Callable[[Command[GroupT, P, T]], Command[GroupT, P, T]]:
"""A decorator that registers a coroutine as a post-invoke hook.
This allows you to refer to one after invoke hook for several commands that
Expand Down Expand Up @@ -132,7 +134,7 @@ async def on_error(self, interaction: Interaction[ClientT_co], error: AppCommand
LOGGER.error("Exception in command tree", exc_info=error, extra={"embed": embed})

async def _call(self, interaction: Interaction[ClientT_co]) -> None:
###### Copy the original logic but add hook checks/calls near the end.
# ---- Copy the original logic but add hook checks/calls near the end.

if not await self.interaction_check(interaction):
interaction.command_failed = True
Expand Down Expand Up @@ -172,7 +174,7 @@ async def _call(self, interaction: Interaction[ClientT_co]) -> None:

return

### Look for a pre-command hook.
# -- Look for a pre-command hook.
# Pre-command hooks are run before actual command-specific checks, unlike prefix commands.
# It doesn't really make sense, but the only solution seems to be monkey-patching
# Command._invoke_with_namespace, which doesn't seem feasible.
Expand All @@ -194,7 +196,7 @@ async def _call(self, interaction: Interaction[ClientT_co]) -> None:
if not interaction.command_failed:
self.client.dispatch("app_command_completion", interaction, command)
finally:
### Look for a post-command hook.
# -- Look for a post-command hook.
after_invoke = getattr(command, "_after_invoke", None)
if after_invoke:
instance = getattr(after_invoke, "__self__", None)
Expand Down
2 changes: 1 addition & 1 deletion core/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .custom_logging import *
from .db import *
from .embeds import *
from .emojis import *
from .log import *
from .misc import *
from .pagination import *
Loading

0 comments on commit 21cd02e

Please sign in to comment.