diff --git a/pyproject.toml b/pyproject.toml index 357479e..3f78a4c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "Beira" -version = "2024.06.22" +version = "2024.07.22" description = "An personal Discord bot made in Python." readme = "README.md" license = { file = "LICENSE" } @@ -13,7 +13,7 @@ authors = [ Homepage = "https://github.com/Sachaa-Thanasius/Beira" [tool.ruff] -include = ["src/beira/**/*.py", "misc/**/*.py"] +include = ["src/**/*.py", "misc/**/*.py"] line-length = 120 target-version = "py312" @@ -56,10 +56,11 @@ ignore = [ "SIM105", # Suppressable exception. contextlib.suppress is a stylistic choice with overhead. "ANN101", # Type of Self for self is usually implicit. "ANN102", # Type of type[Self] for cls is usually implicit. - "ANN204", # Special method return types are usually implicit or known by type checkers. + "ANN204", # Return types for magic methods are usually inferred or known. "ANN401", # Any is necessary sometimes. "PT001", # pytest recommends against empty parentheses on pytest.fixture. "UP038", # isinstance performs better with tuples than unions. + "RUF001", # Allow ambiguous characters. # == Recommended ignores by ruff when using ruff format. "E111", "E114", @@ -72,11 +73,8 @@ ignore = [ "ISC001", "ISC002", # == Project-specific ignores. - "S311", # No need for cryptographically secure number generation in this use case; it's just dice rolls. - "PLR0911", # Unlimited returns. - "PLR0912", # Unlimited branches. - "PLR0913", # Unlimited arguments. - "PYI036", # Bug with annotations for __(a)exit__ if a placeholder for a needed type exists in the else clause of an `if TYPE_CHECKING` block. + "S311", # No need for cryptographically secure number generation in this use case; it's just dice rolls. + # "PLR", # Allow complexity. ] unfixable = [ "ERA", # Don't want erroneous deletion of comments. @@ -94,11 +92,14 @@ lines-after-imports = 2 combine-as-imports = true [tool.pyright] -include = ["src/beira"] +include = ["src"] pythonVersion = "3.12" typeCheckingMode = "strict" -# reportImportCycles = "warning" +reportCallInDefaultInitializer = "warning" +reportImportCycles = "warning" reportPropertyTypeMismatch = "warning" +reportShadowedImports = "error" +reportUninitializedInstanceVariable = "warning" reportUnnecessaryTypeIgnoreComment = "warning" -enableExperimentalFeatures = true + diff --git a/src/beira/bot.py b/src/beira/bot.py index 23d192c..c24dc59 100644 --- a/src/beira/bot.py +++ b/src/beira/bot.py @@ -90,14 +90,13 @@ class Beira(commands.Bot): Arbitrary keyword arguments, primarily for `commands.Bot`. See that class for more information. """ - logging_manager: LoggingManager - def __init__( self, *args: Any, config: Config, db_pool: Pool_alias, web_session: aiohttp.ClientSession, + logging_manager: LoggingManager, initial_extensions: list[str] | None = None, **kwargs: Any, ) -> None: @@ -105,6 +104,7 @@ def __init__( self.config = config self.db_pool = db_pool self.web_session = web_session + self.logging_manager = logging_manager self.initial_extensions: list[str] = initial_extensions or [] # Various webfiction-related clients. @@ -160,7 +160,11 @@ async def get_context(self, origin: discord.Message | discord.Interaction, /) -> @overload async def get_context[ContextT: commands.Context[Any]]( - self, origin: discord.Message | discord.Interaction, /, *, cls: type[ContextT] + self, + origin: discord.Message | discord.Interaction, + /, + *, + cls: type[ContextT], ) -> ContextT: ... async def get_context[ContextT: commands.Context[Any]]( @@ -353,10 +357,10 @@ async def main() -> None: config=config, db_pool=pool, web_session=web_session, + logging_manager=logging_manager, intents=intents, tree_cls=HookableTree, ) as bot: - bot.logging_manager = logging_manager await bot.start(config.discord.token) # Needed for graceful exit? diff --git a/src/beira/checks.py b/src/beira/checks.py index 5082c1f..feb4729 100644 --- a/src/beira/checks.py +++ b/src/beira/checks.py @@ -12,7 +12,7 @@ if TYPE_CHECKING: - from discord.ext.commands._types import Check # type: ignore [reportMissingTypeStubs] + from discord.ext.commands._types import Check # pyright: ignore [reportMissingTypeStubs] class AppCheck(Protocol): diff --git a/src/beira/config.py b/src/beira/config.py index 62d5b31..55f4d7f 100644 --- a/src/beira/config.py +++ b/src/beira/config.py @@ -1,7 +1,6 @@ """For loading configuration information, such as api keys and tokens, default prefixes, etc.""" import pathlib -from typing import Any import msgspec @@ -9,36 +8,26 @@ __all__ = ("Config", "load_config") -class Base(msgspec.Struct): - """A base class to hold some common functions.""" - - def to_dict(self) -> dict[str, Any]: - return msgspec.structs.asdict(self) - - def to_tuple(self) -> tuple[Any, ...]: - return msgspec.structs.astuple(self) - - -class UserPassConfig(Base): +class UserPassConfig(msgspec.Struct): user: str password: str -class KeyConfig(Base): +class KeyConfig(msgspec.Struct): key: str -class SpotifyConfig(Base): +class SpotifyConfig(msgspec.Struct): client_id: str client_secret: str -class LavalinkConfig(Base): +class LavalinkConfig(msgspec.Struct): uri: str password: str -class PatreonConfig(Base): +class PatreonConfig(msgspec.Struct): client_id: str client_secret: str creator_access_token: str @@ -46,11 +35,11 @@ class PatreonConfig(Base): patreon_guild_id: int -class DatabaseConfig(Base): +class DatabaseConfig(msgspec.Struct): pg_url: str -class DiscordConfig(Base): +class DiscordConfig(msgspec.Struct): token: str default_prefix: str logging_webhook: str @@ -59,7 +48,7 @@ class DiscordConfig(Base): webhooks: list[str] = msgspec.field(default_factory=list) -class Config(Base): +class Config(msgspec.Struct): discord: DiscordConfig database: DatabaseConfig patreon: PatreonConfig @@ -72,10 +61,12 @@ class Config(Base): def decode(data: bytes | str) -> Config: - """Decode a TOMl file with the `Config` schema.""" + """Decode a TOML file with the Config schema.""" return msgspec.toml.decode(data, type=Config) def load_config() -> Config: + """Load the contents of a "config.toml" file into a Config struct.""" + return decode(pathlib.Path("config.toml").read_text(encoding="utf-8")) diff --git a/src/beira/errors.py b/src/beira/errors.py index 09b4fcb..38561f8 100644 --- a/src/beira/errors.py +++ b/src/beira/errors.py @@ -1,10 +1,15 @@ """Custom errors used by the bot.""" +from collections.abc import Callable, Coroutine +from typing import Any + +import discord from discord import app_commands from discord.ext import commands -AppCheckFunc = app_commands.commands.Check +# Copied from discord.app_commands.commands.Check. +type AppCheckFunc = Callable[[discord.Interaction[Any]], bool | Coroutine[Any, Any, bool]] __all__ = ( @@ -21,14 +26,14 @@ class CannotTargetSelf(commands.BadArgument): """Exception raised when the member provided as a target was also the command invoker. - This inherits from :exc:`commands.BadArgument`. + This inherits from commands.BadArgument. """ class NotOwnerOrFriend(commands.CheckFailure): """Exception raised when the message author is not the owner of the bot or on the special friends list. - This inherits from :exc:`CheckFailure`. + This inherits from CheckFailure. """ def __init__(self, message: str | None = None) -> None: @@ -38,7 +43,7 @@ def __init__(self, message: str | None = None) -> None: class NotAdmin(commands.CheckFailure): """Exception raised when the message author is not an administrator of the guild in the current context. - This inherits from :exc:`commands.CheckFailure`. + This inherits from commands.CheckFailure. """ def __init__(self, message: str | None = None) -> None: @@ -48,7 +53,7 @@ def __init__(self, message: str | None = None) -> None: class NotInBotVoiceChannel(commands.CheckFailure): """Exception raised when the message author is not in the same voice channel as the bot in a context's guild. - This inherits from :exc:`commands.CheckFailure`. + This inherits from commands.CheckFailure. """ def __init__(self, message: str | None = None) -> None: @@ -58,7 +63,7 @@ def __init__(self, message: str | None = None) -> None: class UserIsBlocked(commands.CheckFailure): """Exception raised when the message author is blocked from using the bot. - This inherits from :exc:`commands.CheckFailure`. + This inherits from commands.CheckFailure. """ def __init__(self, message: str | None = None) -> None: @@ -68,7 +73,7 @@ def __init__(self, message: str | None = None) -> None: class GuildIsBlocked(commands.CheckFailure): """Exception raised when the message guild is blocked from using the bot. - This inherits from :exc:`commands.CheckFailure`. + This inherits from commands.CheckFailure. """ def __init__(self, message: str | None = None) -> None: @@ -76,15 +81,15 @@ def __init__(self, message: str | None = None) -> None: class CheckAnyFailure(app_commands.CheckFailure): - """Exception raised when all predicates in :func:`check_any` fail. + """Exception raised when all predicates in `check_any` fail. - This inherits from :exc:`app_commands.CheckFailure`. + This inherits from app_commands.CheckFailure. Attributes - ------------ - errors: list[`app_commands.CheckFailure`] + ---------- + errors: list[app_commands.CheckFailure] A list of errors that were caught during execution. - checks: List[Callable[[`discord.Interaction`], `bool`]] + checks: List[Callable[[discord.Interaction], bool]] A list of check predicates that failed. """ diff --git a/src/beira/exts/dice.py b/src/beira/exts/dice.py index 4342016..1aec66e 100644 --- a/src/beira/exts/dice.py +++ b/src/beira/exts/dice.py @@ -120,7 +120,7 @@ def roll_custom_dice_expression(expression: str) -> tuple[str, int]: components.append(sum(rolls)) else: return f"(Invalid expression; expected number or dice expression, not {part!r})", 0 - else: # noqa: PLR5501 + else: if part == "-": operations.append(operator.sub) elif part == "+": diff --git a/src/beira/exts/patreon.py b/src/beira/exts/patreon.py index c2dc876..4c3c27c 100644 --- a/src/beira/exts/patreon.py +++ b/src/beira/exts/patreon.py @@ -102,6 +102,7 @@ def __init__(self, bot: beira.Beira) -> None: self.bot = bot self.access_token = bot.config.patreon.creator_access_token self.patrons_on_discord: dict[str, list[discord.Member]] = {} + self.patreon_tiers_info: list[PatreonTierInfo] = [] @property def cog_emoji(self) -> discord.PartialEmoji: diff --git a/src/beira/exts/timing.py b/src/beira/exts/timing.py index 7dbc6d0..9a45a22 100644 --- a/src/beira/exts/timing.py +++ b/src/beira/exts/timing.py @@ -81,6 +81,7 @@ async def parse_bcp47_timezones(session: aiohttp.ClientSession) -> dict[str, str class TimingCog(commands.Cog, name="Timing"): def __init__(self, bot: beira.Beira) -> None: self.bot = bot + self.timezone_aliases: dict[str, str] = {} async def cog_load(self) -> None: self.timezone_aliases: dict[str, str] = await parse_bcp47_timezones(self.bot.web_session) @@ -155,7 +156,9 @@ async def timezone_info(self, ctx: beira.Context, tz: str) -> None: @timezone_set.autocomplete("tz") @timezone_info.autocomplete("tz") async def timezone_autocomplete( - self, itx: beira.Interaction, current: str + self, + itx: beira.Interaction, + current: str, ) -> list[discord.app_commands.Choice[str]]: if not current: return [ diff --git a/src/beira/exts/triggers/misc_triggers.py b/src/beira/exts/triggers/misc_triggers.py index b3aaea6..e905461 100644 --- a/src/beira/exts/triggers/misc_triggers.py +++ b/src/beira/exts/triggers/misc_triggers.py @@ -13,8 +13,6 @@ import beira -LOGGER = logging.getLogger(__name__) - type ValidGuildChannel = ( discord.VoiceChannel | discord.StageChannel | discord.ForumChannel | discord.TextChannel | discord.CategoryChannel ) @@ -49,6 +47,8 @@ LOSSY_TWITTER_LINK_PATTERN = re.compile(r"(?:http(?:s)?://|(? None: diff --git a/src/beira/exts/triggers/rss_notifications.py b/src/beira/exts/triggers/rss_notifications.py index 1209d44..15bd44e 100644 --- a/src/beira/exts/triggers/rss_notifications.py +++ b/src/beira/exts/triggers/rss_notifications.py @@ -36,6 +36,7 @@ class RSSNotificationsCog(commands.Cog): def __init__(self, bot: beira.Beira) -> None: self.bot = bot + self.records: list[NotificationRecord] = [] # self.notification_check_loop.start() async def cog_unload(self) -> None: diff --git a/src/beira/tree.py b/src/beira/tree.py index 5ea38d8..ceda873 100644 --- a/src/beira/tree.py +++ b/src/beira/tree.py @@ -14,6 +14,7 @@ from discord.types.interactions import ApplicationCommandInteractionData from typing_extensions import TypeVar + # Copied from discord._types.ClientT. ClientT_co = TypeVar("ClientT_co", bound=Client, covariant=True, default=Client) else: from typing import TypeVar @@ -61,7 +62,7 @@ def before_app_invoke[GroupT: (Group | commands.Cog), **P, T]( raise TypeError(msg) def decorator(inner: Command[GroupT, P, T]) -> Command[GroupT, P, T]: - inner._before_invoke = coro # type: ignore # Runtime attribute assignment. + inner._before_invoke = coro # pyright: ignore # Runtime attribute assignment. return inner return decorator @@ -96,7 +97,7 @@ def after_app_invoke[GroupT: (Group | commands.Cog), **P, T]( raise TypeError(msg) def decorator(inner: Command[GroupT, P, T]) -> Command[GroupT, P, T]: - inner._after_invoke = coro # type: ignore # Runtime attribute assignment. + inner._after_invoke = coro # pyright: ignore # Runtime attribute assignment. return inner return decorator @@ -135,12 +136,11 @@ async def on_error(self, interaction: Interaction[ClientT_co], error: AppCommand async def _call(self, interaction: Interaction[ClientT_co]) -> None: # ---- Copy the original logic but add hook checks/calls near the end. - if not await self.interaction_check(interaction): interaction.command_failed = True return - data: ApplicationCommandInteractionData = interaction.data # type: ignore + data: ApplicationCommandInteractionData = interaction.data # pyright: ignore [reportAssignmentType] type_ = data.get("type", 1) if type_ != 1: # Context menu command... @@ -150,14 +150,14 @@ async def _call(self, interaction: Interaction[ClientT_co]) -> None: command, options = self._get_app_command_options(data) # Pre-fill the cached slot to prevent re-computation - interaction._cs_command = command # type: ignore # Protected + interaction._cs_command = command # pyright: ignore [reportPrivateUsage] # At this point options refers to the arguments of the command # and command refers to the class type we care about namespace = Namespace(interaction, data.get("resolved", {}), options) # Same pre-fill as above - interaction._cs_namespace = namespace # type: ignore # Protected + interaction._cs_namespace = namespace # pyright: ignore [reportPrivateUsage] # Auto complete handles the namespace differently... so at this point this is where we decide where that is. if interaction.type is discord.enums.InteractionType.autocomplete: @@ -167,10 +167,10 @@ async def _call(self, interaction: Interaction[ClientT_co]) -> None: raise AppCommandError(msg) try: - await command._invoke_autocomplete(interaction, focused, namespace) # type: ignore # Protected - except Exception: # noqa: S110, BLE001 + await command._invoke_autocomplete(interaction, focused, namespace) # pyright: ignore [reportPrivateUsage] + except Exception: # Suppress exception since it can't be handled anyway. - pass + LOGGER.exception("Ignoring exception in autocomplete for %r", command.qualified_name) return @@ -178,29 +178,25 @@ async def _call(self, interaction: Interaction[ClientT_co]) -> None: # Pre-command hooks are run before actual command-specific checks, unlike prefix commands. # It doesn't really make sense, but the only solution seems to be monkey-patching # Command._invoke_with_namespace, which doesn't seem feasible. - before_invoke = getattr(command, "_before_invoke", None) - if before_invoke: - instance = getattr(before_invoke, "__self__", None) - if instance: + if before_invoke := getattr(command, "_before_invoke", None): + if instance := getattr(before_invoke, "__self__", None): await before_invoke(instance, interaction) else: await before_invoke(interaction) try: - await command._invoke_with_namespace(interaction, namespace) # type: ignore # Protected + await command._invoke_with_namespace(interaction, namespace) # pyright: ignore [reportPrivateUsage] except AppCommandError as e: interaction.command_failed = True - await command._invoke_error_handlers(interaction, e) # type: ignore # Protected + await command._invoke_error_handlers(interaction, e) # pyright: ignore [reportPrivateUsage] await self.on_error(interaction, e) else: if not interaction.command_failed: self.client.dispatch("app_command_completion", interaction, command) finally: # -- Look for a post-command hook. - after_invoke = getattr(command, "_after_invoke", None) - if after_invoke: - instance = getattr(after_invoke, "__self__", None) - if instance: + if after_invoke := getattr(command, "_after_invoke", None): + if instance := getattr(after_invoke, "__self__", None): await after_invoke(instance, interaction) else: await after_invoke(interaction) diff --git a/src/beira/utils/db.py b/src/beira/utils/db.py index c46710f..d9fe7c6 100644 --- a/src/beira/utils/db.py +++ b/src/beira/utils/db.py @@ -7,11 +7,7 @@ from asyncpg.pool import PoolConnectionProxy -__all__ = ( - "Connection_alias", - "Pool_alias", - "conn_init", -) +__all__ = ("Connection_alias", "Pool_alias", "conn_init") if TYPE_CHECKING: type Connection_alias = Connection[Record] | PoolConnectionProxy[Record] diff --git a/src/beira/utils/embeds.py b/src/beira/utils/embeds.py index d43a26b..8a5c8b1 100644 --- a/src/beira/utils/embeds.py +++ b/src/beira/utils/embeds.py @@ -1,15 +1,12 @@ """Embed-related helpers, e.g. a class for displaying user-specific statistics separated into fields.""" import itertools -import logging from collections.abc import Iterable, Sequence from typing import Self import discord -LOGGER = logging.getLogger(__name__) - type _AnyEmoji = discord.Emoji | discord.PartialEmoji | str diff --git a/src/beira/utils/extras/__init__.py b/src/beira/utils/extras/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/beira/utils/extras/formats.py b/src/beira/utils/extras/formats.py deleted file mode 100644 index f51f558..0000000 --- a/src/beira/utils/extras/formats.py +++ /dev/null @@ -1,26 +0,0 @@ -from __future__ import annotations - -from collections.abc import Sequence - - -class plural: - def __init__(self, value: int) -> None: - self.value: int = value - - def __format__(self, format_spec: str) -> str: - v = self.value - singular, _, plural = format_spec.partition("|") - plural = plural or f"{singular}s" - return f"{v} {plural if (abs(v) != 1) else singular}" - - -def human_join(seq: Sequence[str], delim: str = ", ", final: str = "or") -> str: - size = len(seq) - if size == 0: - return "" - if size == 1: - return seq[0] - if size == 2: - return f"{seq[0]} {final} {seq[1]}" - - return delim.join(seq[:-1]) + f" {final} {seq[-1]}" diff --git a/src/beira/utils/extras/time.py b/src/beira/utils/extras/time.py deleted file mode 100644 index 0ad56e5..0000000 --- a/src/beira/utils/extras/time.py +++ /dev/null @@ -1,420 +0,0 @@ -from __future__ import annotations - -import datetime -import re -from typing import Any, Self -from zoneinfo import ZoneInfo - -import discord -import parsedatetime as pdt -from dateutil.relativedelta import relativedelta -from discord.ext import commands - -import beira - -from .formats import human_join, plural - - -# Monkey patch mins and secs into the units -units = pdt.pdtLocales["en_US"].units -units["minutes"].append("mins") -units["seconds"].append("secs") - - -UTC = ZoneInfo("UTC") - - -class ShortTime: - COMPILED = re.compile( - """ - (?:(?P[0-9])(?:years?|y))? # e.g. 2y - (?:(?P[0-9]{1,2})(?:months?|mon?))? # e.g. 2months - (?:(?P[0-9]{1,4})(?:weeks?|w))? # e.g. 10w - (?:(?P[0-9]{1,5})(?:days?|d))? # e.g. 14d - (?:(?P[0-9]{1,5})(?:hours?|hr?s?))? # e.g. 12h - (?:(?P[0-9]{1,5})(?:minutes?|m(?:ins?)?))? # e.g. 10m - (?:(?P[0-9]{1,5})(?:seconds?|s(?:ecs?)?))? # e.g. 15s - """, - re.VERBOSE, - ) - - DISCORD_FMT = re.compile(r"[0-9]+)(?:\:?[RFfDdTt])?>") - - def __init__( - self, - argument: str, - *, - now: datetime.datetime | None = None, - tzinfo: datetime.tzinfo = UTC, - ): - match = self.COMPILED.fullmatch(argument) - if match is None or not match.group(0): - match = self.DISCORD_FMT.fullmatch(argument) - if match is not None: - self.dt = datetime.datetime.fromtimestamp(int(match.group("ts")), tz=UTC) - - if tzinfo not in {datetime.UTC, UTC}: - self.dt = self.dt.astimezone(tzinfo) - return - - msg = "invalid time provided" - raise commands.BadArgument(msg) - - data = {k: int(v) for k, v in match.groupdict(default=0).items()} - now = now or datetime.datetime.now(UTC) - self.dt = now + relativedelta(**data) # type: ignore # None of the regex groups currently fill the date fields. - if tzinfo not in {datetime.UTC, UTC}: - self.dt = self.dt.astimezone(tzinfo) - - @classmethod - async def convert(cls, ctx: beira.Context, argument: str) -> Self: - tzinfo = await ctx.bot.get_user_tzinfo(ctx.author.id) - return cls(argument, now=ctx.message.created_at, tzinfo=tzinfo) - - -class RelativeDelta(discord.app_commands.Transformer, commands.Converter[relativedelta]): - @classmethod - def __do_conversion(cls, argument: str) -> relativedelta: - match = ShortTime.COMPILED.fullmatch(argument) - if match is None or not match.group(0): - msg = "invalid time provided" - raise ValueError(msg) - - data = {k: int(v) for k, v in match.groupdict(default=0).items()} - return relativedelta(**data) # type: ignore # None of the regex groups currently fill the date fields. - - async def convert(self, ctx: beira.Context, argument: str) -> relativedelta: # type: ignore # Custom context. - try: - return self.__do_conversion(argument) - except ValueError as e: - raise commands.BadArgument(str(e)) from None - - async def transform(self, interaction: discord.Interaction, value: str) -> relativedelta: - try: - return self.__do_conversion(value) - except ValueError as e: - raise discord.app_commands.AppCommandError(str(e)) from None - - -class HumanTime: - calendar = pdt.Calendar(version=pdt.VERSION_CONTEXT_STYLE) - - def __init__( - self, - argument: str, - *, - now: datetime.datetime | None = None, - tzinfo: datetime.tzinfo = UTC, - ): - now = now or datetime.datetime.now(tzinfo) - dt, status = self.calendar.parseDT(argument, sourceTime=now, tzinfo=None) - - assert isinstance(status, pdt.pdtContext) - - if not status.hasDateOrTime: - msg = 'invalid time provided, try e.g. "tomorrow" or "3 days"' - raise commands.BadArgument(msg) - - if not status.hasTime: - # replace it with the current time - dt = dt.replace(hour=now.hour, minute=now.minute, second=now.second, microsecond=now.microsecond) - - self.dt: datetime.datetime = dt.replace(tzinfo=tzinfo) - if now.tzinfo is None: - now = now.replace(tzinfo=UTC) - self._past: bool = self.dt < now - - @classmethod - async def convert(cls, ctx: beira.Context, argument: str) -> Self: - tzinfo = await ctx.bot.get_user_tzinfo(ctx.author.id) - return cls(argument, now=ctx.message.created_at, tzinfo=tzinfo) - - -class Time(HumanTime): - def __init__( - self, - argument: str, - *, - now: datetime.datetime | None = None, - tzinfo: datetime.tzinfo = UTC, - ): - try: - o = ShortTime(argument, now=now, tzinfo=tzinfo) - except Exception: # noqa: BLE001 - super().__init__(argument, now=now, tzinfo=tzinfo) - else: - self.dt = o.dt - self._past = False - - -class FutureTime(Time): - def __init__( - self, - argument: str, - *, - now: datetime.datetime | None = None, - tzinfo: datetime.tzinfo = UTC, - ): - super().__init__(argument, now=now, tzinfo=tzinfo) - - if self._past: - msg = "this time is in the past" - raise commands.BadArgument(msg) - - -class BadTimeTransform(discord.app_commands.AppCommandError): - pass - - -class TimeTransformer(discord.app_commands.Transformer): - async def transform(self, interaction: discord.Interaction[beira.Beira], value: str) -> datetime.datetime: - tzinfo = await interaction.client.get_user_tzinfo(interaction.user.id) - - now = interaction.created_at.astimezone(tzinfo) - try: - short = ShortTime(value, now=now, tzinfo=tzinfo) - except commands.BadArgument: - try: - human = FutureTime(value, now=now, tzinfo=tzinfo) - except commands.BadArgument as e: - raise BadTimeTransform(str(e)) from None - else: - return human.dt - else: - return short.dt - - -class FriendlyTimeResult: - __slots__ = ("dt", "arg") - - def __init__(self, dt: datetime.datetime): - self.dt: datetime.datetime = dt - self.arg: str = "" - - async def ensure_constraints( - self, - ctx: beira.Context, - uft: UserFriendlyTime, - now: datetime.datetime, - remaining: str, - ) -> None: - if self.dt < now: - msg = "This time is in the past." - raise commands.BadArgument(msg) - - if not remaining: - if uft.default is None: - msg = "Missing argument after the time." - raise commands.BadArgument(msg) - remaining = uft.default - - if uft.converter is not None: - self.arg = await uft.converter.convert(ctx, remaining) - else: - self.arg = remaining - - -class UserFriendlyTime(commands.Converter[FriendlyTimeResult]): - """That way quotes aren't absolutely necessary.""" - - def __init__( - self, - converter: type[commands.Converter[str]] | commands.Converter[str] | None = None, - *, - default: Any = None, - ): - if issubclass(converter, commands.Converter): # type: ignore [reportUnnecessaryIsInstance] - converter = converter() - - if converter is not None and not isinstance(converter, commands.Converter): # type: ignore [reportUnnecessaryIsInstance] - msg = "commands.Converter subclass necessary." - raise TypeError(msg) - - self.converter: commands.Converter[str] | None = converter - self.default: Any = default - - async def convert(self, ctx: beira.Context, argument: str) -> FriendlyTimeResult: # type: ignore # Custom context. # noqa: PLR0915 - calendar = HumanTime.calendar - regex = ShortTime.COMPILED - now = ctx.message.created_at - - tzinfo = await ctx.bot.get_user_tzinfo(ctx.author.id) - - assert isinstance(tzinfo, datetime.tzinfo) - - match = regex.match(argument) - if match is not None and match.group(0): - data = {k: int(v) for k, v in match.groupdict(default=0).items()} - remaining = argument[match.end() :].strip() - dt = now + relativedelta(**data) # type: ignore # None of the regex groups currently fill the date fields. - result = FriendlyTimeResult(dt.astimezone(tzinfo)) - await result.ensure_constraints(ctx, self, now, remaining) - return result - - if match is None or not match.group(0): - match = ShortTime.DISCORD_FMT.match(argument) - if match is not None: - result = FriendlyTimeResult( - datetime.datetime.fromtimestamp(int(match.group("ts")), tz=UTC).astimezone(tzinfo) - ) - remaining = argument[match.end() :].strip() - await result.ensure_constraints(ctx, self, now, remaining) - return result - - # apparently nlp does not like "from now" - # it likes "from x" in other cases though so let me handle the 'now' case - if argument.endswith("from now"): - argument = argument[:-8].strip() - - if argument[0:2] == "me" and argument[0:6] in ("me to ", "me in ", "me at "): - argument = argument[6:] - - # Have to adjust the timezone so pdt knows how to handle things like "tomorrow at 6pm" in an aware way - now = now.astimezone(tzinfo) - elements = calendar.nlp(argument, sourceTime=now) - if elements is None or len(elements) == 0: - msg = 'Invalid time provided, try e.g. "tomorrow" or "3 days".' - raise commands.BadArgument(msg) - - # handle the following cases: - # "date time" foo - # date time foo - # foo date time - - # first the first two cases: - dt, status, begin, end, _dt_string = elements[0] - assert isinstance(status, pdt.pdtContext) - - if not status.hasDateOrTime: - msg = 'Invalid time provided, try e.g. "tomorrow" or "3 days".' - raise commands.BadArgument(msg) - - if begin not in (0, 1) and end != len(argument): - msg = ( - "Time is either in an inappropriate location, which " - "must be either at the end or beginning of your input, " - "or I just flat out did not understand what you meant. Sorry." - ) - raise commands.BadArgument(msg) - - dt = dt.replace(tzinfo=tzinfo) - if not status.hasTime: - # replace it with the current time - dt = dt.replace(hour=now.hour, minute=now.minute, second=now.second, microsecond=now.microsecond) - - if status.hasTime and not status.hasDate and dt < now: - # if it's in the past, and it has a time but no date, - # assume it's for the next occurrence of that time - dt = dt + datetime.timedelta(days=1) - - # if midnight is provided, just default to next day - if status.accuracy == pdt.pdtContext.ACU_HALFDAY: - dt = dt + datetime.timedelta(days=1) - - result = FriendlyTimeResult(dt) - remaining = "" - - if begin in (0, 1): - if begin == 1: - # check if it's quoted: - if argument[0] != '"': - msg = "Expected quote before time input..." - raise commands.BadArgument(msg) - - if not (end < len(argument) and argument[end] == '"'): - msg = "If the time is quoted, you must unquote it." - raise commands.BadArgument(msg) - - remaining = argument[end + 1 :].lstrip(" ,.!") - else: - remaining = argument[end:].lstrip(" ,.!") - elif len(argument) == end: - remaining = argument[:begin].strip() - - await result.ensure_constraints(ctx, self, now, remaining) - return result - - -def human_timedelta( - dt: datetime.datetime, - *, - source: datetime.datetime | None = None, - accuracy: int | None = 3, - brief: bool = False, - suffix: bool = True, -) -> str: - now = source or datetime.datetime.now(UTC) - if dt.tzinfo is None: - dt = dt.replace(tzinfo=UTC) - - if now.tzinfo is None: - now = now.replace(tzinfo=UTC) - - # Microsecond free zone - now = now.replace(microsecond=0) - dt = dt.replace(microsecond=0) - - # Make sure they're both in the timezone - now = now.astimezone(UTC) - dt = dt.astimezone(UTC) - - # This implementation uses relativedelta instead of the much more obvious - # divmod approach with seconds because the seconds approach is not entirely - # accurate once you go over 1 week in terms of accuracy since you have to - # hardcode a month as 30 or 31 days. - # A query like "11 months" can be interpreted as "!1 months and 6 days" - if dt > now: - delta = relativedelta(dt, now) - output_suffix = "" - else: - delta = relativedelta(now, dt) - output_suffix = " ago" if suffix else "" - - attrs = [ - ("year", "y"), - ("month", "mo"), - ("day", "d"), - ("hour", "h"), - ("minute", "m"), - ("second", "s"), - ] - - output: list[str] = [] - for attr, brief_attr in attrs: - elem = getattr(delta, attr + "s") - if not elem: - continue - - if attr == "day": - weeks = delta.weeks - if weeks: - elem -= weeks * 7 - if not brief: - output.append(format(plural(weeks), "week")) - else: - output.append(f"{weeks}w") - - if elem <= 0: - continue - - if brief: - output.append(f"{elem}{brief_attr}") - else: - output.append(format(plural(elem), attr)) - - if accuracy is not None: - output = output[:accuracy] - - if len(output) == 0: - return "now" - - if not brief: - return human_join(output, final="and") + output_suffix - - return " ".join(output) + output_suffix - - -def format_relative(dt: datetime.datetime) -> str: - if dt.tzinfo is None: - dt = dt.replace(tzinfo=UTC) - return discord.utils.format_dt(dt, "R") diff --git a/src/beira/utils/log.py b/src/beira/utils/log.py index c38ce80..f2392fb 100644 --- a/src/beira/utils/log.py +++ b/src/beira/utils/log.py @@ -30,28 +30,28 @@ class LoggingManager: Parameters ---------- - stream: `bool`, default=True + stream: bool, default=True Whether the logs should be output to a stream. Defaults to True. Attributes ---------- - log: `logging.Logger` + log: logging.Logger The primary bot handler. - max_bytes: `int` + max_bytes: int The maximum size of each log file. - logging_path: `Path` + logging_path: Path A path to the directory for all log files. - stream: `bool` + stream: bool A boolean indicating whether the logs should be output to a stream. - log_queue: `asyncio.Queue[logging.LogRecord]` + log_queue: asyncio.Queue[logging.LogRecord] An asyncio queue with logs to send to a logging webhook. """ def __init__(self, *, stream: bool = True) -> None: self.log = logging.getLogger() self.max_bytes = 32 * 1024 * 1024 # 32MiB - self.logging_path = Path("./logs/") - self.logging_path.mkdir(exist_ok=True) + self.logs_path = Path("./logs/") + self.logs_path.mkdir(exist_ok=True) self.stream = stream self.log_queue: asyncio.Queue[logging.LogRecord] = asyncio.Queue() @@ -61,15 +61,15 @@ async def __aenter__(self) -> Self: def __enter__(self) -> Self: """Set and customize loggers.""" - logging.getLogger("wavelink").setLevel(logging.INFO) logging.getLogger("discord").setLevel(logging.INFO) logging.getLogger("discord.http").setLevel(logging.INFO) logging.getLogger("discord.state").addFilter(RemoveNoise()) + logging.getLogger("wavelink").setLevel(logging.INFO) self.log.setLevel(logging.INFO) # Add a file handler. handler = RotatingFileHandler( - filename=self.logging_path / "Beira.log", + filename=self.logs_path / "Beira.log", encoding="utf-8", mode="w", maxBytes=self.max_bytes, diff --git a/src/beira/utils/misc.py b/src/beira/utils/misc.py index 234e1f0..824bc2e 100644 --- a/src/beira/utils/misc.py +++ b/src/beira/utils/misc.py @@ -18,12 +18,13 @@ class catchtime: Parameters ---------- - logger: `logging.Logger`, optional - The logging channel to send the time to, if relevant. Optional. + logger: logging.Logger, optional + The logging channel to send the time to, if provided. Optional. """ def __init__(self, logger: logging.Logger | None = None): self.logger = logger + self.elapsed = 0.0 def __enter__(self): self.elapsed = time.perf_counter() diff --git a/src/beira/utils/pagination.py b/src/beira/utils/pagination.py index 1bc005d..cc0cb07 100644 --- a/src/beira/utils/pagination.py +++ b/src/beira/utils/pagination.py @@ -52,9 +52,9 @@ class OwnedView(discord.ui.View): Parameters ---------- - author: `int` + author: int The Discord ID of the user that triggered this view. No one else can use it. - timeout: `float` | None, optional + timeout: float | None, optional Timeout in seconds from last interaction with the UI before no longer accepting input. If ``None`` then there is no timeout. """ @@ -77,11 +77,11 @@ class PageSeekModal(discord.ui.Modal, title="Page Jump"): Attributes ---------- - input_page_num: `TextInput` + input_page_num: TextInput A UI text input element to allow users to enter a page number. - parent: `PaginatedEmbedView` + parent: PaginatedEmbedView The paginated view that this modal was called from. - interaction: `discord.Interaction` + interaction: discord.Interaction The interaction of the user with the modal. Only populates on submission. """ @@ -115,27 +115,27 @@ class PaginatedEmbedView[_LT](abc.ABC, OwnedView): Parameters ---------- - author_id: `int` + author_id: int The Discord ID of the user that triggered this view. No one else can use it. pages_content: list[Any] The content for every possible page. - per: `int` + per: int The number of entries to be displayed per page. - timeout: `float`, optional + timeout: float, optional Timeout in seconds from last interaction with the UI before no longer accepting input. If ``None`` then there is no timeout. Attributes ---------- - message: `discord.Message` - The message to which the view is attached to, allowing interaction without a `discord.Interaction`. - per_page: `int` + message: discord.Message + The message to which the view is attached to, allowing interaction without a discord.Interaction. + per_page: int The number of entries to be displayed per page. pages: list[Any] A list of content for pages, split according to how much content is wanted per page. - page_index: `int` + page_index: int The index for the current page. - page_modal_strings: tuple[`str`, ...], default=() + page_modal_strings: tuple[str, ...], default=() Tuple of strings to modify the page seek modal with if necessary. Empty by default. total_pages """ @@ -155,7 +155,7 @@ def __init__(self, author_id: int, pages_content: list[_LT], per: int = 1, *, ti @property def total_pages(self) -> int: - """``int`: The total number of pages.""" + """int: The total number of pages.""" return len(self.pages) @@ -291,21 +291,21 @@ class PaginatedSelectView[_LT](abc.ABC, OwnedView): Parameters ---------- - author_id: `int` + author_id: int The Discord ID of the user that triggered this view. No one else can use it. pages_content: Sequence[Any] The content for every possible page. - timeout: `float` | None, optional + timeout: float | None, optional Timeout in seconds from last interaction with the UI before no longer accepting input. If ``None`` then there is no timeout. Attributes ---------- - message: `discord.Message` - The message to which the view is attached to, allowing interaction without a `discord.Interaction`. + message: discord.Message + The message to which the view is attached to, allowing interaction without a discord.Interaction. pages: list[Any] A list of content for pages. - page_index: `int` + page_index: int The index for the current page. total_pages """ @@ -324,7 +324,7 @@ def __init__(self, author_id: int, pages_content: Sequence[_LT], *, timeout: flo @property def total_pages(self) -> int: - """``int`: The total number of pages.""" + """int: The total number of pages.""" return len(self.pages) diff --git a/src/beira/utils/extras/scheduler.py b/src/beira/utils/scheduler.py similarity index 86% rename from src/beira/utils/extras/scheduler.py rename to src/beira/utils/scheduler.py index eba12aa..1971e6e 100644 --- a/src/beira/utils/extras/scheduler.py +++ b/src/beira/utils/scheduler.py @@ -10,19 +10,19 @@ # endregion import asyncio +import random +import time +from collections.abc import Callable from datetime import datetime, timedelta -from itertools import count -from types import TracebackType from typing import Protocol, Self -from uuid import uuid4 from warnings import warn from zoneinfo import ZoneInfo import asyncpg -from msgspec import Struct, field +from msgspec import Struct from msgspec.json import decode as json_decode, encode as json_encode -from ..db import Connection_alias # noqa: TID252 +from .db import Connection_alias # noqa: TID252 class BotLike(Protocol): @@ -31,19 +31,70 @@ def dispatch(self, event_name: str, /, *args: object, **kwargs: object) -> None: async def wait_until_ready(self) -> None: ... +def _uuid7gen() -> Callable[[], str]: + """UUIDv7 has been accepted as part of rfc9562 + + This is intended to be a compliant implementation, but I am not advertising it + in public, exported APIs as such *yet* + + In particular, this is: + UUIDv7 as described in rfc9562 section 5.7 utilizing the + optional sub-millisecond timestamp fraction described in section 6.2 method 3 + """ + _last_timestamp: int | None = None + + def uuid7() -> str: + """This is unique identifer generator + + This was chosen to increase performance of indexing and + to pick something likely to get specific database support + for this to be a portably efficient choice should someone + decide to have this be backed by something other than sqlite + + This should not be relied on as always generating valid UUIDs of + any version or variant at this time. The current intent is that + this is a UUIDv7 in str form, but this should not be relied + on outside of this library and may be changed in the future for + better performance within this library. + """ + nonlocal _last_timestamp + nanoseconds = time.time_ns() + if _last_timestamp is not None and nanoseconds <= _last_timestamp: + nanoseconds = _last_timestamp + 1 + _last_timestamp = nanoseconds + timestamp_s, timestamp_ns = divmod(nanoseconds, 10**9) + subsec_a = timestamp_ns >> 18 + subsec_b = (timestamp_ns >> 6) & 0x0FFF + subsec_seq_node = (timestamp_ns & 0x3F) << 56 + subsec_seq_node += random.SystemRandom().getrandbits(56) + uuid_int = (timestamp_s & 0x0FFFFFFFFF) << 92 + uuid_int += subsec_a << 80 + uuid_int += subsec_b << 64 + uuid_int += subsec_seq_node + uuid_int &= ~(0xC000 << 48) + uuid_int |= 0x8000 << 48 + uuid_int &= ~(0xF000 << 64) + uuid_int |= 7 << 76 + return f"{uuid_int:032x}" + + return uuid7 + + +_uuid7 = _uuid7gen() + __all__ = ("DiscordBotScheduler", "ScheduledDispatch", "Scheduler") SQLROW_TYPE = tuple[str, str, str, str, int | None, int | None, bytes | None] DATE_FMT = r"%Y-%m-%d %H:%M" -_c = count() +# Requires a postgres extension: https://github.com/fboulnois/pg_uuidv7 INITIALIZATION_STATEMENTS = """ CREATE TABLE IF NOT EXISTS scheduled_dispatches ( - task_id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - dispatch_name TEXT NOT NULL, - dispatch_time TIMESTAMP WITH TIME ZONE NOT NULL, - dispatch_zone TEXT NOT NULL, + task_id UUID PRIMARY KEY DEFAULT uuid_generate_v7(), + dispatch_name TEXT NOT NULL, + dispatch_time TIMESTAMP NOT NULL, + dispatch_zone TEXT NOT NULL, associated_guild BIGINT, associated_user BIGINT, dispatch_extra JSONB @@ -161,23 +212,22 @@ class ScheduledDispatch(Struct, frozen=True, gc=False): associated_guild: int | None associated_user: int | None dispatch_extra: bytes | None - _count: int = field(default_factory=lambda: next(_c)) def __eq__(self, other: object) -> bool: return self is other - def __lt__(self, other: object) -> bool: - if isinstance(other, type(self)): - return (self.get_arrow_time(), self._count) < (other.get_arrow_time(), other._count) + def __lt__(self, other: Self) -> bool: + if type(self) is type(other): + return (self.get_arrow_time(), self.task_id) < (other.get_arrow_time(), self.task_id) return False - def __gt__(self, other: object) -> bool: - if isinstance(other, type(self)): - return (self.get_arrow_time(), self._count) > (other.get_arrow_time(), other._count) + def __gt__(self, other: Self) -> bool: + if type(self) is type(other): + return (self.get_arrow_time(), self.task_id) > (other.get_arrow_time(), self.task_id) return False @classmethod - def from_pg_row(cls: type[Self], row: asyncpg.Record) -> Self: + def from_pg_row(cls, row: asyncpg.Record) -> Self: tid, name, time, zone, guild, user, extra_bytes = row return cls(tid, name, time, zone, guild, user, extra_bytes) @@ -196,7 +246,7 @@ def from_exposed_api( if extra is not None: f = json_encode(extra) packed = f - return cls(uuid4().hex, name, time, zone, guild, user, packed) + return cls(_uuid7(), name, time, zone, guild, user, packed) def to_pg_row(self) -> SQLROW_TYPE: return ( @@ -213,7 +263,7 @@ def get_arrow_time(self) -> datetime: return datetime.strptime(self.dispatch_time, DATE_FMT).replace(tzinfo=ZoneInfo(self.dispatch_zone)) def unpack_extra(self) -> object | None: - if self.dispatch_extra: + if self.dispatch_extra is not None: return json_decode(self.dispatch_extra, strict=True) return None @@ -316,7 +366,7 @@ async def _loop(self) -> None: for s in scheduled: await self._queue.put(s) - async def __aexit__(self, exc_type: type[BaseException], exc_value: BaseException, traceback: TracebackType): + async def __aexit__(self, *exc_info: object): if not self._closing: msg = "Exiting without use of stop_gracefully may cause loss of tasks" warn(msg, stacklevel=2)