From 21cd02ec57a2cdfb04a531704be6550da9fc1b24 Mon Sep 17 00:00:00 2001 From: Sachaa-Thanasius <111999343+Sachaa-Thanasius@users.noreply.github.com> Date: Sun, 21 Jul 2024 21:30:46 -0400 Subject: [PATCH] Update to 3.12 --- core/__init__.py | 2 +- core/bot.py | 26 ++---- core/checks.py | 75 ++++++++--------- core/config.py | 15 +--- core/context.py | 8 +- core/errors.py | 4 +- core/tree.py | 30 +++---- core/utils/__init__.py | 2 +- core/utils/db.py | 75 +++-------------- core/utils/embeds.py | 12 ++- core/utils/emojis.py | 3 - core/utils/{custom_logging.py => log.py} | 37 ++------- core/utils/misc.py | 8 +- core/utils/pagination.py | 9 +- exts/__init__.py | 2 +- exts/_dev/__init__.py | 6 +- exts/_dev/_dev.py | 100 +++++++++++++++++++---- exts/admin.py | 37 ++++----- exts/bot_stats.py | 76 ++++++++--------- exts/dice.py | 13 +-- exts/emoji_ops.py | 5 +- exts/fandom_wiki.py | 42 ++++++---- exts/ff_metadata/__init__.py | 10 +-- exts/ff_metadata/ff_metadata.py | 40 ++++----- exts/ff_metadata/utils.py | 2 - exts/help.py | 5 +- exts/lol.py | 11 +-- exts/misc.py | 46 ++++------- exts/music.py | 7 +- exts/notifications/__init__.py | 2 - exts/notifications/aci_notifications.py | 38 ++++----- exts/notifications/other_triggers.py | 8 +- exts/notifications/rss_notifications.py | 4 +- exts/patreon.py | 7 +- exts/presence.py | 2 - exts/snowball/__init__.py | 10 +-- exts/snowball/snow_text.py | 12 +++ exts/snowball/snowball.py | 31 +++---- exts/snowball/utils.py | 40 +++++---- exts/starkid.py | 5 +- exts/story_search.py | 42 +++++----- exts/timing.py | 9 +- exts/todo.py | 24 +++--- exts/webhook_logging.py | 2 - pyproject.toml | 6 +- 45 files changed, 416 insertions(+), 534 deletions(-) rename core/utils/{custom_logging.py => log.py} (82%) diff --git a/core/__init__.py b/core/__init__.py index dd763aa..17b83c3 100644 --- a/core/__init__.py +++ b/core/__init__.py @@ -1,4 +1,4 @@ -from . import tree as tree, utils as utils +from . import tree as tree from .bot import Beira as Beira from .checks import * from .config import * diff --git a/core/bot.py b/core/bot.py index ac8d581..fedbcf0 100644 --- a/core/bot.py +++ b/core/bot.py @@ -1,20 +1,15 @@ -""" -bot.py: The main bot code. -""" - -from __future__ import annotations +"""bot.py: The main bot code.""" import logging import sys import time import traceback -from typing import TYPE_CHECKING, Any +from typing import Any from zoneinfo import ZoneInfo, ZoneInfoNotFoundError import aiohttp import ao3 import async_lru -import asyncpg import atlas_api import discord import fichub_api @@ -26,13 +21,9 @@ from .checks import is_blocked from .config import CONFIG from .context import Context +from .utils import LoggingManager, Pool_alias -if TYPE_CHECKING: - from core.utils import LoggingManager -else: - LoggingManager = object - LOGGER = logging.getLogger(__name__) @@ -58,7 +49,7 @@ class Beira(commands.Bot): def __init__( self, *args: Any, - db_pool: asyncpg.Pool[asyncpg.Record], + db_pool: Pool_alias, web_session: aiohttp.ClientSession, initial_extensions: list[str] | None = None, **kwargs: Any, @@ -189,8 +180,8 @@ def owner(self) -> discord.User: async def _load_blocked_entities(self) -> None: """Load all blocked users and guilds from the bot database.""" - user_query = """SELECT user_id FROM users WHERE is_blocked;""" - guild_query = """SELECT guild_id FROM guilds WHERE is_blocked;""" + user_query = "SELECT user_id FROM users WHERE is_blocked;" + guild_query = "SELECT guild_id FROM guilds WHERE is_blocked;" async with self.db_pool.acquire() as conn, conn.transaction(): user_records = await conn.fetch(user_query) @@ -202,7 +193,7 @@ async def _load_blocked_entities(self) -> None: async def _load_guild_prefixes(self, guild_id: int | None = None) -> None: """Load all prefixes from the bot database.""" - query = """SELECT guild_id, prefix FROM guild_prefixes""" + query = "SELECT guild_id, prefix FROM guild_prefixes" try: if guild_id: query += " WHERE guild_id = $1" @@ -248,8 +239,7 @@ async def _load_special_friends(self) -> None: @async_lru.alru_cache() async def get_user_timezone(self, user_id: int) -> str | None: - query = "SELECT timezone FROM users WHERE user_id = $1;" - record = await self.db_pool.fetchrow(query, user_id) + record = await self.db_pool.fetchrow("SELECT timezone FROM users WHERE user_id = $1;", user_id) return record["timezone"] if record else None async def get_user_tzinfo(self, user_id: int) -> ZoneInfo: diff --git a/core/checks.py b/core/checks.py index 48d729b..7245a2a 100644 --- a/core/checks.py +++ b/core/checks.py @@ -1,11 +1,7 @@ -""" -checks.py: Custom checks used by the bot. -""" - -from __future__ import annotations +"""checks.py: Custom checks used by the bot.""" from collections.abc import Callable -from typing import TYPE_CHECKING, Any, Protocol, TypeVar +from typing import TYPE_CHECKING, Any, Protocol import discord from discord import app_commands @@ -18,15 +14,11 @@ if TYPE_CHECKING: from discord.ext.commands._types import Check # type: ignore [reportMissingTypeStubs] - from .context import Context, GuildContext - -T = TypeVar("T") - class AppCheck(Protocol): predicate: AppCheckFunc - def __call__(self, coro_or_commands: T) -> T: ... + def __call__[T](self, coro_or_commands: T) -> T: ... __all__ = ( @@ -39,16 +31,16 @@ def __call__(self, coro_or_commands: T) -> T: ... ) -def is_owner_or_friend() -> Check[Any]: - """A :func:`.check` that checks if the person invoking this command is the - owner of the bot or on a special friends list. +def is_owner_or_friend() -> "Check[Any]": + """A `.check` that checks if the person invoking this command is the owner of the bot or on a special friends list. - This is partially powered by :meth:`.Bot.is_owner`. + This is partially powered by `.Bot.is_owner`. - This check raises a special exception, :exc:`.NotOwnerOrFriend` that is derived - from :exc:`commands.CheckFailure`. + This check raises a special exception, `.NotOwnerOrFriend` that is derived from `commands.CheckFailure`. """ + from .context import Context + async def predicate(ctx: Context) -> bool: if not (ctx.bot.is_special_friend(ctx.author) or await ctx.bot.is_owner(ctx.author)): raise NotOwnerOrFriend @@ -57,14 +49,15 @@ async def predicate(ctx: Context) -> bool: return commands.check(predicate) -def is_admin() -> Check[Any]: - """A :func:`.check` that checks if the person invoking this command is an - administrator of the guild in the current context. +def is_admin() -> "Check[Any]": + """A `.check` that checks if the person invoking this command is an administrator of the guild in the current + context. - This check raises a special exception, :exc:`NotAdmin` that is derived - from :exc:`commands.CheckFailure`. + This check raises a special exception, `NotAdmin` that is derived from `commands.CheckFailure`. """ + from .context import GuildContext + async def predicate(ctx: GuildContext) -> bool: if not ctx.author.guild_permissions.administrator: raise NotAdmin @@ -73,14 +66,15 @@ async def predicate(ctx: GuildContext) -> bool: return commands.check(predicate) -def in_bot_vc() -> Check[Any]: - """A :func:`.check` that checks if the person invoking this command is in - the same voice channel as the bot within a guild. +def in_bot_vc() -> "Check[Any]": + """A `.check` that checks if the person invoking this command is in the same voice channel as the bot within + a guild. - This check raises a special exception, :exc:`NotInBotVoiceChannel` that is derived - from :exc:`commands.CheckFailure`. + This check raises a special exception, `NotInBotVoiceChannel` that is derived from `commands.CheckFailure`. """ + from .context import GuildContext + async def predicate(ctx: GuildContext) -> bool: vc = ctx.voice_client @@ -94,13 +88,14 @@ async def predicate(ctx: GuildContext) -> bool: return commands.check(predicate) -def in_aci100_guild() -> Check[Any]: - """A :func:`.check` that checks if the person invoking this command is in - the ACI100 guild. +def in_aci100_guild() -> "Check[Any]": + """A `.check` that checks if the person invoking this command is in the ACI100 guild. - This check raises the exception :exc:`commands.CheckFailure`. + This check raises the exception `commands.CheckFailure`. """ + from .context import GuildContext + async def predicate(ctx: GuildContext) -> bool: if ctx.guild.id != 602735169090224139: msg = "This command isn't active in this guild." @@ -110,12 +105,14 @@ async def predicate(ctx: GuildContext) -> bool: return commands.check(predicate) -def is_blocked() -> Check[Any]: - """A :func:`.check` that checks if the command is being invoked from a blocked user or guild. +def is_blocked() -> "Check[Any]": + """A `.check` that checks if the command is being invoked from a blocked user or guild. - This check raises the exception :exc:`commands.CheckFailure`. + This check raises the exception `commands.CheckFailure`. """ + from .context import Context + async def predicate(ctx: Context) -> bool: if not (await ctx.bot.is_owner(ctx.author)): if ctx.author.id in ctx.bot.blocked_entities_cache["users"]: @@ -128,22 +125,22 @@ async def predicate(ctx: Context) -> bool: # TODO: Actually check if this works. -def check_any(*checks: AppCheck) -> Callable[[T], T]: +def check_any[T](*checks: AppCheck) -> Callable[[T], T]: """An attempt at making a `check_any` decorator for application commands that checks if any of the checks passed will pass, i.e. using logical OR. - If all checks fail then :exc:`CheckAnyFailure` is raised to signal the failure. - It inherits from :exc:`app_commands.CheckFailure`. + If all checks fail then :exc:`CheckAnyFailure` is raised to signal the failure. It inherits from + `app_commands.CheckFailure`. Parameters ---------- checks: `AppCheckProtocol` - An argument list of checks that have been decorated with :func:`app_commands.check` decorator. + An argument list of checks that have been decorated with `app_commands.check` decorator. Raises ------ TypeError - A check passed has not been decorated with the :func:`app_commands.check` decorator. + A check passed has not been decorated with the `app_commands.check` decorator. """ unwrapped: list[AppCheckFunc] = [] diff --git a/core/config.py b/core/config.py index 6f0de31..2949648 100644 --- a/core/config.py +++ b/core/config.py @@ -1,6 +1,4 @@ -""" -config.py: Imports configuration information, such as api keys and tokens, default prefixes, etc. -""" +"""config.py: Imports configuration information, such as api keys and tokens, default prefixes, etc.""" import pathlib from typing import Any @@ -82,13 +80,4 @@ def decode(data: bytes | str) -> Config: return msgspec.toml.decode(data, type=Config) -def encode(msg: Config) -> bytes: - """Encode a ``Config`` object to TOML.""" - - return msgspec.toml.encode(msg) - - -with pathlib.Path("config.toml").open(encoding="utf-8") as f: - data = f.read() - -CONFIG = decode(data) +CONFIG = decode(pathlib.Path("config.toml").read_text(encoding="utf-8")) diff --git a/core/context.py b/core/context.py index b0330c4..bf6ef5d 100644 --- a/core/context.py +++ b/core/context.py @@ -1,10 +1,8 @@ -""" -context.py: For the custom context and interaction subclasses. Mainly used for type narrowing. -""" +"""context.py: For the custom context and interaction subclasses. Mainly used for type narrowing.""" from __future__ import annotations -from typing import TYPE_CHECKING, TypeAlias +from typing import TYPE_CHECKING import aiohttp import discord @@ -20,7 +18,7 @@ __all__ = ("Context", "GuildContext", "Interaction") -Interaction: TypeAlias = discord.Interaction["Beira"] +type Interaction = discord.Interaction[Beira] class Context(commands.Context["Beira"]): diff --git a/core/errors.py b/core/errors.py index 65ca7e6..5c04ee6 100644 --- a/core/errors.py +++ b/core/errors.py @@ -1,6 +1,4 @@ -""" -errors.py: Custom errors used by the bot. -""" +"""errors.py: Custom errors used by the bot.""" from discord import app_commands from discord.ext import commands diff --git a/core/tree.py b/core/tree.py index d9fe1d6..5ea38d8 100644 --- a/core/tree.py +++ b/core/tree.py @@ -1,10 +1,8 @@ -from __future__ import annotations - import asyncio import logging import traceback from collections.abc import Callable, Coroutine -from typing import TYPE_CHECKING, Any, ParamSpec, TypeAlias +from typing import TYPE_CHECKING, Any import discord from discord import Client, Interaction @@ -22,19 +20,21 @@ ClientT_co = TypeVar("ClientT_co", bound=Client, covariant=True) -P = ParamSpec("P") -T = TypeVar("T") -Coro: TypeAlias = Coroutine[Any, Any, T] -CoroFunc: TypeAlias = Callable[..., Coro[Any]] -GroupT = TypeVar("GroupT", bound=Group | commands.Cog) -AppHook: TypeAlias = Callable[[GroupT, Interaction[Any]], Coro[Any]] | Callable[[Interaction[Any]], Coro[Any]] + +type Coro[T] = Coroutine[Any, Any, T] +type CoroFunc = Callable[..., Coro[Any]] +type AppHook[GroupT: (Group | commands.Cog)] = ( + Callable[[GroupT, Interaction[Any]], Coro[Any]] | Callable[[Interaction[Any]], Coro[Any]] +) LOGGER = logging.getLogger(__name__) __all__ = ("before_app_invoke", "after_app_invoke", "HookableTree") -def before_app_invoke(coro: AppHook[GroupT]) -> Callable[[Command[GroupT, P, T]], Command[GroupT, P, T]]: +def before_app_invoke[GroupT: (Group | commands.Cog), **P, T]( + coro: AppHook[GroupT], +) -> Callable[[Command[GroupT, P, T]], Command[GroupT, P, T]]: """A decorator that registers a coroutine as a pre-invoke hook. This allows you to refer to one before invoke hook for several commands that @@ -67,7 +67,9 @@ def decorator(inner: Command[GroupT, P, T]) -> Command[GroupT, P, T]: return decorator -def after_app_invoke(coro: AppHook[GroupT]) -> Callable[[Command[GroupT, P, T]], Command[GroupT, P, T]]: +def after_app_invoke[GroupT: (Group | commands.Cog), **P, T]( + coro: AppHook[GroupT], +) -> Callable[[Command[GroupT, P, T]], Command[GroupT, P, T]]: """A decorator that registers a coroutine as a post-invoke hook. This allows you to refer to one after invoke hook for several commands that @@ -132,7 +134,7 @@ async def on_error(self, interaction: Interaction[ClientT_co], error: AppCommand LOGGER.error("Exception in command tree", exc_info=error, extra={"embed": embed}) async def _call(self, interaction: Interaction[ClientT_co]) -> None: - ###### Copy the original logic but add hook checks/calls near the end. + # ---- Copy the original logic but add hook checks/calls near the end. if not await self.interaction_check(interaction): interaction.command_failed = True @@ -172,7 +174,7 @@ async def _call(self, interaction: Interaction[ClientT_co]) -> None: return - ### Look for a pre-command hook. + # -- Look for a pre-command hook. # Pre-command hooks are run before actual command-specific checks, unlike prefix commands. # It doesn't really make sense, but the only solution seems to be monkey-patching # Command._invoke_with_namespace, which doesn't seem feasible. @@ -194,7 +196,7 @@ async def _call(self, interaction: Interaction[ClientT_co]) -> None: if not interaction.command_failed: self.client.dispatch("app_command_completion", interaction, command) finally: - ### Look for a post-command hook. + # -- Look for a post-command hook. after_invoke = getattr(command, "_after_invoke", None) if after_invoke: instance = getattr(after_invoke, "__self__", None) diff --git a/core/utils/__init__.py b/core/utils/__init__.py index 307f8fe..8af78f3 100644 --- a/core/utils/__init__.py +++ b/core/utils/__init__.py @@ -1,6 +1,6 @@ -from .custom_logging import * from .db import * from .embeds import * from .emojis import * +from .log import * from .misc import * from .pagination import * diff --git a/core/utils/db.py b/core/utils/db.py index 41e2741..975f5af 100644 --- a/core/utils/db.py +++ b/core/utils/db.py @@ -1,29 +1,24 @@ -""" -db.py: Utility functions for interacting with the database. -""" +"""db.py: Utility functions for interacting with the database.""" -from __future__ import annotations +from typing import TYPE_CHECKING -from typing import TYPE_CHECKING, TypeAlias - -import discord import msgspec from asyncpg import Connection, Pool, Record from asyncpg.pool import PoolConnectionProxy -UserObject: TypeAlias = discord.abc.User | discord.Object | tuple[int, bool] -GuildObject: TypeAlias = discord.Guild | discord.Object | tuple[int, bool] - - -__all__ = ("Connection_alias", "Pool_alias", "conn_init", "upsert_users", "upsert_guilds") +__all__ = ( + "Connection_alias", + "Pool_alias", + "conn_init", +) if TYPE_CHECKING: - Connection_alias: TypeAlias = Connection[Record] | PoolConnectionProxy[Record] - Pool_alias: TypeAlias = Pool[Record] + type Connection_alias = Connection[Record] | PoolConnectionProxy[Record] + type Pool_alias = Pool[Record] else: - Connection_alias: TypeAlias = Connection | PoolConnectionProxy - Pool_alias: TypeAlias = Pool + type Connection_alias = Connection | PoolConnectionProxy + type Pool_alias = Pool async def conn_init(connection: Connection_alias) -> None: @@ -35,51 +30,3 @@ async def conn_init(connection: Connection_alias) -> None: encoder=msgspec.json.encode, decoder=msgspec.json.decode, ) - - -async def upsert_users(conn: Pool_alias | Connection_alias, *users: UserObject) -> None: - """Upsert a Discord user in the appropriate database table. - - Parameters - ---------- - conn: `Pool` | `Connection` - The connection pool used to interact to the database. - users: tuple[`discord.abc.User` | `discord.Object` | tuple] - One or more users, members, discord objects, or tuples of user ids and blocked statuses, to use for upsertion. - """ - - upsert_query = """ - INSERT INTO users (user_id, is_blocked) - VALUES ($1, $2) - ON CONFLICT(user_id) - DO UPDATE - SET is_blocked = EXCLUDED.is_blocked; - """ - - # Format the users as minimal tuples. - values = [(user.id, False) if not isinstance(user, tuple) else user for user in users] - await conn.executemany(upsert_query, values, timeout=60.0) - - -async def upsert_guilds(conn: Pool_alias | Connection_alias, *guilds: GuildObject) -> None: - """Upsert a Discord guild in the appropriate database table. - - Parameters - ---------- - conn: `Pool` | `Connection` - The connection pool used to interact to the database. - guilds: tuple[`discord.Guild` | `discord.Object` | tuple] - One or more guilds, discord objects, or tuples of guild ids, names, and blocked statuses, to use for upsertion. - """ - - upsert_query = """ - INSERT INTO guilds (guild_id, is_blocked) - VALUES ($1, $2) - ON CONFLICT (guild_id) - DO UPDATE - SET is_blocked = EXCLUDED.is_blocked; - """ - - # Format the guilds as minimal tuples. - values = [(guild.id, False) if not isinstance(guild, tuple) else guild for guild in guilds] - await conn.executemany(upsert_query, values, timeout=60.0) diff --git a/core/utils/embeds.py b/core/utils/embeds.py index 4666553..82abbaf 100644 --- a/core/utils/embeds.py +++ b/core/utils/embeds.py @@ -1,23 +1,21 @@ -""" -embeds.py: This class provides embeds for user-specific statistics separated into fields. -""" +"""embeds.py: This class provides embeds for user-specific statistics separated into fields.""" from __future__ import annotations import itertools import logging from collections.abc import Iterable, Sequence -from typing import Self, TypeAlias +from typing import Self import discord -AnyEmoji: TypeAlias = discord.Emoji | discord.PartialEmoji | str +LOGGER = logging.getLogger(__name__) +type AnyEmoji = discord.Emoji | discord.PartialEmoji | str -__all__ = ("StatsEmbed",) -LOGGER = logging.getLogger(__name__) +__all__ = ("StatsEmbed",) class StatsEmbed(discord.Embed): diff --git a/core/utils/emojis.py b/core/utils/emojis.py index 0da670b..9f272a9 100644 --- a/core/utils/emojis.py +++ b/core/utils/emojis.py @@ -1,6 +1,3 @@ -from __future__ import annotations - - __all__ = ( "EMOJI_STOCK", "EMOJI_URL", diff --git a/core/utils/custom_logging.py b/core/utils/log.py similarity index 82% rename from core/utils/custom_logging.py rename to core/utils/log.py index 3319c06..146296a 100644 --- a/core/utils/custom_logging.py +++ b/core/utils/log.py @@ -1,35 +1,26 @@ -""" -custom_logging.py: Based on the work of Umbra, this is Beira's logging system. +"""custom_logging.py: Based on the work of Umbra, this is Beira's logging system. References ---------- https://github.com/AbstractUmbra/Mipha/blob/main/bot.py#L91 """ -from __future__ import annotations - import asyncio import copy import logging from logging.handlers import RotatingFileHandler from pathlib import Path -from typing import TYPE_CHECKING, Any, Self +from typing import Self from discord.utils import _ColourFormatter as ColourFormatter, stream_supports_colour # type: ignore # Because color. -if TYPE_CHECKING: - from types import TracebackType -else: - TracebackType = object - - __all__ = ("LoggingManager",) class AsyncQueueHandler(logging.Handler): # Copied api and implementation of stdlib QueueHandler. - def __init__(self, queue: asyncio.Queue[Any]) -> None: + def __init__(self, queue: asyncio.Queue[logging.LogRecord]) -> None: logging.Handler.__init__(self) self.queue = queue @@ -68,9 +59,7 @@ def __init__(self) -> None: super().__init__(name="discord.state") def filter(self, record: logging.LogRecord) -> bool: - if record.levelname == "WARNING" and "referencing an unknown" in record.msg: - return False - return True + return not (record.levelname == "WARNING" and "referencing an unknown" in record.msg) # TODO: Personalize logging beyond Umbra's work. @@ -142,20 +131,10 @@ def __enter__(self) -> Self: return self - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - traceback: TracebackType | None, - ) -> None: - return self.__exit__(exc_type, exc_val, traceback) - - def __exit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - traceback: TracebackType | None, - ) -> None: + async def __aexit__(self, *exc_info: object) -> None: + return self.__exit__(*exc_info) + + def __exit__(self, *exc_info: object) -> None: """Close and remove all logging handlers.""" handlers = self.log.handlers[:] diff --git a/core/utils/misc.py b/core/utils/misc.py index 0d8ba98..abe0ac2 100644 --- a/core/utils/misc.py +++ b/core/utils/misc.py @@ -1,8 +1,4 @@ -""" -misc.py: Miscellaneous utility functions that might come in handy. -""" - -from __future__ import annotations +"""misc.py: Miscellaneous utility functions that might come in handy.""" import logging import re @@ -32,7 +28,7 @@ def __enter__(self): self.total_time = time.perf_counter() return self - def __exit__(self, *exc: object) -> None: + def __exit__(self, *exc_info: object) -> None: self.total_time = time.perf_counter() - self.total_time if self.logger: self.logger.info("Time: %.3f seconds", self.total_time) diff --git a/core/utils/pagination.py b/core/utils/pagination.py index eef2926..c45c6a2 100644 --- a/core/utils/pagination.py +++ b/core/utils/pagination.py @@ -8,14 +8,11 @@ import asyncio from abc import ABC, abstractmethod from collections.abc import Sequence -from typing import Any, Generic, Self, TypeVar +from typing import Any, Self import discord -_LT = TypeVar("_LT") - - __all__ = ("QuitButton", "OwnedView", "PageSeekModal", "PaginatedEmbedView", "PaginatedSelectView") @@ -118,7 +115,7 @@ async def on_submit(self, interaction: discord.Interaction, /) -> None: self.stop() -class PaginatedEmbedView(ABC, Generic[_LT], OwnedView): +class PaginatedEmbedView[_LT](ABC, OwnedView): """A view that handles paginated embeds and page buttons. Parameters @@ -294,7 +291,7 @@ async def turn_to_last(self, interaction: discord.Interaction, _: discord.ui.But await self.update_page(interaction) -class PaginatedSelectView(ABC, Generic[_LT], OwnedView): +class PaginatedSelectView[_LT](ABC, OwnedView): """A view that handles paginated embeds and page buttons. Parameters diff --git a/exts/__init__.py b/exts/__init__.py index a125350..dd659ee 100644 --- a/exts/__init__.py +++ b/exts/__init__.py @@ -1,4 +1,4 @@ from pkgutil import iter_modules -EXTENSIONS = [module.name for module in iter_modules(__path__, f"{__package__}.")] +EXTENSIONS = [mod_info.name for mod_info in iter_modules(__path__, f"{__package__}.")] diff --git a/exts/_dev/__init__.py b/exts/_dev/__init__.py index 8cc50f8..6e65bc1 100644 --- a/exts/_dev/__init__.py +++ b/exts/_dev/__init__.py @@ -1,12 +1,10 @@ -from __future__ import annotations - -from core import Beira +import core from ._dev import DevCog from ._test import TestCog -async def setup(bot: Beira) -> None: +async def setup(bot: core.Beira) -> None: """Connects cog to bot.""" # Can't use the guilds kwarg, as it doesn't currently work for hybrids. It would look like this: diff --git a/exts/_dev/_dev.py b/exts/_dev/_dev.py index d087d91..02e5754 100644 --- a/exts/_dev/_dev.py +++ b/exts/_dev/_dev.py @@ -1,12 +1,12 @@ +"""_dev.py: A cog that implements commands for reloading and syncing extensions and other commands, at the owner's +behest. """ -_dev.py: A cog that implements commands for reloading and syncing extensions and other commands, at the owner's behest. -""" - -from __future__ import annotations +import contextlib import logging +from collections.abc import Generator from time import perf_counter -from typing import Literal +from typing import Any, Literal import discord from asyncpg.exceptions import PostgresConnectionError, PostgresError @@ -14,7 +14,6 @@ from discord.ext import commands import core -from core.utils import upsert_guilds, upsert_users from exts import EXTENSIONS @@ -142,11 +141,25 @@ async def block_add( # Regardless of block type, update the database, update the cache, and create an informational embed. if block_type == "users": - await upsert_users(ctx.db, *((user.id, True) for user in entities)) + stmt = """\ + INSERT INTO users (user_id, is_blocked) + VALUES ($1, $2) + ON CONFLICT(user_id) + DO UPDATE + SET is_blocked = EXCLUDED.is_blocked; + """ + await ctx.db.executemany(stmt, [(user.id, True) for user in entities]) self.bot.blocked_entities_cache["users"].update(user.id for user in entities) embed = discord.Embed(title="Users", description="\n".join(str(user) for user in entities)) else: - await upsert_guilds(ctx.db, *((guild.id, True) for guild in entities)) + stmt = """\ + INSERT INTO guilds (guild_id, is_blocked) + VALUES ($1, $2) + ON CONFLICT (guild_id) + DO UPDATE + SET is_blocked = EXCLUDED.is_blocked; + """ + await ctx.db.executemany(stmt, [(guild.id, True) for guild in entities]) self.bot.blocked_entities_cache["guilds"].update(guild.id for guild in entities) embed = discord.Embed(title="Guilds", description="\n".join(str(guild) for guild in entities)) @@ -176,11 +189,25 @@ async def block_remove( # Regardless of block type, update the database, update the cache, and create an informational embed. if block_type == "users": - await upsert_users(ctx.db, *((user.id, False) for user in entities)) + stmt = """\ + INSERT INTO users (user_id, is_blocked) + VALUES ($1, $2) + ON CONFLICT(user_id) + DO UPDATE + SET is_blocked = EXCLUDED.is_blocked; + """ + await ctx.db.executemany(stmt, [(user.id, False) for user in entities]) self.bot.blocked_entities_cache["users"].difference_update(user.id for user in entities) embed = discord.Embed(title="Users", description="\n".join(str(user) for user in entities)) else: - await upsert_guilds(ctx.db, *((guild.id, False) for guild in entities)) + stmt = """\ + INSERT INTO guilds (guild_id, is_blocked) + VALUES ($1, $2) + ON CONFLICT (guild_id) + DO UPDATE + SET is_blocked = EXCLUDED.is_blocked; + """ + await ctx.db.executemany(stmt, [(guild.id, False) for guild in entities]) self.bot.blocked_entities_cache["guilds"].difference_update(guild.id for guild in entities) embed = discord.Embed(title="Guilds", description="\n".join(str(guild) for guild in entities)) @@ -208,7 +235,14 @@ async def context_menu_block_add( interaction: core.Interaction, user: discord.User | discord.Member, ) -> None: - await upsert_users(interaction.client.db_pool, (user.id, True)) + stmt = """ + INSERT INTO users (user_id, is_blocked) + VALUES ($1, $2) + ON CONFLICT(user_id) + DO UPDATE + SET is_blocked = EXCLUDED.is_blocked; + """ + await self.bot.db_pool.execute(stmt, user.id, True) self.bot.blocked_entities_cache["users"].update((user.id,)) # Display the results. @@ -221,7 +255,14 @@ async def context_menu_block_remove( interaction: core.Interaction, user: discord.User | discord.Member, ) -> None: - await upsert_users(interaction.client.db_pool, (user.id, False)) + stmt = """ + INSERT INTO users (user_id, is_blocked) + VALUES ($1, $2) + ON CONFLICT(user_id) + DO UPDATE + SET is_blocked = EXCLUDED.is_blocked; + """ + await self.bot.db_pool.execute(stmt, user.id, False) self.bot.blocked_entities_cache["users"].difference_update((user.id,)) # Display the results. @@ -371,14 +412,14 @@ async def reload(self, ctx: core.Context, extension: str) -> None: failed: list[str] = [] start_time = perf_counter() - for extension in sorted(self.bot.extensions): + for ext in sorted(self.bot.extensions): try: - await self.bot.reload_extension(extension) + await self.bot.reload_extension(ext) except commands.ExtensionError as err: - failed.append(extension) - LOGGER.exception("Couldn't reload extension: %s", extension, exc_info=err) + failed.append(ext) + LOGGER.exception("Couldn't reload extension: %s", ext, exc_info=err) else: - reloaded.append(extension) + reloaded.append(ext) end_time = perf_counter() ratio_succeeded = f"{len(reloaded)}/{len(self.bot.extensions)}" @@ -530,3 +571,28 @@ async def sync_error(self, ctx: core.Context, error: commands.CommandError) -> N LOGGER.exception("Unknown error in sync command", exc_info=error) await ctx.reply(embed=embed) + + @commands.hybrid_command() + async def cmd_tree(self, ctx: core.Context) -> None: + indent_level = 0 + + @contextlib.contextmanager + def new_indent(num: int = 4) -> Generator[None, object, None]: + nonlocal indent_level + indent_level += num + try: + yield + finally: + indent_level -= num + + def walk_commands_with_indent(group: commands.GroupMixin[Any]) -> Generator[str, object, None]: + for cmd in group.commands: + indent = "" if (indent_level == 0) else (indent_level - 1) * "─" + yield f"└{indent}{cmd.qualified_name}" + + if isinstance(cmd, commands.GroupMixin): + with new_indent(): + yield from walk_commands_with_indent(cmd) + + result = "\n".join(["Beira", *walk_commands_with_indent(ctx.bot)]) + await ctx.send(f"```\n{result}\n```") diff --git a/exts/admin.py b/exts/admin.py index 2addc1f..778a774 100644 --- a/exts/admin.py +++ b/exts/admin.py @@ -1,9 +1,6 @@ +"""admin.py: A cog that implements commands for reloading and syncing extensions and other commands, at a guild owner +or bot owner's behest. """ -admin.py: A cog that implements commands for reloading and syncing extensions and other commands, at a guild owner or -bot owner's behest. -""" - -from __future__ import annotations import logging @@ -82,15 +79,15 @@ async def prefixes_add(self, ctx: core.GuildContext, *, new_prefix: str) -> None if new_prefix in local_prefixes: await ctx.send("You already registered this prefix.") else: - guild_query = """INSERT INTO guilds (guild_id) VALUES ($1) ON CONFLICT DO NOTHING;""" - prefix_query = """ + guild_stmt = "INSERT INTO guilds (guild_id) VALUES ($1) ON CONFLICT DO NOTHING;" + prefix_stmt = """\ INSERT INTO guild_prefixes (guild_id, prefix) VALUES ($1, $2) ON CONFLICT (guild_id, prefix) DO NOTHING; """ async with self.bot.db_pool.acquire() as conn, conn.transaction(): - await conn.execute(guild_query, ctx.guild.id) - await conn.execute(prefix_query, ctx.guild.id, new_prefix) + await conn.execute(guild_stmt, ctx.guild.id) + await conn.execute(prefix_stmt, ctx.guild.id, new_prefix) # Update it in the cache. self.bot.prefix_cache.setdefault(ctx.guild.id, []).append(new_prefix) @@ -116,12 +113,9 @@ async def prefixes_remove(self, ctx: core.GuildContext, *, old_prefix: str) -> N if old_prefix not in local_prefixes: await ctx.send("This prefix was never registered in this guild or has already been unregistered.") else: - prefix_query = """DELETE FROM guild_prefixes WHERE guild_id = $1 AND prefix = $2;""" - - # Update it in the database. - await self.bot.db_pool.execute(prefix_query, ctx.guild.id, old_prefix) - - # Update it in the cache. + # Update it in the database and the cache. + prefix_stmt = "DELETE FROM guild_prefixes WHERE guild_id = $1 AND prefix = $2;" + await self.bot.db_pool.execute(prefix_stmt, ctx.guild.id, old_prefix) self.bot.prefix_cache.setdefault(ctx.guild.id, [old_prefix]).remove(old_prefix) await ctx.send(f"'{old_prefix}' has been unregistered as a prefix in this guild.") @@ -139,16 +133,13 @@ async def prefixes_reset(self, ctx: core.GuildContext) -> None: """ async with ctx.typing(): - prefix_query = """DELETE FROM guild_prefixes WHERE guild_id = $1;""" - - # Update it in the database. - await self.bot.db_pool.execute(prefix_query, ctx.guild.id) - # Update it in the cache. + # Update it in the database and the cache. + prefix_stmt = """DELETE FROM guild_prefixes WHERE guild_id = $1;""" + await self.bot.db_pool.execute(prefix_stmt, ctx.guild.id) self.bot.prefix_cache.setdefault(ctx.guild.id, []).clear() - await ctx.send( - "The prefix(es) for this guild have been reset. Now only accepting the default prefix: `$`.", - ) + content = "The prefix(es) for this guild have been reset. Now only accepting the default prefix: `$`." + await ctx.send(content) @prefixes_add.error @prefixes_remove.error diff --git a/exts/bot_stats.py b/exts/bot_stats.py index b4749ba..31443ff 100644 --- a/exts/bot_stats.py +++ b/exts/bot_stats.py @@ -1,25 +1,16 @@ -""" -bot_stats.py: A cog for tracking different bot metrics. -""" - -from __future__ import annotations +"""bot_stats.py: A cog for tracking different bot metrics.""" import logging from datetime import timedelta -from typing import TYPE_CHECKING, Literal +from typing import Literal +import asyncpg import discord from discord.app_commands import Choice from discord.ext import commands import core -from core.utils import StatsEmbed, upsert_guilds, upsert_users - - -if TYPE_CHECKING: - from asyncpg import Record -else: - Record = object +from core.utils import StatsEmbed LOGGER = logging.getLogger(__name__) @@ -68,6 +59,8 @@ async def track_command_use(self, ctx: core.Context) -> None: assert ctx.command is not None + db = self.bot.db_pool + # Make sure all possible involved users and guilds are in the database before using their ids as foreign keys. user_info = [ctx.author] guild_info = [ctx.guild] if ctx.guild else [] @@ -79,9 +72,11 @@ async def track_command_use(self, ctx: core.Context) -> None: guild_info.append(arg) if user_info: - await upsert_users(self.bot.db_pool, *user_info) + stmt = "INSERT INTO users (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING;" + await db.executemany(stmt, [user.id for user in user_info], timeout=60.0) if guild_info: - await upsert_guilds(self.bot.db_pool, *guild_info) + stmt = "INSERT INTO guilds (guild_id) VALUES ($1) ON CONFLICT (guild_id) DO NOTHING;" + await db.executemany(stmt, [guild.id for guild in guild_info], timeout=60.0) # Assemble the record to upsert. cmd = ( @@ -95,11 +90,11 @@ async def track_command_use(self, ctx: core.Context) -> None: ctx.command_failed, ) - query = """ + stmt = """\ INSERT into commands (guild_id, channel_id, user_id, date_time, prefix, command, app_command, failed) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8); """ - await self.bot.db_pool.execute(query, *cmd, timeout=60.0) + await db.execute(stmt, *cmd, timeout=60.0) @commands.Cog.listener("on_command_completion") async def track_command_completion(self, ctx: core.Context) -> None: @@ -136,7 +131,8 @@ async def track_command_error(self, ctx: core.Context, error: commands.CommandEr async def add_guild_to_db(self, guild: discord.Guild) -> None: """Upserts a guild - one that the bot just joined - to the database.""" - await upsert_guilds(self.bot.db_pool, guild) + stmt = "INSERT INTO guilds (guild_id) VALUES ($1) ON CONFLICT (guild_id) DO NOTHING;" + await self.bot.db_pool.execute(stmt, guild.id, timeout=60.0) @commands.hybrid_command(name="usage") async def check_usage(self, ctx: core.Context, *, search_factors: CommandStatsSearchFlags) -> None: @@ -166,8 +162,12 @@ async def check_usage(self, ctx: core.Context, *, search_factors: CommandStatsSe record_tuples = (((get_strat(record[0]) or record[0]), record[1]) for record in records) - ldbd_emojis = ["\N{FIRST PLACE MEDAL}", "\N{SECOND PLACE MEDAL}", "\N{THIRD PLACE MEDAL}"] - ldbd_emojis.extend("\N{SPORTS MEDAL}" for _ in range(6)) + ldbd_emojis = [ + "\N{FIRST PLACE MEDAL}", + "\N{SECOND PLACE MEDAL}", + "\N{THIRD PLACE MEDAL}", + *("\N{SPORTS MEDAL}" for _ in range(6)), + ] embed.add_leaderboard_fields(ldbd_content=record_tuples, ldbd_emojis=ldbd_emojis) else: @@ -181,7 +181,7 @@ async def get_usage( command: str | None = None, guild: discord.Guild | None = None, universal: bool = False, - ) -> list[Record]: + ) -> list[asyncpg.Record]: """Queries the database for command usage.""" query_args: list[object] = [] # Holds the query args as objects. @@ -189,24 +189,24 @@ async def get_usage( # Create the base queries. if guild: - query = """ - SELECT u.user_id, COUNT(*) - FROM commands cmds INNER JOIN users u on cmds.user_id = u.user_id - {where} - GROUP BY u.user_id - ORDER BY COUNT(*) DESC - LIMIT 10; - """ + query = """\ +SELECT u.user_id, COUNT(*) +FROM commands cmds INNER JOIN users u on cmds.user_id = u.user_id +{where} +GROUP BY u.user_id +ORDER BY COUNT(*) DESC +LIMIT 10; +""" else: - query = """ - SELECT g.guild_id, COUNT(*) - FROM commands cmds INNER JOIN guilds g on cmds.guild_id = g.guild_id - {where} - GROUP BY g.guild_id - ORDER BY COUNT(*) DESC - LIMIT 10; - """ + query = """\ +SELECT g.guild_id, COUNT(*) +FROM commands cmds INNER JOIN guilds g on cmds.guild_id = g.guild_id +{where} +GROUP BY g.guild_id +ORDER BY COUNT(*) DESC +LIMIT 10; +""" # Create the WHERE clause for the query. if guild and not universal: diff --git a/exts/dice.py b/exts/dice.py index 8b696fe..b5fb36a 100644 --- a/exts/dice.py +++ b/exts/dice.py @@ -1,10 +1,5 @@ -""" -dice.py: The extension that holds a die roll command and all the associated utility classes. - -TODO: Consider adding more elements from https://wiki.roll20.net/Dice_Reference. -""" - -from __future__ import annotations +"""dice.py: The extension that holds a die roll command and all the associated utility classes.""" +# TODO: Consider adding more elements from https://wiki.roll20.net/Dice_Reference. import logging import operator @@ -38,8 +33,6 @@ class Die(msgspec.Struct, frozen=True): The emoji representing the die in displays. color: `discord.Colour` The color representing the die in embed displays. - label: `str`, default=f"D{value}" - The label, or name, of the die. Defaults to ``D{value}``, as with most dice in casual discussion. """ value: int @@ -48,6 +41,8 @@ class Die(msgspec.Struct, frozen=True): @property def label(self) -> str: + """`str`: The label, or name, of the die. Defaults to ``D{value}``, as with most dice in casual discussion.""" + return f"D{self.value}" diff --git a/exts/emoji_ops.py b/exts/emoji_ops.py index 25e3944..6b6a750 100644 --- a/exts/emoji_ops.py +++ b/exts/emoji_ops.py @@ -1,11 +1,8 @@ -""" -emoji_ops.py: This cog is meant to provide functionality for stealing emojis. +"""emoji_ops.py: This cog is meant to provide functionality for stealing emojis. Credit to Froopy and Danny for inspiration from their bots. """ -from __future__ import annotations - import asyncio import logging import re diff --git a/exts/fandom_wiki.py b/exts/fandom_wiki.py index ce01922..8909227 100644 --- a/exts/fandom_wiki.py +++ b/exts/fandom_wiki.py @@ -1,10 +1,7 @@ -""" -fandom_wiki.py: A cog for searching a fandom's Fandom wiki page. Starting with characters from the ACI100 wiki +"""fandom_wiki.py: A cog for searching a fandom's Fandom wiki page. Starting with characters from the ACI100 wiki first. """ -from __future__ import annotations - import asyncio import logging import textwrap @@ -13,9 +10,10 @@ import aiohttp import discord -from discord.app_commands import Choice +import lxml.etree +import lxml.html +from discord import app_commands from discord.ext import commands -from lxml import etree, html import core from core.utils import EMOJI_URL, html_to_markdown @@ -71,7 +69,7 @@ async def load_wiki_all_pages(session: aiohttp.ClientSession, wiki_url: str) -> while True: async with session.get(next_path) as response: text = await response.text() - element = html.fromstring(text) + element = lxml.html.fromstring(text) pages_dict.update( { el.attrib["title"]: urljoin(wiki_url, el.attrib["href"]) @@ -86,7 +84,7 @@ async def load_wiki_all_pages(session: aiohttp.ClientSession, wiki_url: str) -> return pages_dict -def clean_fandom_page(element: etree._Element) -> etree._Element: # type: ignore [reportPrivateUsage] +def clean_fandom_page(element: lxml.etree._Element) -> lxml.etree._Element: # pyright: ignore [reportPrivateUsage] """Attempts to clean a Fandom wiki page. Removes everything from a Fandom wiki page that isn't the first few lines, if possible. @@ -97,7 +95,7 @@ def clean_fandom_page(element: etree._Element) -> etree._Element: # type: ignor # Clean the content. infoboxes = element.findall(".//aside[@class='portable-infobox']") for box in infoboxes: - box.getparent().remove(box) # type: ignore [reportOptionalMemberAccess] + box.getparent().remove(box) # pyright: ignore [reportOptionalMemberAccess] toc = element.find(".//div[@id='toc']") if toc is not None: @@ -108,7 +106,7 @@ def clean_fandom_page(element: etree._Element) -> etree._Element: # type: ignor else: if index > summary_end_index: summary_end_index = index - toc.getparent().remove(toc) # type: ignore [reportOptionalMemberAccess] + toc.getparent().remove(toc) # pyright: ignore [reportOptionalMemberAccess] subheading = element.find(".//h2") if subheading is not None: @@ -119,15 +117,15 @@ def clean_fandom_page(element: etree._Element) -> etree._Element: # type: ignor else: if index > summary_end_index: summary_end_index = index - subheading.getparent().remove(subheading) # type: ignore [reportOptionalMemberAccess] + subheading.getparent().remove(subheading) # pyright: ignore [reportOptionalMemberAccess] if summary_end_index != 0: for el in list(element[summary_end_index + 1 :]): - el.getparent().remove(el) # type: ignore [reportOptionalMemberAccess] + el.getparent().remove(el) # pyright: ignore [reportOptionalMemberAccess] for el in list(element): if el.text and el.text == "\n": - el.getparent().remove(el) # type: ignore [reportOptionalMemberAccess] + el.getparent().remove(el) # pyright: ignore [reportOptionalMemberAccess] return element @@ -139,7 +137,7 @@ async def process_fandom_page(session: aiohttp.ClientSession, url: str) -> tuple char_summary, char_thumbnail = None, None # Extract the main content. - element = html.fromstring(await response.text()) + element = lxml.html.fromstring(await response.text()) content = element.find(".//div[@class='mw-parser-output']") if content is not None: # Extract the image. @@ -235,14 +233,20 @@ async def wiki(self, ctx: core.Context, wiki: str, search_term: str) -> None: await ctx.send(embed=embed) @wiki.autocomplete("wiki") - async def wiki_autocomplete(self, _: core.Interaction, current: str) -> list[Choice[str]]: + async def wiki_autocomplete(self, _: core.Interaction, current: str) -> list[app_commands.Choice[str]]: """Autocomplete callback for the names of different wikis.""" options = self.all_wikis.keys() - return [Choice(name=name, value=name) for name in options if current.casefold() in name.casefold()][:25] + return [ + app_commands.Choice(name=name, value=name) for name in options if current.casefold() in name.casefold() + ][:25] @wiki.autocomplete("search_term") - async def wiki_search_term_autocomplete(self, interaction: core.Interaction, current: str) -> list[Choice[str]]: + async def wiki_search_term_autocomplete( + self, + interaction: core.Interaction, + current: str, + ) -> list[app_commands.Choice[str]]: """Autocomplete callback for the names of different wiki pages. Defaults to searching through the AoC wiki if the given wiki name is invalid. @@ -253,7 +257,9 @@ async def wiki_search_term_autocomplete(self, interaction: core.Interaction, cur wiki = "Harry Potter and the Ashes of Chaos" options = self.all_wikis[wiki] - return [Choice(name=name, value=name) for name in options if current.casefold() in name.casefold()][:25] + return [ + app_commands.Choice(name=name, value=name) for name in options if current.casefold() in name.casefold() + ][:25] async def search_wiki(self, wiki_name: str, wiki_query: str) -> discord.Embed: """Search a Fandom wiki for different pages. diff --git a/exts/ff_metadata/__init__.py b/exts/ff_metadata/__init__.py index 0e2a034..b12d56e 100644 --- a/exts/ff_metadata/__init__.py +++ b/exts/ff_metadata/__init__.py @@ -1,15 +1,9 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING +import core from .ff_metadata import FFMetadataCog -if TYPE_CHECKING: - from core import Beira - - -async def setup(bot: Beira) -> None: +async def setup(bot: core.Beira) -> None: """Connects cog to bot.""" await bot.add_cog(FFMetadataCog(bot)) diff --git a/exts/ff_metadata/ff_metadata.py b/exts/ff_metadata/ff_metadata.py index a0810ef..e37bf2e 100644 --- a/exts/ff_metadata/ff_metadata.py +++ b/exts/ff_metadata/ff_metadata.py @@ -1,15 +1,10 @@ -""" -ff_metadata.py: A cog with triggers for retrieving story metadata. - -TODO: Account for orphaned fics, anonymous fics, really long embed descriptions, and series with more than 25 fics. -""" - -from __future__ import annotations +"""ff_metadata.py: A cog with triggers for retrieving story metadata.""" +# TODO: Account for orphaned fics, anonymous fics, really long embed descriptions, and series with more than 25 fics. import logging import re from collections.abc import AsyncGenerator -from typing import Literal, TypeAlias +from typing import Literal import ao3 import atlas_api @@ -27,7 +22,7 @@ ) -StoryDataType: TypeAlias = atlas_api.Story | fichub_api.Story | ao3.Work | ao3.Series +type StoryDataType = atlas_api.Story | fichub_api.Story | ao3.Work | ao3.Series LOGGER = logging.getLogger(__name__) @@ -54,8 +49,7 @@ def cog_emoji(self) -> discord.PartialEmoji: async def cog_load(self) -> None: # FIXME: Setup logging into AO3 via ao3.py. # Load a cache of channels to auto-respond in. - query = """SELECT guild_id, channel_id FROM fanfic_autoresponse_settings;""" - records = await self.bot.db_pool.fetch(query) + records = await self.bot.db_pool.fetch("SELECT guild_id, channel_id FROM fanfic_autoresponse_settings;") for record in records: self.allowed_channels_cache.setdefault(record["guild_id"], set()).add(record["channel_id"]) @@ -143,17 +137,17 @@ async def autoresponse_add( A list of channels to add, separated by spaces. """ - command = """ - INSERT INTO fanfic_autoresponse_settings (guild_id, channel_id) - VALUES ($1, $2) - ON CONFLICT (guild_id, channel_id) DO NOTHING; - """ - query = """SELECT channel_id FROM fanfic_autoresponse_settings WHERE guild_id = $1;""" - async with ctx.typing(): # Update the database. async with self.bot.db_pool.acquire() as conn: - await conn.executemany(command, [(ctx.guild.id, channel.id) for channel in channels]) + stmt = """\ + INSERT INTO fanfic_autoresponse_settings (guild_id, channel_id) + VALUES ($1, $2) + ON CONFLICT (guild_id, channel_id) DO NOTHING; + """ + await conn.executemany(stmt, [(ctx.guild.id, channel.id) for channel in channels]) + + query = "SELECT channel_id FROM fanfic_autoresponse_settings WHERE guild_id = $1;" records = await conn.fetch(query, ctx.guild.id) # Update the cache. @@ -183,13 +177,13 @@ async def autoresponse_remove( A list of channels to remove, separated by spaces. """ - command = """DELETE FROM fanfic_autoresponse_settings WHERE channel_id = $1;""" - query = """SELECT channel_id FROM fanfic_autoresponse_settings WHERE guild_id = $1;""" - async with ctx.typing(): # Update the database. async with self.bot.db_pool.acquire() as con: - await con.executemany(command, [(channel.id,) for channel in channels]) + stmt = "DELETE FROM fanfic_autoresponse_settings WHERE channel_id = $1;" + await con.executemany(stmt, [(channel.id,) for channel in channels]) + + query = "SELECT channel_id FROM fanfic_autoresponse_settings WHERE guild_id = $1;" records = await con.fetch(query, ctx.guild.id) # Update the cache. diff --git a/exts/ff_metadata/utils.py b/exts/ff_metadata/utils.py index 2097fd7..9f29598 100644 --- a/exts/ff_metadata/utils.py +++ b/exts/ff_metadata/utils.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import re import textwrap from typing import Any, NamedTuple diff --git a/exts/help.py b/exts/help.py index 362882c..d3167b9 100644 --- a/exts/help.py +++ b/exts/help.py @@ -1,13 +1,10 @@ -""" -help.py: A custom help command for Beira set through a cog. +"""help.py: A custom help command for Beira set through a cog. Notes ----- The guide this was based off of: https://gist.github.com/InterStella0/b78488fb28cadf279dfd3164b9f0cf96 """ -from __future__ import annotations - import logging import re from collections.abc import Mapping diff --git a/exts/lol.py b/exts/lol.py index 80eb222..4a8c00f 100644 --- a/exts/lol.py +++ b/exts/lol.py @@ -1,11 +1,8 @@ -""" -lol.py: A cog for checking user win rates and other stats in League of Legends. +"""lol.py: A cog for checking user win rates and other stats in League of Legends. Credit to Ralph for the idea and initial implementation. """ -from __future__ import annotations - import asyncio import itertools import logging @@ -15,9 +12,9 @@ import aiohttp import discord +import lxml.html from arsenic import browsers, errors, get_session, services # type: ignore # Third-party lib typing. from discord.ext import commands -from lxml import html import core from core.utils import StatsEmbed @@ -59,7 +56,7 @@ async def update_op_gg_profiles(urls: list[str]) -> None: class UpdateOPGGView(discord.ui.View): """A small view that adds an update button for OP.GG stats.""" - def __init__(self, author_id: int, cog: LoLCog, summoner_name_list: list[str], **kwargs: Any) -> None: + def __init__(self, author_id: int, cog: "LoLCog", summoner_name_list: list[str], **kwargs: Any) -> None: super().__init__(**kwargs) self.author_id = author_id self.cog = cog @@ -256,7 +253,7 @@ async def check_lol_stats(self, summoner_name: str) -> tuple[str, str, str]: content = await response.text() # Parse the summoner information for winrate and tier (referred to later as rank). - tree = html.fromstring(content) + tree = lxml.html.fromstring(content) winrate = str(tree.xpath("//div[@class='ratio']/string()")).removeprefix("Win Rate") rank = str(tree.xpath("//div[@class='tier']/string()")).capitalize() if not (winrate and rank): diff --git a/exts/misc.py b/exts/misc.py index e2f3da0..2f3e9ed 100644 --- a/exts/misc.py +++ b/exts/misc.py @@ -1,10 +1,4 @@ -""" -misc.py: A cog for testing slash and hybrid command functionality. - -Side note: This is the cog with the ``ping`` command. -""" - -from __future__ import annotations +"""misc.py: A cog for testing slash and hybrid command functionality.""" import asyncio import colorsys @@ -16,7 +10,6 @@ import time from io import BytesIO, StringIO -import aiohttp import discord import openpyxl import openpyxl.styles @@ -128,25 +121,6 @@ def process_color_data(role_data: list[tuple[str, discord.Colour]]) -> BytesIO: return BytesIO(tmp.read()) -async def create_inspiration(session: aiohttp.ClientSession) -> str: - """Makes a call to InspiroBot's API to generate an inspirational poster. - - Parameters - ---------- - session: `aiohttp.ClientSession` - The web session used to access the API. - - Returns - ------- - `str` - The url for the generated poster. - """ - - async with session.get(url=INSPIROBOT_API_URL, params={"generate": "true"}) as response: - response.raise_for_status() - return await response.text() - - class MiscCog(commands.Cog, name="Misc"): """A cog with some basic commands, originally used for testing slash and hybrid command functionality.""" @@ -252,7 +226,7 @@ async def ping_(self, ctx: core.Context) -> None: typing_ping = (time.perf_counter() - start_time) * 1000 start_time = time.perf_counter() - await self.bot.db_pool.fetch("""SELECT * FROM guilds;""") + await self.bot.db_pool.fetch("SELECT * FROM guilds;") db_ping = (time.perf_counter() - start_time) * 1000 start_time = time.perf_counter() @@ -294,6 +268,16 @@ async def meowify(self, ctx: core.Context, *, text: str) -> None: @commands.guild_only() @commands.hybrid_command() async def role_excel(self, ctx: core.GuildContext, by_color: bool = False) -> None: + """Get a spreadsheet with a guild's roles, optionally sorted by color. + + Parameters + ---------- + ctx: `core.GuildContext` + The invocation context, restricted to a guild. + by_color: `bool`, default=False + Whether the roles should be sorted by color. If False, sorts by name. Default is False. + """ + def color_key(item: tuple[str, discord.Colour]) -> tuple[int, int, int]: r, g, b = item[1].to_rgb() return color_step(r, g, b, 8) @@ -312,7 +296,11 @@ async def inspire_me(self, ctx: core.Context) -> None: """Generate a random inspirational poster with InspiroBot.""" async with ctx.typing(): - image_url = await create_inspiration(ctx.session) + # Make a call to InspiroBot's API to generate an inspirational poster. + async with ctx.session.get(url=INSPIROBOT_API_URL, params={"generate": "true"}) as response: + response.raise_for_status() + image_url = await response.text() + embed = ( discord.Embed(color=0xE04206) .set_image(url=image_url) diff --git a/exts/music.py b/exts/music.py index 78b0af7..bb586db 100644 --- a/exts/music.py +++ b/exts/music.py @@ -1,9 +1,6 @@ +"""music.py: This cog provides functionality for playing tracks in voice channels given search terms or urls, +implemented with Wavelink. """ -music.py: This cog provides functionality for playing tracks in voice channels given search terms or urls, implemented -with Wavelink. -""" - -from __future__ import annotations import datetime import functools diff --git a/exts/notifications/__init__.py b/exts/notifications/__init__.py index 8d495eb..934c5ed 100644 --- a/exts/notifications/__init__.py +++ b/exts/notifications/__init__.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import core from .aci_notifications import make_listeners as make_aci_listeners diff --git a/exts/notifications/aci_notifications.py b/exts/notifications/aci_notifications.py index 1831900..20565fd 100644 --- a/exts/notifications/aci_notifications.py +++ b/exts/notifications/aci_notifications.py @@ -1,14 +1,10 @@ -""" -custom_notifications.py: One or more listenerrs for sending custom notifications based on events. -""" - -from __future__ import annotations +"""custom_notifications.py: One or more listenerrs for sending custom notifications based on events.""" import functools import logging import re from collections.abc import Callable -from typing import Any, TypeAlias +from typing import Any import discord from discord import CategoryChannel, ForumChannel, StageChannel, TextChannel, VoiceChannel @@ -16,11 +12,11 @@ import core -ValidGuildChannel: TypeAlias = VoiceChannel | StageChannel | ForumChannel | TextChannel | CategoryChannel +type ValidGuildChannel = VoiceChannel | StageChannel | ForumChannel | TextChannel | CategoryChannel LOGGER = logging.getLogger(__name__) -# 799077440139034654 would be the actional channel should this go into "production" +# 799077440139034654 would be the actual channel should the delete hooks go into "production". ACI_DELETE_CHANNEL = 975459460560605204 # A list of ids for Tatsu leveled roles to keep track of. @@ -35,14 +31,13 @@ # The mod role(s) to ping when sending notifications. ACI_MOD_ROLE = 780904973004570654 -aci_guild_id = core.CONFIG.discord.important_guilds["prod"][0] +ACI_GUILD_ID = core.CONFIG.discord.important_guilds["prod"][0] LEAKY_INSTAGRAM_LINK_PATTERN = re.compile(r"(instagram\.com/.*?)&igsh.*==") async def on_server_boost_role_member_update( - bot: core.Beira, - role_log_wbhk: discord.Webhook, + log_webhook: discord.Webhook, before: discord.Member, after: discord.Member, ) -> None: @@ -54,18 +49,17 @@ async def on_server_boost_role_member_update( # Check if the update is in the right server, a member got new roles, and they got a new "Server Booster" role. if ( - before.guild.id == aci_guild_id + before.guild.id == ACI_GUILD_ID and len(new_roles := set(after.roles).difference(before.roles)) > 0 and after.guild.premium_subscriber_role in new_roles ): # Send a message notifying holders of some other role(s) about this new role acquisition. content = f"<@&{ACI_MOD_ROLE}>, {after.mention} just boosted the server!" - await role_log_wbhk.send(content) + await log_webhook.send(content) async def on_leveled_role_member_update( - bot: core.Beira, - role_log_wbhk: discord.Webhook, + log_webhook: discord.Webhook, before: discord.Member, after: discord.Member, ) -> None: @@ -77,7 +71,7 @@ async def on_leveled_role_member_update( # Check if the update is in the right server, a member got new roles, and they got a relevant leveled role. if ( - before.guild.id == aci_guild_id + before.guild.id == ACI_GUILD_ID and len(new_roles := set(after.roles).difference(before.roles)) > 0 and (new_leveled_roles := tuple(role for role in new_roles if (role.id in ACI_LEVELED_ROLES))) ): @@ -93,11 +87,11 @@ async def on_leveled_role_member_update( # Send a message notifying holders of some other role(s) about this new role acquisition. role_names = tuple(role.name for role in new_leveled_roles) content = f"<@&{ACI_MOD_ROLE}>, {after.mention} was given the `{role_names}` role(s)." - await role_log_wbhk.send(content) + await log_webhook.send(content) async def on_bad_twitter_link(bot: core.Beira, message: discord.Message) -> None: - if message.author == bot.user or (not message.guild or message.guild.id != aci_guild_id): + if message.author == bot.user or (not message.guild or message.guild.id != ACI_GUILD_ID): return if links := re.findall(r"(?:http(?:s)?://|(? None async def on_leaky_instagram_link(message: discord.Message) -> None: - if (not message.guild) or (message.guild.id != aci_guild_id): + if (not message.guild) or (message.guild.id != ACI_GUILD_ID): return if not LEAKY_INSTAGRAM_LINK_PATTERN.search(message.content): @@ -204,13 +198,13 @@ def make_listeners(bot: core.Beira) -> tuple[tuple[str, Callable[..., Any]], ... """Connects listeners to bot.""" # The webhook url that will be used to send ACI-related notifications. - aci_webhook_url: str = core.CONFIG.discord.webhooks[0] + aci_webhook_url = core.CONFIG.discord.webhooks[0] role_log_webhook = discord.Webhook.from_url(aci_webhook_url, session=bot.web_session) # Adjust the arguments for the listeners and provide corresponding event name. return ( - ("on_member_update", functools.partial(on_leveled_role_member_update, bot, role_log_webhook)), - ("on_member_update", functools.partial(on_server_boost_role_member_update, bot, role_log_webhook)), + ("on_member_update", functools.partial(on_leveled_role_member_update, role_log_webhook)), + ("on_member_update", functools.partial(on_server_boost_role_member_update, role_log_webhook)), ("on_message", on_leaky_instagram_link), # ("on_message", functools.partial(on_bad_twitter_link, bot)), # Twitter works. # noqa: ERA001 ) diff --git a/exts/notifications/other_triggers.py b/exts/notifications/other_triggers.py index 0bbdce5..e259d3e 100644 --- a/exts/notifications/other_triggers.py +++ b/exts/notifications/other_triggers.py @@ -1,6 +1,7 @@ import asyncio import functools import re +from collections.abc import Callable from typing import Any import aiohttp @@ -45,13 +46,10 @@ async def on_bad_9gag_link(bot: core.Beira, message: discord.Message) -> None: f"Reposted from {message.author.mention} ({message.author.name} - {message.author.id}):\n\n" f"{new_links}" ) - await message.reply( - content, - allowed_mentions=discord.AllowedMentions(users=False, replied_user=False), - ) + await message.reply(content, allowed_mentions=discord.AllowedMentions(users=False, replied_user=False)) -def make_listeners(bot: core.Beira) -> tuple[tuple[str, functools.partial[Any]], ...]: +def make_listeners(bot: core.Beira) -> tuple[tuple[str, Callable[..., Any]], ...]: """Connects listeners to bot.""" # Adjust the arguments for the listeners and provide corresponding event name. diff --git a/exts/notifications/rss_notifications.py b/exts/notifications/rss_notifications.py index e9b1029..72a6e0f 100644 --- a/exts/notifications/rss_notifications.py +++ b/exts/notifications/rss_notifications.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import asyncio from typing import Self @@ -55,7 +53,7 @@ def process_new_item(self, text: str) -> discord.Embed: async def notification_check_loop(self) -> None: """Continuously check urls for updates and send notifications to webhooks accordingly.""" - notif_tasks: list[asyncio.Task[str | None]] = [asyncio.create_task(self.check_url(rec) for rec in self.records)] + notif_tasks: list[asyncio.Task[str | None]] = [asyncio.create_task(self.check_url(rec)) for rec in self.records] results: list[str | None] = await asyncio.gather(*notif_tasks) to_update = ((result, rec) for result, rec in zip(results, self.records, strict=True) if result is not None) for result, rec in to_update: diff --git a/exts/patreon.py b/exts/patreon.py index c958257..c920f64 100644 --- a/exts/patreon.py +++ b/exts/patreon.py @@ -1,11 +1,8 @@ -""" -patreon.py: A cog for checking which Discord members are currently patrons of ACI100. +"""patreon.py: A cog for checking which Discord members are currently patrons of ACI100. Work in progress to make the view portion functional for M J Bradley. """ -from __future__ import annotations - import logging import textwrap import urllib.parse @@ -143,7 +140,7 @@ async def cog_command_error(self, ctx: core.Context, error: Exception) -> None: async def _get_patreon_roles(self) -> None: await self.bot.wait_until_ready() - query = """SELECT * FROM patreon_creators WHERE creator_name = 'ACI100' ORDER BY tier_value;""" + query = "SELECT * FROM patreon_creators WHERE creator_name = 'ACI100' ORDER BY tier_value;" records: list[asyncpg.Record] = await self.bot.db_pool.fetch(query) self.patreon_tiers_info = [PatreonTierInfo.from_record(record) for record in records] diff --git a/exts/presence.py b/exts/presence.py index f7433e9..3f64a0a 100644 --- a/exts/presence.py +++ b/exts/presence.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import random import discord diff --git a/exts/snowball/__init__.py b/exts/snowball/__init__.py index 5536286..1f2b429 100644 --- a/exts/snowball/__init__.py +++ b/exts/snowball/__init__.py @@ -1,15 +1,9 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING +import core from .snowball import SnowballCog -if TYPE_CHECKING: - from core import Beira - - -async def setup(bot: Beira) -> None: +async def setup(bot: core.Beira) -> None: """Connects cog to bot.""" await bot.add_cog(SnowballCog(bot)) diff --git a/exts/snowball/snow_text.py b/exts/snowball/snow_text.py index dcd0f49..df62523 100644 --- a/exts/snowball/snow_text.py +++ b/exts/snowball/snow_text.py @@ -1,3 +1,15 @@ +__all__ = ( + "COLLECT_SUCCEED_IMGS", + "COLLECT_FAIL_IMGS", + "HIT_NOTES", + "HIT_IMGS", + "MISS_NOTES", + "MISS_IMGS", + "SNOW_INSPO_URL", + "SNOW_INSPO_NOTE", + "SNOW_CODE_NOTE", +) + COLLECT_SUCCEED_IMGS = ( "https://c.tenor.com/NBqwJNBaSXUAAAAC/playing-with-snow-piu-piu.gif?width=400&height=225", "https://media.tenor.com/odNpnufgwkYAAAAC/anime-cute.gif", diff --git a/exts/snowball/snowball.py b/exts/snowball/snowball.py index 0119b0b..e00ed3a 100644 --- a/exts/snowball/snowball.py +++ b/exts/snowball/snowball.py @@ -1,15 +1,12 @@ -""" -snowball.py: A snowball cog that implements a version of Discord's 2021 Snowball Bot game. +"""snowball.py: A snowball cog that implements a version of Discord's 2021 Snowball Bot game. -References ----------- -Rules and code inspiration. -https://web.archive.org/web/20220103003050/https://support.discord.com/hc/en-us/articles/4414111886359-Snowsgiving-2021-Snowball-Bot-FAQ -https://github.com/0xMukesh/snowball-bot +Notes +----- +Rules and code inspiration: +- https://web.archive.org/web/20220103003050/https://support.discord.com/hc/en-us/articles/4414111886359-Snowsgiving-2021-Snowball-Bot-FAQ +- https://github.com/0xMukesh/snowball-bot """ -from __future__ import annotations - import logging import random from itertools import cycle, islice @@ -187,7 +184,7 @@ async def throw(self, ctx: core.GuildContext, *, target: discord.Member) -> None embed = discord.Embed(color=0x60FF60) ephemeral = False - query = "SELECT hits, misses, kos, stock FROM snowball_stats WHERE guild_id = $1 AND user_id = $2" + query = "SELECT hits, misses, kos, stock FROM snowball_stats WHERE guild_id = $1 AND user_id = $2;" record = await ctx.db.fetchrow(query, ctx.guild.id, ctx.author.id) # The user has to be in the database and have collected at least one snowball before they can throw one. @@ -260,7 +257,7 @@ async def transfer(self, ctx: core.GuildContext, amount: int, *, receiver: disco await ctx.send(embed=def_embed, ephemeral=True) return - query = "SELECT hits, misses, kos, stock FROM snowball_stats WHERE guild_id = $1 AND user_id = $2" + query = "SELECT hits, misses, kos, stock FROM snowball_stats WHERE guild_id = $1 AND user_id = $2;" async with ctx.db.acquire() as conn, conn.transaction(): giver_record = await conn.fetchrow(query, ctx.guild.id, ctx.author.id) receiver_record = await conn.fetchrow(query, ctx.guild.id, receiver.id) @@ -337,7 +334,7 @@ async def steal(self, ctx: core.GuildContext, amount: int, *, victim: discord.Me await ctx.send(embed=def_embed, ephemeral=True) return - query = "SELECT hits, misses, kos, stock FROM snowball_stats WHERE guild_id = $1 AND user_id = $2" + query = "SELECT hits, misses, kos, stock FROM snowball_stats WHERE guild_id = $1 AND user_id = $2;" async with ctx.db.acquire() as conn, conn.transaction(): thief_record = await conn.fetchrow(query, ctx.guild.id, ctx.author.id) victim_record = await conn.fetchrow(query, ctx.guild.id, victim.id) @@ -390,7 +387,7 @@ async def stats(self, ctx: core.GuildContext, *, target: discord.User = commands all their interactions within the guild in context. """ - query = """ + query = """\ SELECT guild_rank, hits, misses, kos, stock FROM( SELECT user_id, hits, kos, misses, stock, @@ -469,7 +466,7 @@ async def stats_global(self, ctx: core.Context, *, target: discord.User = comman async def leaderboard(self, ctx: core.GuildContext) -> None: """See who's dominating the Snowball Bot leaderboard in your server.""" - query = """ + query = """\ SELECT user_id, hits, kos, misses, stock, DENSE_RANK() over (ORDER BY hits DESC, kos, misses, stock DESC, user_id DESC) AS rank FROM snowball_stats @@ -501,8 +498,7 @@ async def leaderboard_global(self, ctx: core.Context) -> None: """See who's dominating the Global Snowball Bot leaderboard across all the servers.""" assert self.bot.user # Known to exist during runtime. - query = "SELECT * FROM global_rank_view LIMIT $1;" - global_ldbd = await ctx.db.fetch(query, LEADERBOARD_MAX) + global_ldbd = await ctx.db.fetch("SELECT * FROM global_rank_view LIMIT $1;", LEADERBOARD_MAX) embed = StatsEmbed( color=0x2F3136, @@ -523,8 +519,7 @@ async def leaderboard_guilds(self, ctx: core.Context) -> None: """See which guild is dominating the Snowball Bot leaderboard.""" assert self.bot.user # Known to exist during runtime. - query = "SELECT * FROM guilds_only_rank_view LIMIT $1;" - guilds_only_ldbd = await ctx.db.fetch(query, LEADERBOARD_MAX) + guilds_only_ldbd = await ctx.db.fetch("SELECT * FROM guilds_only_rank_view LIMIT $1;", LEADERBOARD_MAX) embed = StatsEmbed( color=0x2F3136, diff --git a/exts/snowball/utils.py b/exts/snowball/utils.py index a9ff608..4e735e0 100644 --- a/exts/snowball/utils.py +++ b/exts/snowball/utils.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from typing import Self import asyncpg @@ -8,7 +6,7 @@ from discord.ext import commands import core -from core.utils.db import Connection_alias, Pool_alias, upsert_guilds, upsert_users +from core.utils.db import Connection_alias, Pool_alias __all__ = ( @@ -61,10 +59,16 @@ async def upsert_record( """Upserts a user's snowball stats based on the given stat parameters.""" # Upsert the relevant users and guilds to the database before adding a snowball record. - await upsert_users(conn, member) - await upsert_guilds(conn, member.guild) + user_stmt = "INSERT INTO users (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING;" + await conn.execute(user_stmt, member.id) + guild_stmt = "INSERT INTO guilds (guild_id) VALUES ($2) ON CONFLICT (guild_id) DO NOTHING;" + await conn.execute(guild_stmt, member.guild.id) + member_stmt = ( + "INSERT INTO members (guild_id, user_id) VALUES ($1, $2) ON CONFLICT (guild_id, user_id) DO NOTHING;" + ) + await conn.execute(member_stmt, member.id, member.guild.id) - snowball_upsert_query = """ + snowball_stmt = """\ INSERT INTO snowball_stats (user_id, guild_id, hits, misses, kos, stock) VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT (user_id, guild_id) DO UPDATE @@ -75,7 +79,7 @@ async def upsert_record( RETURNING *; """ args = member.id, member.guild.id, hits, misses, kos, max(stock, 0), stock - return cls.from_record(await conn.fetchrow(snowball_upsert_query, *args)) + return cls.from_record(await conn.fetchrow(snowball_stmt, *args)) class GuildSnowballSettings(msgspec.Struct): @@ -106,14 +110,13 @@ def from_record(cls: type[Self], record: asyncpg.Record) -> Self: async def from_database(cls: type[Self], conn: Pool_alias | Connection_alias, guild_id: int) -> Self: """Query a snowball settings database record for a guild.""" - query = """SELECT * FROM snowball_settings WHERE guild_id = $1;""" - record = await conn.fetchrow(query, guild_id) + record = await conn.fetchrow("SELECT * FROM snowball_settings WHERE guild_id = $1;", guild_id) return cls.from_record(record) if record else cls(guild_id) async def upsert_record(self, conn: Pool_alias | Connection_alias) -> None: """Upsert these snowball settings into the database.""" - query = """ + stmt = """\ INSERT INTO snowball_settings (guild_id, hit_odds, stock_cap, transfer_cap) VALUES ($1, $2, $3, $4) ON CONFLICT(guild_id) @@ -122,7 +125,7 @@ async def upsert_record(self, conn: Pool_alias | Connection_alias) -> None: stock_cap = EXCLUDED.stock_cap, transfer_cap = EXCLUDED.transfer_cap; """ - await conn.execute(query, self.guild_id, self.hit_odds, self.stock_cap, self.transfer_cap) + await conn.execute(stmt, self.guild_id, self.hit_odds, self.stock_cap, self.transfer_cap) class SnowballSettingsModal(discord.ui.Modal): @@ -187,26 +190,29 @@ async def on_submit(self, interaction: core.Interaction, /) -> None: # type: ig new_odds_val = self.default_settings.hit_odds try: temp = float(self.hit_odds_input.value) - if 0.0 <= temp <= 1.0: - new_odds_val = temp except ValueError: pass + else: + if 0.0 <= temp <= 1.0: + new_odds_val = temp new_stock_val = self.default_settings.stock_cap try: temp = int(self.stock_cap_input.value) - if temp >= 0: - new_stock_val = temp except ValueError: pass + else: + if temp >= 0: + new_stock_val = temp new_transfer_val = self.default_settings.transfer_cap try: temp = int(self.transfer_cap_input.value) - if temp >= 0: - new_transfer_val = temp except ValueError: pass + else: + if temp >= 0: + new_transfer_val = temp # Update the record in the database if there's been a change. self.new_settings = GuildSnowballSettings(guild_id, new_odds_val, new_stock_val, new_transfer_val) diff --git a/exts/starkid.py b/exts/starkid.py index 710b1bf..a2730d3 100644 --- a/exts/starkid.py +++ b/exts/starkid.py @@ -1,11 +1,8 @@ -""" -starkid.py: A cog for StarKid-related commands and functionality. +"""starkid.py: A cog for StarKid-related commands and functionality. Shoutout to Theo and Ali for inspiration, as well as the whole StarKid server. """ -from __future__ import annotations - import logging import discord diff --git a/exts/story_search.py b/exts/story_search.py index 93e846a..4278a6d 100644 --- a/exts/story_search.py +++ b/exts/story_search.py @@ -4,15 +4,13 @@ Currently supports most long-form ACI100 works and M J Bradley's A Cadmean Victory Remastered. """ -from __future__ import annotations - import importlib.resources -import importlib.resources.abc import logging import random import re import textwrap from bisect import bisect_left +from importlib.resources.abc import Traversable from typing import TYPE_CHECKING, ClassVar, Self import aiohttp @@ -40,6 +38,24 @@ def markdownify(html: str, **kwargs: object) -> str: ... assert AO3_EMOJI.id +class StoryInfo(msgspec.Struct): + """A class to hold all the information about each story.""" + + acronym: str + name: str + author: str + link: str + emoji_id: int | None = None + text: list[str] = msgspec.field(default_factory=list) + chapter_index: list[int] = msgspec.field(default_factory=list) + collection_index: list[int] = msgspec.field(default_factory=list) + + @classmethod + def from_record(cls, record: asyncpg.Record) -> Self: + attrs_ = ("story_acronym", "story_full_name", "author_name", "story_link", "emoji_id") + return cls(*(record[attr] for attr in attrs_)) + + @async_lru.alru_cache(ttl=300) async def get_ao3_story_html(session: aiohttp.ClientSession, url: str) -> lxml.html.HtmlElement | None: async with session.get(url) as response: @@ -110,24 +126,6 @@ def find_keywords_in_ao3_story( return StoryInfo("NONE", title, author, url, AO3_EMOJI.id), results -class StoryInfo(msgspec.Struct): - """A class to hold all the information about each story.""" - - acronym: str - name: str - author: str - link: str - emoji_id: int | None = None - text: list[str] = msgspec.field(default_factory=list) - chapter_index: list[int] = msgspec.field(default_factory=list) - collection_index: list[int] = msgspec.field(default_factory=list) - - @classmethod - def from_record(cls, record: asyncpg.Record) -> Self: - attrs_ = ("story_acronym", "story_full_name", "author_name", "story_link", "emoji_id") - return cls(*(record[attr] for attr in attrs_)) - - class AO3StoryHtmlData(msgspec.Struct): url: str title: str @@ -285,7 +283,7 @@ async def cog_command_error(self, ctx: core.Context, error: Exception) -> None: LOGGER.exception("", exc_info=error) @classmethod - def load_story_text(cls, filepath: importlib.resources.abc.Traversable) -> None: + def load_story_text(cls, filepath: Traversable) -> None: """Load the story metadata and text.""" # Compile all necessary regex patterns. diff --git a/exts/timing.py b/exts/timing.py index 0ec8a91..91fac11 100644 --- a/exts/timing.py +++ b/exts/timing.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import datetime from typing import NotRequired, TypedDict from zoneinfo import ZoneInfo, ZoneInfoNotFoundError @@ -115,14 +113,14 @@ async def timezone_set(self, ctx: core.Context, tz: str) -> None: except ZoneInfoNotFoundError: await ctx.send("That's an invalid time zone.") else: - query = """\ + stmt = """\ INSERT INTO users (user_id, timezone) VALUES ($1, $2) ON CONFLICT (user_id) DO UPDATE SET timezone = EXCLUDED.timezone; """ - await ctx.db.execute(query, ctx.author.id, zone) + await ctx.db.execute(stmt, ctx.author.id, zone) self.bot.get_user_timezone.cache_invalidate(ctx.author.id) await ctx.send( f"Your timezone has been set to {tz} (CLDR name: {self.timezone_aliases[tz]}).", @@ -133,8 +131,7 @@ async def timezone_set(self, ctx: core.Context, tz: str) -> None: async def timezone_clear(self, ctx: core.Context) -> None: """Clear your timezone.""" - query = "UPDATE users SET timezone = NULL WHERE user_id = $1;" - await ctx.db.execute(query, ctx.author.id) + await ctx.db.execute("UPDATE users SET timezone = NULL WHERE user_id = $1;", ctx.author.id) self.bot.get_user_timezone.cache_invalidate(ctx.author.id) await ctx.send("Your timezone has been cleared.", ephemeral=True) diff --git a/exts/todo.py b/exts/todo.py index 58d9080..1e77e8e 100644 --- a/exts/todo.py +++ b/exts/todo.py @@ -2,8 +2,6 @@ todo.py: A module/cog for handling todo lists made in Discord and stored in a database. """ -from __future__ import annotations - import datetime import logging import textwrap @@ -63,9 +61,9 @@ async def change_completion(self, conn: Pool_alias | Connection_alias) -> Self: The connection/pool that will be used to make this database command. """ - command = "UPDATE todos SET todo_completed_at = $1 WHERE todo_id = $2 RETURNING *;" + stmt = "UPDATE todos SET todo_completed_at = $1 WHERE todo_id = $2 RETURNING *;" new_date = discord.utils.utcnow() if self.completed_at is None else None - record = await conn.fetchrow(command, new_date, self.todo_id) + record = await conn.fetchrow(stmt, new_date, self.todo_id) return self.from_record(record) if record else type(self).generate_deleted() async def update(self, conn: Pool_alias | Connection_alias, updated_content: str) -> Self: @@ -81,8 +79,8 @@ async def update(self, conn: Pool_alias | Connection_alias, updated_content: str The new to-do content. """ - command = "UPDATE todos SET todo_content = $1 WHERE todo_id = $2 RETURNING *;" - record = await conn.fetchrow(command, updated_content, self.todo_id) + stmt = "UPDATE todos SET todo_content = $1 WHERE todo_id = $2 RETURNING *;" + record = await conn.fetchrow(stmt, updated_content, self.todo_id) return self.from_record(record) if record else type(self).generate_deleted() async def delete(self, conn: Pool_alias | Connection_alias) -> None: @@ -94,8 +92,7 @@ async def delete(self, conn: Pool_alias | Connection_alias) -> None: The connection/pool that will be used to make this database command. """ - command = "DELETE FROM todos where todo_id = $1;" - await conn.execute(command, self.todo_id) + await conn.execute("DELETE FROM todos where todo_id = $1;", self.todo_id) def display_embed(self, *, to_be_deleted: bool = False) -> discord.Embed: """Generates a formatted embed from a to-do record. @@ -478,8 +475,8 @@ async def todo_add(self, ctx: core.Context, content: str) -> None: await ctx.send("Content is too long. Please keep to within 2000 characters.") return - command = "INSERT INTO todos (user_id, todo_content) VALUES ($1, $2);" - await self.bot.db_pool.execute(command, ctx.author.id, content) + stmt = "INSERT INTO todos (user_id, todo_content) VALUES ($1, $2);" + await self.bot.db_pool.execute(stmt, ctx.author.id, content) await ctx.send("Todo added!", ephemeral=True) @todo.command("delete") @@ -494,16 +491,15 @@ async def todo_delete(self, ctx: core.Context, todo_id: int) -> None: The id of the task to do. """ - command = "DELETE FROM todos where todo_id = $1 and user_id = $2;" - await self.bot.db_pool.execute(command, todo_id, ctx.author.id) + stmt = "DELETE FROM todos where todo_id = $1 and user_id = $2;" + await self.bot.db_pool.execute(stmt, todo_id, ctx.author.id) await ctx.send(f"To-do item #{todo_id} has been removed.", ephemeral=True) @todo.command("clear") async def todo_clear(self, ctx: core.Context) -> None: """Clear all of your to-do items.""" - command = "DELETE FROM todos where user_id = $1;" - await self.bot.db_pool.execute(command, ctx.author.id) + await self.bot.db_pool.execute("DELETE FROM todos where user_id = $1;", ctx.author.id) await ctx.send("All of your todo items have been cleared.", ephemeral=True) @todo.command("show") diff --git a/exts/webhook_logging.py b/exts/webhook_logging.py index 3cb79e5..45d10c3 100644 --- a/exts/webhook_logging.py +++ b/exts/webhook_logging.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import discord from discord.ext import commands, tasks diff --git a/pyproject.toml b/pyproject.toml index 3c49df8..33e3346 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,9 +13,9 @@ authors = [ Homepage = "https://github.com/Sachaa-Thanasius/Beira" [tool.ruff] -include = ["main.py", "core/*", "exts/*", "**/pyproject.toml"] +include = ["main.py", "core/*", "exts/*", "**/pyproject.toml", "misc/**/*.py"] line-length = 120 -target-version = "py311" +target-version = "py312" [tool.ruff.lint] select = [ @@ -95,7 +95,7 @@ combine-as-imports = true [tool.pyright] include = ["main.py", "core", "exts"] -pythonVersion = "3.11" +pythonVersion = "3.12" typeCheckingMode = "strict" # reportImportCycles = "warning"