diff --git a/pyproject.toml b/pyproject.toml index 3f78a4c..68bfab9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,8 +73,10 @@ ignore = [ "ISC001", "ISC002", # == Project-specific ignores. - "S311", # No need for cryptographically secure number generation in this use case; it's just dice rolls. - # "PLR", # Allow complexity. + "PLR0912", # Allow more branches + "PLR0913", # Allow more parameters + "S311", # No need for cryptographically secure number generation in this use case; it's just dice rolls. + ] unfixable = [ "ERA", # Don't want erroneous deletion of comments. @@ -100,6 +102,5 @@ reportCallInDefaultInitializer = "warning" reportImportCycles = "warning" reportPropertyTypeMismatch = "warning" reportShadowedImports = "error" -reportUninitializedInstanceVariable = "warning" +# reportUninitializedInstanceVariable = "warning" reportUnnecessaryTypeIgnoreComment = "warning" - diff --git a/requirements.txt b/requirements.txt index 133c6df..2d75302 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,22 +1,18 @@ aiohttp -ao3.py @ git+https://github.com/Sachaa-Thanasius/ao3.py@main +ao3.py @ git+https://github.com/Sachaa-Thanasius/ao3.py arsenic async-lru asyncpg==0.29.0 asyncpg-stubs==0.29.1 -atlas-api @ https://github.com/Sachaa-Thanasius/atlas-api-wrapper +atlas-api @ git+https://github.com/Sachaa-Thanasius/atlas-api-wrapper discord.py[speed,voice]>=2.4.0 -fichub-api @ https://github.com/Sachaa-Thanasius/fichub-api -jishaku @ git+https://github.com/Gorialis/jishaku@a6661e2813124fbfe53326913e54f7c91e5d0dec +fichub-api @ git+https://github.com/Sachaa-Thanasius/fichub-api +jishaku @ git+https://github.com/Gorialis/jishaku lxml>=4.9.3 markdownify -msgspec[toml] +msgspec openpyxl Pillow>=10.0.0 types-lxml +tzdata; platform_system == "Windows" wavelink>=3.4.0 - -# To be used later: -# parsedatetime -# parsedatetime-stubs @ git+https://github.com/Sachaa-Thanasius/parsedatetime-stubs -# python-dateutil \ No newline at end of file diff --git a/src/beira/bot.py b/src/beira/bot.py index c24dc59..bbb55a8 100644 --- a/src/beira/bot.py +++ b/src/beira/bot.py @@ -5,7 +5,7 @@ import sys import time import traceback -from typing import Any, Self, overload +from typing import Any, overload from zoneinfo import ZoneInfo, ZoneInfoNotFoundError import aiohttp @@ -18,10 +18,11 @@ import wavelink from discord.ext import commands from discord.utils import MISSING -from exts import EXTENSIONS from .checks import is_blocked from .config import Config, load_config +from .exts import EXTENSIONS +from .scheduler import Scheduler from .tree import HookableTree from .utils import LoggingManager, Pool_alias, conn_init, copy_annotations @@ -97,6 +98,7 @@ def __init__( db_pool: Pool_alias, web_session: aiohttp.ClientSession, logging_manager: LoggingManager, + scheduler: Scheduler, initial_extensions: list[str] | None = None, **kwargs: Any, ) -> None: @@ -105,6 +107,7 @@ def __init__( self.db_pool = db_pool self.web_session = web_session self.logging_manager = logging_manager + self.scheduler = scheduler self.initial_extensions: list[str] = initial_extensions or [] # Various webfiction-related clients. @@ -115,7 +118,8 @@ def __init__( # Things to load before connecting to the Gateway. self.prefix_cache: dict[int, list[str]] = {} - self.blocked_entities_cache: dict[str, set[int]] = {} + self.blocked_guilds: set[int] = set() + self.blocked_users: set[int] = set() # Things that are more convenient to retrieve when established here or filled after connecting to the Gateway. self.special_friends: dict[str, int] = {} @@ -123,6 +127,12 @@ def __init__( # Add a global check for blocked members. self.add_check(is_blocked().predicate) + @property + def owner(self) -> discord.User: + """`discord.User`: The user that owns the bot.""" + + return self.app_info.owner + async def on_ready(self) -> None: """Display that the bot is ready.""" @@ -130,6 +140,7 @@ async def on_ready(self) -> None: LOGGER.info("Logged in as %s (ID: %s)", self.user, self.user.id) async def setup_hook(self) -> None: + self.scheduler.start_discord_dispatch(self) await self._load_guild_prefixes() await self._load_blocked_entities() await self._load_extensions() @@ -149,14 +160,15 @@ async def setup_hook(self) -> None: # Cache "friends". self.loop.create_task(self._load_special_friends()) - async def get_prefix(self, message: discord.Message, /) -> list[str] | str: - if not self.prefix_cache: - await self._load_guild_prefixes() + async def close(self) -> None: + await self.scheduler.stop_discord_dispatch() + await super().close() + async def get_prefix(self, message: discord.Message, /) -> list[str] | str: return self.prefix_cache.get(message.guild.id, "$") if message.guild else "$" @overload - async def get_context(self, origin: discord.Message | discord.Interaction, /) -> commands.Context[Self]: ... + async def get_context(self, origin: discord.Message | discord.Interaction, /) -> Context: ... @overload async def get_context[ContextT: commands.Context[Any]]( @@ -239,39 +251,27 @@ async def on_command_error(self, context: Context, exception: commands.CommandEr embed.add_field(name="Channel", value=f"{context.channel}", inline=False) LOGGER.error("Exception in command %s", context.command, exc_info=exception, extra={"embed": embed}) - @property - def owner(self) -> discord.User: - """`discord.User`: The user that owns the bot.""" - - return self.app_info.owner - async def _load_blocked_entities(self) -> None: """Load all blocked users and guilds from the bot database.""" - async with self.db_pool.acquire() as conn, conn.transaction(): - user_records = await conn.fetch("SELECT user_id FROM users WHERE is_blocked;") - guild_records = await conn.fetch("SELECT guild_id FROM guilds WHERE is_blocked;") + user_records = await self.db_pool.fetch("SELECT user_id FROM users WHERE is_blocked;") + self.blocked_users.update([record["user_id"] for record in user_records]) - self.blocked_entities_cache["users"] = {record["user_id"] for record in user_records} - self.blocked_entities_cache["guilds"] = {record["guild_id"] for record in guild_records} + guild_records = await self.db_pool.fetch("SELECT guild_id FROM guilds WHERE is_blocked;") + self.blocked_guilds.update([record["guild_id"] for record in guild_records]) - async def _load_guild_prefixes(self, guild_id: int | None = None) -> None: + async def _load_guild_prefixes(self) -> None: """Load all prefixes from the bot database.""" - query = "SELECT guild_id, prefix FROM guild_prefixes" - if guild_id: - query += " WHERE guild_id = $1" - try: - db_prefixes = await self.db_pool.fetch(query) + db_prefixes = await self.db_pool.fetch("SELECT guild_id, prefix FROM guild_prefixes;") except OSError: - LOGGER.exception("Couldn't load guild prefixes from the database. Ignoring for sake of defaults.") + LOGGER.exception("Couldn't load guild prefixes from the database. Using default(s) instead.") else: for entry in db_prefixes: self.prefix_cache.setdefault(entry["guild_id"], []).append(entry["prefix"]) - msg = f"(Re)loaded guild prefixes for {guild_id}." if guild_id else "(Re)loaded all guild prefixes." - LOGGER.info(msg) + LOGGER.info("(Re)loaded all guild prefixes.") async def _load_extensions(self) -> None: """Loads extensions/cogs. @@ -326,25 +326,17 @@ def is_special_friend(self, user: discord.abc.User, /) -> bool: return False - def is_ali(self, user: discord.abc.User, /) -> bool: - """Checks if a `discord.User` or `discord.Member` is Ali.""" - - if len(self.special_friends) > 0: - return user.id == self.special_friends["aeroali"] - - return False - async def main() -> None: """Starts an instance of the bot.""" config = load_config() - # Initialize a connection to a PostgreSQL database, an asynchronous web session, and a custom logger setup. async with ( aiohttp.ClientSession() as web_session, asyncpg.create_pool(dsn=config.database.pg_url, command_timeout=30, init=conn_init) as pool, LoggingManager() as logging_manager, + Scheduler(pool) as scheduler, ): # Set the bot's basic starting parameters. intents = discord.Intents.all() @@ -353,11 +345,12 @@ async def main() -> None: # Initialize and start the bot. async with Beira( - command_prefix=default_prefix, + default_prefix, config=config, db_pool=pool, web_session=web_session, logging_manager=logging_manager, + scheduler=scheduler, intents=intents, tree_cls=HookableTree, ) as bot: diff --git a/src/beira/checks.py b/src/beira/checks.py index feb4729..bd0240f 100644 --- a/src/beira/checks.py +++ b/src/beira/checks.py @@ -115,9 +115,9 @@ def is_blocked() -> "Check[Any]": async def predicate(ctx: Context) -> bool: if not (await ctx.bot.is_owner(ctx.author)): - if ctx.author.id in ctx.bot.blocked_entities_cache["users"]: + if ctx.author.id in ctx.bot.blocked_users: raise UserIsBlocked - if ctx.guild and (ctx.guild.id in ctx.bot.blocked_entities_cache["guilds"]): + if ctx.guild and ctx.guild.id in ctx.bot.blocked_guilds: raise GuildIsBlocked return True diff --git a/src/beira/exts/_dev.py b/src/beira/exts/_dev.py index fbc4d8d..f745695 100644 --- a/src/beira/exts/_dev.py +++ b/src/beira/exts/_dev.py @@ -1,6 +1,5 @@ """A cog that implements commands for reloading and syncing extensions and other commands, at the owner's behest.""" -import contextlib import logging from collections.abc import Generator from time import perf_counter @@ -63,13 +62,13 @@ async def cog_unload(self) -> None: for dev_guild in self.dev_guilds: self.bot.tree.remove_command( self.block_add_ctx_menu.name, - type=self.block_add_ctx_menu.type, guild=dev_guild, + type=self.block_add_ctx_menu.type, ) self.bot.tree.remove_command( self.block_remove_ctx_menu.name, - type=self.block_remove_ctx_menu.type, guild=dev_guild, + type=self.block_remove_ctx_menu.type, ) async def cog_check(self, ctx: beira.Context) -> bool: # type: ignore # Narrowing, and async is allowed. @@ -84,17 +83,11 @@ async def block(self, ctx: beira.Context) -> None: By default, display the users and guilds that are blocked from using the bot. """ - users = self.bot.blocked_entities_cache["users"] - guilds = self.bot.blocked_entities_cache["guilds"] + users_descr = "\n".join(str(self.bot.get_user(u) or u) for u in self.bot.blocked_users) + users_embed = discord.Embed(title="Blocked Users", description=users_descr) - users_embed = discord.Embed( - title="Blocked Users", - description="\n".join(str(self.bot.get_user(u) or u) for u in users), - ) - guilds_embed = discord.Embed( - title="Blocked Guilds", - description="\n".join(str(self.bot.get_guild(g) or g) for g in guilds), - ) + guilds_descr = "\n".join(str(self.bot.get_guild(g) or g) for g in self.bot.blocked_guilds) + guilds_embed = discord.Embed(title="Blocked Guilds", description=guilds_descr) await ctx.send(embeds=[users_embed, guilds_embed]) @block.command("add") @@ -127,7 +120,7 @@ async def block_add( SET is_blocked = EXCLUDED.is_blocked; """ await ctx.db.executemany(stmt, [(user.id, True) for user in entities]) - self.bot.blocked_entities_cache["users"].update(user.id for user in entities) + self.bot.blocked_users.update(user.id for user in entities) embed = discord.Embed(title="Users", description="\n".join(str(user) for user in entities)) else: stmt = """\ @@ -138,7 +131,7 @@ async def block_add( SET is_blocked = EXCLUDED.is_blocked; """ await ctx.db.executemany(stmt, [(guild.id, True) for guild in entities]) - self.bot.blocked_entities_cache["guilds"].update(guild.id for guild in entities) + self.bot.blocked_guilds.update(guild.id for guild in entities) embed = discord.Embed(title="Guilds", description="\n".join(str(guild) for guild in entities)) # Display the results. @@ -174,8 +167,8 @@ async def block_remove( SET is_blocked = EXCLUDED.is_blocked; """ await ctx.db.executemany(stmt, [(user.id, False) for user in entities]) - self.bot.blocked_entities_cache["users"].difference_update(user.id for user in entities) - embed = discord.Embed(title="Users", description="\n".join(str(user) for user in entities)) + self.bot.blocked_users.difference_update(user.id for user in entities) + embed = discord.Embed(title="Users", description="\n".join(map(str, entities))) else: stmt = """\ INSERT INTO guilds (guild_id, is_blocked) @@ -185,8 +178,8 @@ async def block_remove( SET is_blocked = EXCLUDED.is_blocked; """ await ctx.db.executemany(stmt, [(guild.id, False) for guild in entities]) - self.bot.blocked_entities_cache["guilds"].difference_update(guild.id for guild in entities) - embed = discord.Embed(title="Guilds", description="\n".join(str(guild) for guild in entities)) + self.bot.blocked_guilds.difference_update(guild.id for guild in entities) + embed = discord.Embed(title="Guilds", description="\n".join(map(str, entities))) # Display the results. await ctx.send("Unblocked the following from bot usage:", embed=embed, ephemeral=True) @@ -201,7 +194,7 @@ async def block_change_error(self, ctx: beira.Context, error: commands.CommandEr if ctx.interaction: error = getattr(error, "original", error) - if isinstance(error, PostgresError | PostgresConnectionError): + if isinstance(error, (PostgresError, PostgresConnectionError)): action = "block" if ctx.command.qualified_name == "block add" else "unblock" await ctx.send(f"Unable to {action} these users/guilds at this time.", ephemeral=True) @@ -215,7 +208,7 @@ async def context_menu_block_add(self, interaction: beira.Interaction, user: dis SET is_blocked = EXCLUDED.is_blocked; """ await self.bot.db_pool.execute(stmt, user.id, True) - self.bot.blocked_entities_cache["users"].update((user.id,)) + self.bot.blocked_users.add(user.id) # Display the results. embed = discord.Embed(title="Users", description=str(user)) @@ -235,7 +228,7 @@ async def context_menu_block_remove( SET is_blocked = EXCLUDED.is_blocked; """ await self.bot.db_pool.execute(stmt, user.id, False) - self.bot.blocked_entities_cache["users"].difference_update((user.id,)) + self.bot.blocked_users.difference_update((user.id,)) # Display the results. embed = discord.Embed(title="Users", description=str(user)) @@ -258,17 +251,15 @@ async def walk(self, ctx: beira.Context) -> None: def create_walk_embed(title: str, cmds: list[app_commands.AppCommand]) -> None: """Creates an embed for global and guild command areas and adds it to a collection of embeds.""" - descr = "\n".join([f"**{cmd.mention}**\n" for cmd in cmds]) + descr = "\n\n".join(f"**{cmd.mention}**" for cmd in cmds) walk_embed = discord.Embed(color=0xCCCCCC, title=title, description=descr) all_embeds.append(walk_embed) - global_commands = await self.bot.tree.fetch_commands() - if global_commands: + if global_commands := await self.bot.tree.fetch_commands(): create_walk_embed("Global App Commands Registered", global_commands) for guild in self.bot.guilds: - guild_commands = await self.bot.tree.fetch_commands(guild=guild) - if guild_commands: + if guild_commands := await self.bot.tree.fetch_commands(guild=guild): create_walk_embed(f"Guild App Commands Registered - {guild}", guild_commands) await ctx.reply(embeds=all_embeds, ephemeral=True) @@ -307,7 +298,7 @@ async def load_ext_autocomplete(self, _: beira.Interaction, current: str) -> lis exts_to_load = set(EXTENSIONS).difference(set(self.bot.extensions), set(IGNORE_EXTENSIONS)) return [ - app_commands.Choice(name=ext.rsplit(".", 1)[1], value=ext) + app_commands.Choice(name=ext.rpartition(".")[2], value=ext) for ext in exts_to_load if current.lower() in ext.lower() ][:25] @@ -398,7 +389,7 @@ async def ext_autocomplete(self, _: beira.Interaction, current: str) -> list[app """Autocompletes names for currently loaded extensions.""" return [ - app_commands.Choice(name=ext.rsplit(".", 1)[1], value=ext) + app_commands.Choice(name=ext.rpartition(".")[2], value=ext) for ext in self.bot.extensions if current.lower() in ext.lower() ][:25] @@ -427,7 +418,7 @@ async def sync_( self, ctx: beira.Context, guilds: commands.Greedy[discord.Object] = None, # type: ignore # Can't be type-hinted as optional. - spec: app_commands.Choice[str] | None = None, + spec: str | None = None, ) -> None: """Syncs the command tree in some way based on input. @@ -529,7 +520,6 @@ async def sync_error(self, ctx: beira.Context, error: commands.CommandError) -> "Syncing the commands failed due to a user related error, typically because the command has invalid " "data. This is equivalent to an HTTP status code of 400." ) - LOGGER.error("", exc_info=error) elif isinstance(error, discord.Forbidden): embed.description = "The bot does not have the `applications.commands` scope in the guild." elif isinstance(error, app_commands.MissingApplicationID): @@ -540,26 +530,18 @@ async def sync_error(self, ctx: beira.Context, error: commands.CommandError) -> embed.description = "Generic HTTP error: Syncing the commands failed." else: embed.description = "Syncing the commands failed." - LOGGER.exception("Unknown error in sync command", exc_info=error) await ctx.reply(embed=embed) @commands.hybrid_command() async def cmd_tree(self, ctx: beira.Context) -> None: - indent_level = 0 - - @contextlib.contextmanager - def new_indent(num: int = 4) -> Generator[None, object, None]: - nonlocal indent_level - indent_level += num - try: - yield - finally: - indent_level -= num - - def walk_commands_with_indent(group: commands.GroupMixin[Any]) -> Generator[str, object, None]: + """Display all bot commands in a pretty tree-like format.""" + + def walk_commands_with_indent(group: commands.GroupMixin[Any], indent_level: int = 0) -> Generator[str]: + indent_level += 4 + for cmd in group.commands: - if indent_level != 0: # noqa: SIM108 + if indent_level > 4: # noqa: SIM108 indent = (indent_level - 1) * "─" else: indent = "" @@ -567,11 +549,9 @@ def walk_commands_with_indent(group: commands.GroupMixin[Any]) -> Generator[str, yield f"└{indent}{cmd.qualified_name}" if isinstance(cmd, commands.GroupMixin): - with new_indent(): - yield from walk_commands_with_indent(cmd) + yield from walk_commands_with_indent(cmd, indent_level) - result = "\n".join(["Beira", *walk_commands_with_indent(ctx.bot)]) - await ctx.send(f"```\n{result}\n```") + await ctx.send("\n".join(("```", "Beira", *walk_commands_with_indent(ctx.bot), "```"))) async def setup(bot: beira.Beira) -> None: diff --git a/src/beira/exts/_test.py b/src/beira/exts/_test.py index b9d1886..1d29280 100644 --- a/src/beira/exts/_test.py +++ b/src/beira/exts/_test.py @@ -1,5 +1,3 @@ -import logging - import discord from discord import app_commands from discord.ext import commands @@ -8,9 +6,6 @@ from beira.tree import after_app_invoke, before_app_invoke -LOGGER = logging.getLogger(__name__) - - async def example_before_hook(itx: discord.Interaction) -> None: await itx.response.defer() await itx.followup.send("In pre-command hook.") diff --git a/src/beira/exts/admin.py b/src/beira/exts/admin.py index f7715fe..9e5bf1b 100644 --- a/src/beira/exts/admin.py +++ b/src/beira/exts/admin.py @@ -2,8 +2,6 @@ owner's behest. """ -import logging - import discord from asyncpg import PostgresError, PostgresWarning from discord.ext import commands @@ -11,9 +9,6 @@ import beira -LOGGER = logging.getLogger(__name__) - - class AdminCog(commands.Cog, name="Administration"): """A cog for handling administrative tasks like adding and removing prefixes.""" diff --git a/src/beira/exts/bot_stats.py b/src/beira/exts/bot_stats.py index 968d555..d451bde 100644 --- a/src/beira/exts/bot_stats.py +++ b/src/beira/exts/bot_stats.py @@ -1,6 +1,5 @@ """A cog for tracking different bot metrics.""" -import logging from datetime import timedelta from typing import Literal @@ -13,9 +12,6 @@ from beira.utils import StatsEmbed -LOGGER = logging.getLogger(__name__) - - class CommandStatsSearchFlags(commands.FlagConverter): """A Discord commands flag converter for queries related to command usage stats.""" @@ -49,7 +45,7 @@ def cog_emoji(self) -> discord.PartialEmoji: async def track_command_use(self, ctx: beira.Context) -> None: """Stores records of command uses in the database after some processing.""" - assert ctx.command is not None + assert ctx.command db = self.bot.db_pool @@ -209,7 +205,7 @@ async def check_usage(self, ctx: beira.Context, *, search_factors: CommandStatsS "\N{FIRST PLACE MEDAL}", "\N{SECOND PLACE MEDAL}", "\N{THIRD PLACE MEDAL}", - *("\N{SPORTS MEDAL}" for _ in range(6)), + *(["\N{SPORTS MEDAL}"] * 6), ] embed.add_leaderboard_fields(ldbd_content=record_tuples, ldbd_emojis=ldbd_emojis) @@ -223,7 +219,7 @@ async def command_autocomplete(self, interaction: beira.Interaction, current: st """Autocompletes with bot command names.""" assert self.bot.help_command - ctx = await self.bot.get_context(interaction, cls=beira.Context) + ctx = await self.bot.get_context(interaction) help_command = self.bot.help_command.copy() help_command.context = ctx diff --git a/src/beira/exts/ff_metadata.py b/src/beira/exts/ff_metadata.py new file mode 100644 index 0000000..8146f88 --- /dev/null +++ b/src/beira/exts/ff_metadata.py @@ -0,0 +1,562 @@ +"""A cog with triggers for retrieving story metadata.""" +# TODO: Account for orphaned fics, anonymous fics, really long embed descriptions, and series with more than 25 fics. + +import logging +import re +import textwrap +from collections.abc import AsyncGenerator +from typing import Any, Literal, NamedTuple + +import ao3 +import atlas_api +import discord +import fichub_api +import lxml.html +from discord.ext import commands + +import beira +from beira.utils import PaginatedSelectView, html_to_markdown + + +LOGGER = logging.getLogger(__name__) + +FANFICFINDER_ID = 779772534040166450 + +type StoryDataType = atlas_api.Story | fichub_api.Story | ao3.Work | ao3.Series + + +# region -------- Embed Helpers + + +FFN_PATTERN = re.compile(r"(?:www\.|m\.|)fanfiction\.net/s/(?P\d+)") +FP_PATTERN = re.compile(r"(?:www\.|m\.|)fictionpress\.com/s/\d+") +AO3_PATTERN = re.compile(r"(?:www\.|)archiveofourown\.org/(?Pworks|series)/(?P\d+)") +SB_PATTERN = re.compile(r"forums\.spacebattles\.com/threads/\S*") +SV_PATTERN = re.compile(r"forums\.sufficientvelocity\.com/threads/\S*") +QQ_PATTERN = re.compile(r"forums\.questionablequesting\.com/threads/\S*") +SIYE_PATTERN = re.compile(r"(?:www\.|)siye\.co\.uk/(?:siye/|)viewstory\.php\?sid=\d+") + +FFN_ICON = "https://www.fanfiction.net/static/icons3/ff-icon-128.png" +FP_ICON = "https://www.fanfiction.net/static/icons3/ff-icon-128.png" +AO3_ICON = ao3.utils.AO3_LOGO_URL +SB_ICON = "https://forums.spacebattles.com/data/svg/2/1/1682578744/2022_favicon_192x192.png" +SV_ICON = "https://forums.sufficientvelocity.com/favicon-96x96.png?v=69wyvmQdJN" +QQ_ICON = "https://forums.questionablequesting.com/favicon.ico" +SIYE_ICON = "https://www.siye.co.uk/siye/favicon.ico" + + +class StoryWebsite(NamedTuple): + name: str + acronym: str + story_regex: re.Pattern[str] + icon_url: str + + +STORY_WEBSITE_STORE: dict[str, StoryWebsite] = { + "FFN": StoryWebsite("FanFiction.Net", "FFN", FFN_PATTERN, FFN_ICON), + "FP": StoryWebsite("FictionPress", "FP", FP_PATTERN, FP_ICON), + "AO3": StoryWebsite("Archive of Our Own", "AO3", AO3_PATTERN, AO3_ICON), + "SB": StoryWebsite("SpaceBattles", "SB", SB_PATTERN, SB_ICON), + "SV": StoryWebsite("Sufficient Velocity", "SV", SV_PATTERN, SV_ICON), + "QQ": StoryWebsite("Questionable Questing", "QQ", QQ_PATTERN, QQ_ICON), + "SIYE": StoryWebsite("Sink Into Your Eyes", "SIYE", SIYE_PATTERN, SIYE_ICON), +} + +STORY_WEBSITE_REGEX = re.compile( + r"(?:http://|https://|)" + + "|".join(f"(?P<{key}>{value.story_regex.pattern})" for key, value in STORY_WEBSITE_STORE.items()), +) + + +def create_ao3_work_embed(work: ao3.Work) -> discord.Embed: + """Create an embed that holds all the relevant metadata for an Archive of Our Own work.""" + + # Format the relevant information. + if work.date_updated: + updated = work.date_updated.strftime("%B %d, %Y") + (" (Complete)" if work.is_complete else "") + else: + updated = "Unknown" + author_names = ", ".join(str(author.name) for author in work.authors) + fandoms = textwrap.shorten(", ".join(work.fandoms), 100, placeholder="...") + categories = textwrap.shorten(", ".join(work.categories), 100, placeholder="...") + characters = textwrap.shorten(", ".join(work.characters), 100, placeholder="...") + details = " • ".join((fandoms, categories, characters)) + stats_str = " • ".join( + ( + f"**Comments:** {work.ncomments:,d}", + f"**Kudos:** {work.nkudos:,d}", + f"**Bookmarks:** {work.nbookmarks:,d}", + f"**Hits:** {work.nhits:,d}", + ), + ) + + # Add the info in the embed appropriately. + author_url = f"https://archiveofourown.org/users/{work.authors[0].name}" + ao3_embed = ( + discord.Embed(title=work.title, url=work.url, description=work.summary, timestamp=discord.utils.utcnow()) + .set_author(name=author_names, url=author_url, icon_url=STORY_WEBSITE_STORE["AO3"].icon_url) + .add_field(name="\N{SCROLL} Last Updated", value=f"{updated}") + .add_field(name="\N{OPEN BOOK} Length", value=f"{work.nwords:,d} words in {work.nchapters} chapter(s)") + .add_field(name=f"\N{BOOKMARK} Rating: {work.rating}", value=details, inline=False) + .add_field(name="\N{BAR CHART} Stats", value=stats_str, inline=False) + .set_footer(text="A substitute for displaying AO3 information.") + ) + + # Use the remaining space in the embed for the truncated description. + if len(ao3_embed) > 6000: + ao3_embed.description = work.summary[: 6000 - len(ao3_embed) - 3] + "..." + return ao3_embed + + +def create_ao3_series_embed(series: ao3.Series) -> discord.Embed: + """Create an embed that holds all the relevant metadata for an Archive of Our Own series.""" + + author_url = f"https://archiveofourown.org/users/{series.creators[0].name}" + + # Format the relevant information. + if series.date_updated: + updated = series.date_updated.strftime("%B %d, %Y") + (" (Complete)" if series.is_complete else "") + else: + updated = "Unknown" + author_names = ", ".join(name for creator in series.creators if (name := creator.name)) + work_links = "\N{BOOKS} **Works:**\n" + "\n".join(f"[{work.title}]({work.url})" for work in series.works_list) + + # Add the info in the embed appropriately. + ao3_embed = ( + discord.Embed(title=series.name, url=series.url, description=work_links, timestamp=discord.utils.utcnow()) + .set_author(name=author_names, url=author_url, icon_url=STORY_WEBSITE_STORE["AO3"].icon_url) + .add_field(name="\N{SCROLL} Last Updated", value=updated) + .add_field(name="\N{OPEN BOOK} Length", value=f"{series.nwords:,d} words in {series.nworks} work(s)") + .set_footer(text="A substitute for displaying AO3 information.") + ) + + # Use the remaining space in the embed for the truncated description. + if len(ao3_embed) > 6000: + series_descr = series.description[: 6000 - len(ao3_embed) - 5] + "...\n\n" + ao3_embed.description = series_descr + (ao3_embed.description or "") + return ao3_embed + + +def create_atlas_ffn_embed(story: atlas_api.Story) -> discord.Embed: + """Create an embed that holds all the relevant metadata for a FanFiction.Net story.""" + + # Format the relevant information. + update_date = story.updated if story.updated else story.published + updated = update_date.strftime("%B %d, %Y") + (" (Complete)" if story.is_complete else "") + fandoms = textwrap.shorten(", ".join(story.fandoms), 100, placeholder="...") + genres = textwrap.shorten("/".join(story.genres), 100, placeholder="...") + characters = textwrap.shorten(", ".join(story.characters), 100, placeholder="...") + details = " • ".join((fandoms, genres, characters)) + stats = f"**Reviews:** {story.reviews:,d} • **Faves:** {story.favorites:,d} • **Follows:** {story.follows:,d}" + + # Add the info to the embed appropriately. + ffn_embed = ( + discord.Embed(title=story.title, url=story.url, description=story.description, timestamp=discord.utils.utcnow()) + .set_author(name=story.author.name, url=story.author.url, icon_url=STORY_WEBSITE_STORE["FFN"].icon_url) + .add_field(name="\N{SCROLL} Last Updated", value=updated) + .add_field(name="\N{OPEN BOOK} Length", value=f"{story.words:,d} words in {story.chapters} chapter(s)") + .add_field(name=f"\N{BOOKMARK} Rating: Fiction {story.rating}", value=details, inline=False) + .add_field(name="\N{BAR CHART} Stats", value=stats, inline=False) + .set_footer(text="Made using iris's Atlas API. Some results may be out of date or unavailable.") + ) + + # Use the remaining space in the embed for the truncated description. + if len(ffn_embed) > 6000: + ffn_embed.description = story.description[: 6000 - len(ffn_embed) - 3] + "..." + return ffn_embed + + +def create_fichub_embed(story: fichub_api.Story) -> discord.Embed: + """Create an embed that holds all the relevant metadata for a few different types of online fiction story.""" + + # Format the relevant information. + updated = story.updated.strftime("%B %d, %Y") + fandoms = textwrap.shorten(", ".join(story.fandoms), 100, placeholder="...") + categories_list = story.tags.category if isinstance(story, fichub_api.AO3Story) else () + categories = textwrap.shorten(", ".join(categories_list), 100, placeholder="...") + characters = textwrap.shorten(", ".join(story.characters), 100, placeholder="...") + details = " • ".join((fandoms, categories, characters)) + + # Get site-specific information, since FicHub works for multiple websites. + icon_url = next( + (value.icon_url for value in STORY_WEBSITE_STORE.values() if re.search(value.story_regex, story.url)), + None, + ) + + if isinstance(story, fichub_api.FFNStory): + stats_names = ("reviews", "favorites", "follows") + stats_str = " • ".join(f"**{name.capitalize()}:** {getattr(story.stats, name):,d}" for name in stats_names) + elif isinstance(story, fichub_api.AO3Story): + stats_names = ("comments", "kudos", "bookmarks", "hits") + stats_str = " • ".join(f"**{name.capitalize()}:** {getattr(story.stats, name):,d}" for name in stats_names) + else: + stats_str = "No stats available at this time." + + md_description = html_to_markdown(lxml.html.fromstring(story.description)) + + # Add the info to the embed appropriately. + story_embed = ( + discord.Embed(title=story.title, url=story.url, description=md_description, timestamp=discord.utils.utcnow()) + .set_author(name=story.author.name, url=story.author.url, icon_url=icon_url) + .add_field(name="\N{SCROLL} Last Updated", value=f"{updated} ({story.status.capitalize()})") + .add_field(name="\N{OPEN BOOK} Length", value=f"{story.words:,d} words in {story.chapters} chapter(s)") + .add_field(name=f"\N{BOOKMARK} Rating: {story.rating}", value=details, inline=False) + .add_field(name="\N{BAR CHART} Stats", value=stats_str, inline=False) + .set_footer(text="Made using the FicHub API. Some results may be out of date or unavailable.") + ) + + # Use the remaining space in the embed for the truncated description. + if len(story_embed) > 6000: + story_embed.description = md_description[: 6000 - len(story_embed) - 3] + "..." + return story_embed + + +def ff_embed_factory(story_data: Any | None) -> discord.Embed | None: + match story_data: + case atlas_api.Story(): + return create_atlas_ffn_embed(story_data) + case fichub_api.AO3Story() | fichub_api.FFNStory() | fichub_api.OtherStory(): + return create_fichub_embed(story_data) + case ao3.Work(): + return create_ao3_work_embed(story_data) + case ao3.Series(): + return create_ao3_series_embed(story_data) + case _: + return None + + +class AO3SeriesView(PaginatedSelectView[tuple[ao3.Work, ...]]): + """A view that wraps a dropdown item for AO3 works. + + Parameters + ---------- + author_id: int + The Discord ID of the user that triggered this view. No one else can use it. + series: ao3.Series + The object holding metadata about an AO3 series and the works within. + timeout: float | None, optional + Timeout in seconds from last interaction with the UI before no longer accepting input. + If ``None`` then there is no timeout. + + Attributes + ---------- + series: ao3.Series + The object holding metadata about an AO3 series and the works within. + """ + + def __init__(self, author_id: int, series: ao3.Series, *, timeout: float | None = 180) -> None: + self.series = series + super().__init__(author_id, series.works_list, timeout=timeout) + + async def on_timeout(self) -> None: + """Disables all items on timeout.""" + + for item in self.children: + item.disabled = True # type: ignore + + await self.message.edit(view=self) + self.stop() + + def populate_select(self) -> None: + self.select_page.placeholder = "Choose the work here..." + descr = textwrap.shorten(self.series.description, 100, placeholder="...") + self.select_page.add_option(label=self.series.name, value="0", description=descr, emoji="\N{BOOKS}") + + for i, work in enumerate(self.pages, start=1): + descr = textwrap.shorten(work.summary, 100, placeholder="...") + self.select_page.add_option( + label=f"{i}. {work.title}", + value=str(i), + description=descr, + emoji="\N{OPEN BOOK}", + ) + + def format_page(self) -> discord.Embed: + """Makes the series/work 'page' that the user will see.""" + + if self.page_index != 0: + embed_page = create_ao3_work_embed(self.pages[self.page_index - 1]) + else: + embed_page = create_ao3_series_embed(self.series) + return embed_page + + +# endregion + + +class FFMetadataCog(commands.GroupCog, name="Fanfiction Metadata Search", group_name="ff"): + """A cog with triggers and commands for retrieving story metadata.""" + + def __init__(self, bot: beira.Beira) -> None: + self.bot = bot + + self.atlas_client = bot.atlas_client + self.fichub_client = bot.fichub_client + self.ao3_client = bot.ao3_client + self.aci100_guild_id: int = bot.config.discord.important_guilds["prod"][0] + self.allowed_channels_cache: dict[int, set[int]] = {} + + @property + def cog_emoji(self) -> discord.PartialEmoji: + """`discord.PartialEmoji`: A partial emoji representing this cog.""" + + return discord.PartialEmoji(name="\N{BAR CHART}") + + async def cog_load(self) -> None: + # FIXME: Set up logging into AO3 via ao3.py. + # Load a cache of channels to auto-respond in. + records = await self.bot.db_pool.fetch("SELECT guild_id, channel_id FROM fanfic_autoresponse_settings;") + for record in records: + self.allowed_channels_cache.setdefault(record["guild_id"], set()).add(record["channel_id"]) + + @commands.Cog.listener("on_message") + async def on_posted_fanfic_link(self, message: discord.Message) -> None: + """Send informational embeds about a story if the user sends a fanfiction link. + + Must be triggered in an allowed channel. + """ + + if (message.author == self.bot.user) or (not message.guild) or message.guild.id == self.aci100_guild_id: + return + + # Listen to the allowed channels in the allowed guilds for valid fanfic links. + if ( + (channels_cache := self.allowed_channels_cache.get(message.guild.id, set())) + and (message.channel.id in channels_cache) + and re.search(STORY_WEBSITE_REGEX, message.content) + ): + # Only show typing indicator on valid messages. + async with message.channel.typing(): + # Send an embed for every valid link. + async for story_data in self.get_ff_data_from_links(message.content, message.guild.id): + if story_data is not None and (embed := ff_embed_factory(story_data)): + await message.channel.send(embed=embed) + + @commands.Cog.listener("on_message") + async def on_fanficfinder_nothing_found_message(self, message: discord.Message) -> None: + # Listen to the allowed channels in the allowed guilds. + + if bool( + message.guild + and (message.guild.id == self.aci100_guild_id) + and (message.author.id == FANFICFINDER_ID) + and message.embeds + and (embed := message.embeds[0]) + and embed.description is not None + and "fanfiction not found" in embed.description.lower(), + ): + await message.delete() + + @commands.hybrid_group(fallback="get") + @commands.guild_only() + async def autoresponse(self, ctx: beira.GuildContext) -> None: + """Autoresponse-related commands for automatically responding to fanfiction links in certain channels. + + By default, display the channels in the server set to autorespond. + """ + + async with ctx.typing(): + embed = discord.Embed( + title="Autoresponse Channels for Fanfic Links", + description="\n".join( + f"<#{channel}>" for channel in self.allowed_channels_cache.get(ctx.guild.id, set()) + ), + ) + await ctx.send(embed=embed) + + @autoresponse.command("add") + async def autoresponse_add( + self, + ctx: beira.GuildContext, + *, + channels: commands.Greedy[discord.abc.GuildChannel], + ) -> None: + """Set the bot to listen for AO3/FFN/other ff site links posted in the given channels. + + If allowed, the bot will respond automatically with an informational embed. + + Parameters + ---------- + ctx: `beira.GuildContext` + The invocation context. + channels: `commands.Greedy[discord.abc.GuildChannel]` + A list of channels to add, separated by spaces. + """ + + async with ctx.typing(): + # Update the database. + async with self.bot.db_pool.acquire() as conn: + stmt = """\ + INSERT INTO fanfic_autoresponse_settings (guild_id, channel_id) + VALUES ($1, $2) + ON CONFLICT (guild_id, channel_id) DO NOTHING; + """ + await conn.executemany(stmt, [(ctx.guild.id, channel.id) for channel in channels]) + + query = "SELECT channel_id FROM fanfic_autoresponse_settings WHERE guild_id = $1;" + records = await conn.fetch(query, ctx.guild.id) + + # Update the cache. + channel_ids: list[int] = [record[0] for record in records] + self.allowed_channels_cache.setdefault(ctx.guild.id, set()).update(channel_ids) + embed = discord.Embed( + title="Adjusted Autoresponse Channels for Fanfic Links", + description="\n".join(f"<#{id_}>" for id_ in channel_ids), + ) + await ctx.send(embed=embed) + + @autoresponse.command("remove") + async def autoresponse_remove( + self, + ctx: beira.GuildContext, + *, + channels: commands.Greedy[discord.abc.GuildChannel], + ) -> None: + """Set the bot to not listen for AO3/FFN/other ff site links posted in the given channels. + + The bot will no longer automatically respond to links with information embeds. + + Parameters + ---------- + ctx: `beira.GuildContext` + The invocation context. + channels: `commands.Greedy[discord.abc.GuildChannel]` + A list of channels to remove, separated by spaces. + """ + + async with ctx.typing(): + # Update the database. + async with self.bot.db_pool.acquire() as con: + stmt = "DELETE FROM fanfic_autoresponse_settings WHERE channel_id = $1;" + await con.executemany(stmt, [(channel.id,) for channel in channels]) + + query = "SELECT channel_id FROM fanfic_autoresponse_settings WHERE guild_id = $1;" + records = await con.fetch(query, ctx.guild.id) + + # Update the cache. + self.allowed_channels_cache.setdefault(ctx.guild.id, set()).intersection_update( + record["channel_id"] for record in records + ) + embed = discord.Embed( + title="Adjusted Autoresponse Channels for Fanfic Links", + description="\n".join(f"<#{record[0]}>" for record in records), + ) + await ctx.send(embed=embed) + + @commands.hybrid_command() + async def ff_search( + self, + ctx: beira.Context, + platform: Literal["ao3", "ffn", "other"], + *, + name_or_url: str, + ) -> None: + """Search available platforms for a fic with a certain title or url. + + Note: Only urls are accepted for `other`. + + Parameters + ---------- + ctx: `beira.Context` + The invocation context. + platform: `Literal["ao3", "ffn", "other"]` + The platform to search. + name_or_url: `str` + The search string for the story title, or the story url. + """ + + async with ctx.typing(): + if platform == "ao3": + story_data = await self.search_ao3(name_or_url) + elif platform == "ffn": + story_data = await self.search_ffn(name_or_url) + else: + story_data = await self.search_other(name_or_url) + + embed = ff_embed_factory(story_data) + if embed is None: + embed = discord.Embed( + title="No Results", + description="No results found. You may need to edit your search.", + timestamp=discord.utils.utcnow(), + ) + + if isinstance(story_data, ao3.Series): + view = AO3SeriesView(ctx.author.id, story_data) + view.message = await ctx.send(embed=embed, view=view) + else: + await ctx.send(embed=embed) + + async def search_ao3(self, name_or_url: str) -> ao3.Work | ao3.Series | fichub_api.Story | None: + """More generically search AO3 for works based on a partial title or full url.""" + + if match := re.search(STORY_WEBSITE_STORE["AO3"].story_regex, name_or_url): + if match.group("type") == "series": + try: + series_id = match.group("ao3_id") + story_data = await self.ao3_client.get_series(int(series_id)) + except ao3.AO3Exception: + LOGGER.exception("") + story_data = None + else: + try: + url = match.group(0) + story_data = await self.fichub_client.get_story_metadata(url) + except fichub_api.FicHubException as err: + msg = "Retrieval with Fichub client failed. Trying the AO3 scraping library now." + LOGGER.warning(msg, exc_info=err) + try: + work_id = match.group("ao3_id") + story_data = await self.ao3_client.get_work(int(work_id)) + except ao3.AO3Exception as err: + msg = "Retrieval with Fichub client and AO3 scraping library failed. Returning None." + LOGGER.warning(msg, exc_info=err) + story_data = None + else: + search_options = ao3.WorkSearchOptions(any_field=name_or_url) + search = await self.ao3_client.search_works(search_options) + story_data = results[0] if (results := search.results) else None + + return story_data + + async def search_ffn(self, name_or_url: str) -> atlas_api.Story | fichub_api.Story | None: + """More generically search FFN for works based on a partial title or full url.""" + + if fic_id := atlas_api.extract_fic_id(name_or_url): + try: + story_data = await self.atlas_client.get_story_metadata(fic_id) + except atlas_api.AtlasException as err: + msg = "Retrieval with Atlas client failed. Trying FicHub now." + LOGGER.warning(msg, exc_info=err) + try: + story_data = await self.fichub_client.get_story_metadata(name_or_url) + except fichub_api.FicHubException as err: + msg = "Retrieval with Atlas and Fichub clients failed. Returning None." + LOGGER.warning(msg, exc_info=err) + story_data = None + else: + results = await self.atlas_client.get_bulk_metadata(title_ilike=f"%{name_or_url}%", limit=1) + story_data = results[0] if results else None + + return story_data + + async def search_other(self, url: str) -> fichub_api.Story | None: + """More generically search for the metadata of other works based on a full url.""" + + return await self.fichub_client.get_story_metadata(url) + + async def get_ff_data_from_links(self, text: str, guild_id: int) -> AsyncGenerator[StoryDataType | None]: + for match_obj in re.finditer(STORY_WEBSITE_REGEX, text): + # Attempt to get the story data from whatever method. + if match_obj.lastgroup == "FFN": + yield await self.atlas_client.get_story_metadata(int(match_obj.group("ffn_id"))) + elif match_obj.lastgroup == "AO3": + yield await self.search_ao3(match_obj.group(0)) + elif match_obj.lastgroup: + yield await self.search_other(match_obj.group(0)) + else: + yield None + + +async def setup(bot: beira.Beira) -> None: + await bot.add_cog(FFMetadataCog(bot)) diff --git a/src/beira/exts/ff_metadata/__init__.py b/src/beira/exts/ff_metadata/__init__.py deleted file mode 100644 index 77459fc..0000000 --- a/src/beira/exts/ff_metadata/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -import beira - -from .ff_metadata import FFMetadataCog - - -async def setup(bot: beira.Beira) -> None: - await bot.add_cog(FFMetadataCog(bot)) diff --git a/src/beira/exts/ff_metadata/ff_metadata.py b/src/beira/exts/ff_metadata/ff_metadata.py deleted file mode 100644 index 61d59eb..0000000 --- a/src/beira/exts/ff_metadata/ff_metadata.py +++ /dev/null @@ -1,305 +0,0 @@ -"""A cog with triggers for retrieving story metadata.""" -# TODO: Account for orphaned fics, anonymous fics, really long embed descriptions, and series with more than 25 fics. - -import logging -import re -from collections.abc import AsyncGenerator -from typing import Literal - -import ao3 -import atlas_api -import discord -import fichub_api -from discord.ext import commands - -import beira - -from .utils import ( - STORY_WEBSITE_REGEX, - STORY_WEBSITE_STORE, - AO3SeriesView, - ff_embed_factory, -) - - -type StoryDataType = atlas_api.Story | fichub_api.Story | ao3.Work | ao3.Series - - -LOGGER = logging.getLogger(__name__) -FANFICFINDER_ID = 779772534040166450 - - -class FFMetadataCog(commands.GroupCog, name="Fanfiction Metadata Search", group_name="ff"): - """A cog with triggers and commands for retrieving story metadata.""" - - def __init__(self, bot: beira.Beira) -> None: - self.bot = bot - - self.atlas_client = bot.atlas_client - self.fichub_client = bot.fichub_client - self.ao3_client = bot.ao3_client - self.aci100_guild_id: int = bot.config.discord.important_guilds["prod"][0] - self.allowed_channels_cache: dict[int, set[int]] = {} - - @property - def cog_emoji(self) -> discord.PartialEmoji: - """`discord.PartialEmoji`: A partial emoji representing this cog.""" - - return discord.PartialEmoji(name="\N{BAR CHART}") - - async def cog_load(self) -> None: - # FIXME: Setup logging into AO3 via ao3.py. - # Load a cache of channels to auto-respond in. - records = await self.bot.db_pool.fetch("SELECT guild_id, channel_id FROM fanfic_autoresponse_settings;") - for record in records: - self.allowed_channels_cache.setdefault(record["guild_id"], set()).add(record["channel_id"]) - - @commands.Cog.listener("on_message") - async def on_posted_fanfic_link(self, message: discord.Message) -> None: - """Send informational embeds about a story if the user sends a fanfiction link. - - Must be triggered in an allowed channel. - """ - - if (message.author == self.bot.user) or (not message.guild) or message.guild.id == self.aci100_guild_id: - return - - # Listen to the allowed channels in the allowed guilds for valid fanfic links. - if ( - (channels_cache := self.allowed_channels_cache.get(message.guild.id, set())) - and (message.channel.id in channels_cache) - and re.search(STORY_WEBSITE_REGEX, message.content) - ): - # Only show typing indicator on valid messages. - async with message.channel.typing(): - # Send an embed for every valid link. - async for story_data in self.get_ff_data_from_links(message.content, message.guild.id): - if story_data is not None: - embed = ff_embed_factory(story_data) - if embed is not None: - await message.channel.send(embed=embed) - - @commands.Cog.listener("on_message") - async def on_fanficfinder_nothing_found_message(self, message: discord.Message) -> None: - # Listen to the allowed channels in the allowed guilds. - - if bool( - message.guild - and (message.guild.id == self.aci100_guild_id) - and (message.author.id == FANFICFINDER_ID) - and message.embeds - and (embed := message.embeds[0]) - and embed.description is not None - and "fanfiction not found" in embed.description.lower(), - ): - await message.delete() - - @commands.hybrid_group(fallback="get") - @commands.guild_only() - async def autoresponse(self, ctx: beira.GuildContext) -> None: - """Autoresponse-related commands for automatically responding to fanfiction links in certain channels. - - By default, display the channels in the server set to autorespond. - """ - - async with ctx.typing(): - embed = discord.Embed( - title="Autoresponse Channels for Fanfic Links", - description="\n".join( - f"<#{channel}>" for channel in self.allowed_channels_cache.get(ctx.guild.id, set()) - ), - ) - await ctx.send(embed=embed) - - @autoresponse.command("add") - async def autoresponse_add( - self, - ctx: beira.GuildContext, - *, - channels: commands.Greedy[discord.abc.GuildChannel], - ) -> None: - """Set the bot to listen for AO3/FFN/other ff site links posted in the given channels. - - If allowed, the bot will respond automatically with an informational embed. - - Parameters - ---------- - ctx: `beira.GuildContext` - The invocation context. - channels: `commands.Greedy[discord.abc.GuildChannel]` - A list of channels to add, separated by spaces. - """ - - async with ctx.typing(): - # Update the database. - async with self.bot.db_pool.acquire() as conn: - stmt = """\ - INSERT INTO fanfic_autoresponse_settings (guild_id, channel_id) - VALUES ($1, $2) - ON CONFLICT (guild_id, channel_id) DO NOTHING; - """ - await conn.executemany(stmt, [(ctx.guild.id, channel.id) for channel in channels]) - - query = "SELECT channel_id FROM fanfic_autoresponse_settings WHERE guild_id = $1;" - records = await conn.fetch(query, ctx.guild.id) - - # Update the cache. - self.allowed_channels_cache.setdefault(ctx.guild.id, set()).update(record[0] for record in records) - embed = discord.Embed( - title="Adjusted Autoresponse Channels for Fanfic Links", - description="\n".join(f"<#{record[0]}>" for record in records), - ) - await ctx.send(embed=embed) - - @autoresponse.command("remove") - async def autoresponse_remove( - self, - ctx: beira.GuildContext, - *, - channels: commands.Greedy[discord.abc.GuildChannel], - ) -> None: - """Set the bot to not listen for AO3/FFN/other ff site links posted in the given channels. - - The bot will no longer automatically respond to links with information embeds. - - Parameters - ---------- - ctx: `beira.GuildContext` - The invocation context. - channels: `commands.Greedy[discord.abc.GuildChannel]` - A list of channels to remove, separated by spaces. - """ - - async with ctx.typing(): - # Update the database. - async with self.bot.db_pool.acquire() as con: - stmt = "DELETE FROM fanfic_autoresponse_settings WHERE channel_id = $1;" - await con.executemany(stmt, [(channel.id,) for channel in channels]) - - query = "SELECT channel_id FROM fanfic_autoresponse_settings WHERE guild_id = $1;" - records = await con.fetch(query, ctx.guild.id) - - # Update the cache. - self.allowed_channels_cache.setdefault(ctx.guild.id, set()).intersection_update( - record["channel_id"] for record in records - ) - embed = discord.Embed( - title="Adjusted Autoresponse Channels for Fanfic Links", - description="\n".join(f"<#{record[0]}>" for record in records), - ) - await ctx.send(embed=embed) - - @commands.hybrid_command() - async def ff_search( - self, - ctx: beira.Context, - platform: Literal["ao3", "ffn", "other"], - *, - name_or_url: str, - ) -> None: - """Search available platforms for a fic with a certain title or url. - - Note: Only urls are accepted for `other`. - - Parameters - ---------- - ctx: `beira.Context` - The invocation context. - platform: `Literal["ao3", "ffn", "other"]` - The platform to search. - name_or_url: `str` - The search string for the story title, or the story url. - """ - - async with ctx.typing(): - if platform == "ao3": - story_data = await self.search_ao3(name_or_url) - elif platform == "ffn": - story_data = await self.search_ffn(name_or_url) - else: - story_data = await self.search_other(name_or_url) - - embed = ff_embed_factory(story_data) - if embed is None: - embed = discord.Embed( - title="No Results", - description="No results found. You may need to edit your search.", - timestamp=discord.utils.utcnow(), - ) - - if isinstance(story_data, ao3.Series): - view = AO3SeriesView(ctx.author.id, story_data) - view.message = await ctx.send(embed=embed, view=view) - else: - await ctx.send(embed=embed) - - async def search_ao3(self, name_or_url: str) -> ao3.Work | ao3.Series | fichub_api.Story | None: - """More generically search AO3 for works based on a partial title or full url.""" - - if match := re.search(STORY_WEBSITE_STORE["AO3"].story_regex, name_or_url): - if match.group("type") == "series": - try: - series_id = match.group("ao3_id") - story_data = await self.ao3_client.get_series(int(series_id)) - except ao3.AO3Exception: - LOGGER.exception("") - story_data = None - else: - try: - url = match.group(0) - story_data = await self.fichub_client.get_story_metadata(url) - except fichub_api.FicHubException as err: - msg = "Retrieval with Fichub client failed. Trying the AO3 scraping library now." - LOGGER.warning(msg, exc_info=err) - try: - work_id = match.group("ao3_id") - story_data = await self.ao3_client.get_work(int(work_id)) - except ao3.AO3Exception as err: - msg = "Retrieval with Fichub client and AO3 scraping library failed. Returning None." - LOGGER.warning(msg, exc_info=err) - story_data = None - else: - search_options = ao3.WorkSearchOptions(any_field=name_or_url) - search = await self.ao3_client.search_works(search_options) - story_data = results[0] if (results := search.results) else None - - return story_data - - async def search_ffn(self, name_or_url: str) -> atlas_api.Story | fichub_api.Story | None: - """More generically search FFN for works based on a partial title or full url.""" - - if fic_id := atlas_api.extract_fic_id(name_or_url): - try: - story_data = await self.atlas_client.get_story_metadata(fic_id) - except atlas_api.AtlasException as err: - msg = "Retrieval with Atlas client failed. Trying FicHub now." - LOGGER.warning(msg, exc_info=err) - try: - story_data = await self.fichub_client.get_story_metadata(name_or_url) - except fichub_api.FicHubException as err: - msg = "Retrieval with Atlas and Fichub clients failed. Returning None." - LOGGER.warning(msg, exc_info=err) - story_data = None - else: - results = await self.atlas_client.get_bulk_metadata(title_ilike=f"%{name_or_url}%", limit=1) - story_data = results[0] if results else None - - return story_data - - async def search_other(self, url: str) -> fichub_api.Story | None: - """More generically search for the metadata of other works based on a full url.""" - - return await self.fichub_client.get_story_metadata(url) - - async def get_ff_data_from_links(self, text: str, guild_id: int) -> AsyncGenerator[StoryDataType | None, None]: - for match_obj in re.finditer(STORY_WEBSITE_REGEX, text): - # Attempt to get the story data from whatever method. - if match_obj.lastgroup == "FFN": - story_data = await self.atlas_client.get_story_metadata(int(match_obj.group("ffn_id"))) - elif match_obj.lastgroup == "AO3": - story_data = await self.search_ao3(match_obj.group(0)) - elif match_obj.lastgroup and (match_obj.lastgroup != "AO3"): - story_data = await self.search_other(match_obj.group(0)) - else: - story_data = None - yield story_data diff --git a/src/beira/exts/ff_metadata/utils.py b/src/beira/exts/ff_metadata/utils.py deleted file mode 100644 index c4fa2c1..0000000 --- a/src/beira/exts/ff_metadata/utils.py +++ /dev/null @@ -1,275 +0,0 @@ -import re -import textwrap -from typing import Any, NamedTuple - -import ao3 -import atlas_api -import discord -import fichub_api -import lxml.html - -from beira.utils import PaginatedSelectView, html_to_markdown - - -__all__ = ( - "STORY_WEBSITE_STORE", - "STORY_WEBSITE_REGEX", - "create_ao3_work_embed", - "create_ao3_series_embed", - "create_fichub_embed", - "create_atlas_ffn_embed", - "ff_embed_factory", - "AO3SeriesView", -) - -FFN_PATTERN = re.compile(r"(?:www\.|m\.|)fanfiction\.net/s/(?P\d+)") -FP_PATTERN = re.compile(r"(?:www\.|m\.|)fictionpress\.com/s/\d+") -AO3_PATTERN = re.compile(r"(?:www\.|)archiveofourown\.org/(?Pworks|series)/(?P\d+)") -SB_PATTERN = re.compile(r"forums\.spacebattles\.com/threads/\S*") -SV_PATTERN = re.compile(r"forums\.sufficientvelocity\.com/threads/\S*") -QQ_PATTERN = re.compile(r"forums\.questionablequesting\.com/threads/\S*") -SIYE_PATTERN = re.compile(r"(?:www\.|)siye\.co\.uk/(?:siye/|)viewstory\.php\?sid=\d+") - -FFN_ICON = "https://www.fanfiction.net/static/icons3/ff-icon-128.png" -FP_ICON = "https://www.fanfiction.net/static/icons3/ff-icon-128.png" -AO3_ICON = ao3.utils.AO3_LOGO_URL -SB_ICON = "https://forums.spacebattles.com/data/svg/2/1/1682578744/2022_favicon_192x192.png" -SV_ICON = "https://forums.sufficientvelocity.com/favicon-96x96.png?v=69wyvmQdJN" -QQ_ICON = "https://forums.questionablequesting.com/favicon.ico" -SIYE_ICON = "https://www.siye.co.uk/siye/favicon.ico" - - -class StoryWebsite(NamedTuple): - name: str - acronym: str - story_regex: re.Pattern[str] - icon_url: str - - -STORY_WEBSITE_STORE: dict[str, StoryWebsite] = { - "FFN": StoryWebsite("FanFiction.Net", "FFN", FFN_PATTERN, FFN_ICON), - "FP": StoryWebsite("FictionPress", "FP", FP_PATTERN, FP_ICON), - "AO3": StoryWebsite("Archive of Our Own", "AO3", AO3_PATTERN, AO3_ICON), - "SB": StoryWebsite("SpaceBattles", "SB", SB_PATTERN, SB_ICON), - "SV": StoryWebsite("Sufficient Velocity", "SV", SV_PATTERN, SV_ICON), - "QQ": StoryWebsite("Questionable Questing", "QQ", QQ_PATTERN, QQ_ICON), - "SIYE": StoryWebsite("Sink Into Your Eyes", "SIYE", SIYE_PATTERN, SIYE_ICON), -} - -STORY_WEBSITE_REGEX = re.compile( - r"(?:http://|https://|)" - + "|".join(f"(?P<{key}>{value.story_regex.pattern})" for key, value in STORY_WEBSITE_STORE.items()), -) - - -def create_ao3_work_embed(work: ao3.Work) -> discord.Embed: - """Create an embed that holds all the relevant metadata for an Archive of Our Own work.""" - - # Format the relevant information. - if work.date_updated: - updated = work.date_updated.strftime("%B %d, %Y") + (" (Complete)" if work.is_complete else "") - else: - updated = "Unknown" - author_names = ", ".join(str(author.name) for author in work.authors) - fandoms = textwrap.shorten(", ".join(work.fandoms), 100, placeholder="...") - categories = textwrap.shorten(", ".join(work.categories), 100, placeholder="...") - characters = textwrap.shorten(", ".join(work.characters), 100, placeholder="...") - details = " • ".join((fandoms, categories, characters)) - stats_str = " • ".join( - ( - f"**Comments:** {work.ncomments:,d}", - f"**Kudos:** {work.nkudos:,d}", - f"**Bookmarks:** {work.nbookmarks:,d}", - f"**Hits:** {work.nhits:,d}", - ), - ) - - # Add the info in the embed appropriately. - author_url = f"https://archiveofourown.org/users/{work.authors[0].name}" - ao3_embed = ( - discord.Embed(title=work.title, url=work.url, description=work.summary, timestamp=discord.utils.utcnow()) - .set_author(name=author_names, url=author_url, icon_url=STORY_WEBSITE_STORE["AO3"].icon_url) - .add_field(name="\N{SCROLL} Last Updated", value=f"{updated}") - .add_field(name="\N{OPEN BOOK} Length", value=f"{work.nwords:,d} words in {work.nchapters} chapter(s)") - .add_field(name=f"\N{BOOKMARK} Rating: {work.rating}", value=details, inline=False) - .add_field(name="\N{BAR CHART} Stats", value=stats_str, inline=False) - .set_footer(text="A substitute for displaying AO3 information.") - ) - - # Use the remaining space in the embed for the truncated description. - if len(ao3_embed) > 6000: - ao3_embed.description = work.summary[: 6000 - len(ao3_embed) - 3] + "..." - return ao3_embed - - -def create_ao3_series_embed(series: ao3.Series) -> discord.Embed: - """Create an embed that holds all the relevant metadata for an Archive of Our Own series.""" - - author_url = f"https://archiveofourown.org/users/{series.creators[0].name}" - - # Format the relevant information. - if series.date_updated: - updated = series.date_updated.strftime("%B %d, %Y") + (" (Complete)" if series.is_complete else "") - else: - updated = "Unknown" - author_names = ", ".join(name for creator in series.creators if (name := creator.name)) - work_links = "\N{BOOKS} **Works:**\n" + "\n".join(f"[{work.title}]({work.url})" for work in series.works_list) - - # Add the info in the embed appropriately. - ao3_embed = ( - discord.Embed(title=series.name, url=series.url, description=work_links, timestamp=discord.utils.utcnow()) - .set_author(name=author_names, url=author_url, icon_url=STORY_WEBSITE_STORE["AO3"].icon_url) - .add_field(name="\N{SCROLL} Last Updated", value=updated) - .add_field(name="\N{OPEN BOOK} Length", value=f"{series.nwords:,d} words in {series.nworks} work(s)") - .set_footer(text="A substitute for displaying AO3 information.") - ) - - # Use the remaining space in the embed for the truncated description. - if len(ao3_embed) > 6000: - series_descr = series.description[: 6000 - len(ao3_embed) - 5] + "...\n\n" - ao3_embed.description = series_descr + (ao3_embed.description or "") - return ao3_embed - - -def create_atlas_ffn_embed(story: atlas_api.Story) -> discord.Embed: - """Create an embed that holds all the relevant metadata for a FanFiction.Net story.""" - - # Format the relevant information. - update_date = story.updated if story.updated else story.published - updated = update_date.strftime("%B %d, %Y") + (" (Complete)" if story.is_complete else "") - fandoms = textwrap.shorten(", ".join(story.fandoms), 100, placeholder="...") - genres = textwrap.shorten("/".join(story.genres), 100, placeholder="...") - characters = textwrap.shorten(", ".join(story.characters), 100, placeholder="...") - details = " • ".join((fandoms, genres, characters)) - stats = f"**Reviews:** {story.reviews:,d} • **Faves:** {story.favorites:,d} • **Follows:** {story.follows:,d}" - - # Add the info to the embed appropriately. - ffn_embed = ( - discord.Embed(title=story.title, url=story.url, description=story.description, timestamp=discord.utils.utcnow()) - .set_author(name=story.author.name, url=story.author.url, icon_url=STORY_WEBSITE_STORE["FFN"].icon_url) - .add_field(name="\N{SCROLL} Last Updated", value=updated) - .add_field(name="\N{OPEN BOOK} Length", value=f"{story.words:,d} words in {story.chapters} chapter(s)") - .add_field(name=f"\N{BOOKMARK} Rating: Fiction {story.rating}", value=details, inline=False) - .add_field(name="\N{BAR CHART} Stats", value=stats, inline=False) - .set_footer(text="Made using iris's Atlas API. Some results may be out of date or unavailable.") - ) - - # Use the remaining space in the embed for the truncated description. - if len(ffn_embed) > 6000: - ffn_embed.description = story.description[: 6000 - len(ffn_embed) - 3] + "..." - return ffn_embed - - -def create_fichub_embed(story: fichub_api.Story) -> discord.Embed: - """Create an embed that holds all the relevant metadata for a few different types of online fiction story.""" - - # Format the relevant information. - updated = story.updated.strftime("%B %d, %Y") - fandoms = textwrap.shorten(", ".join(story.fandoms), 100, placeholder="...") - categories_list = story.tags.category if isinstance(story, fichub_api.AO3Story) else () - categories = textwrap.shorten(", ".join(categories_list), 100, placeholder="...") - characters = textwrap.shorten(", ".join(story.characters), 100, placeholder="...") - details = " • ".join((fandoms, categories, characters)) - - # Get site-specific information, since FicHub works for multiple websites. - icon_url = next( - (value.icon_url for value in STORY_WEBSITE_STORE.values() if re.search(value.story_regex, story.url)), - None, - ) - - if isinstance(story, fichub_api.FFNStory): - stats_names = ("reviews", "favorites", "follows") - stats_str = " • ".join(f"**{name.capitalize()}:** {getattr(story.stats, name):,d}" for name in stats_names) - elif isinstance(story, fichub_api.AO3Story): - stats_names = ("comments", "kudos", "bookmarks", "hits") - stats_str = " • ".join(f"**{name.capitalize()}:** {getattr(story.stats, name):,d}" for name in stats_names) - else: - stats_str = "No stats available at this time." - - md_description = html_to_markdown(lxml.html.fromstring(story.description)) - - # Add the info to the embed appropriately. - story_embed = ( - discord.Embed(title=story.title, url=story.url, description=md_description, timestamp=discord.utils.utcnow()) - .set_author(name=story.author.name, url=story.author.url, icon_url=icon_url) - .add_field(name="\N{SCROLL} Last Updated", value=f"{updated} ({story.status.capitalize()})") - .add_field(name="\N{OPEN BOOK} Length", value=f"{story.words:,d} words in {story.chapters} chapter(s)") - .add_field(name=f"\N{BOOKMARK} Rating: {story.rating}", value=details, inline=False) - .add_field(name="\N{BAR CHART} Stats", value=stats_str, inline=False) - .set_footer(text="Made using the FicHub API. Some results may be out of date or unavailable.") - ) - - # Use the remaining space in the embed for the truncated description. - if len(story_embed) > 6000: - story_embed.description = md_description[: 6000 - len(story_embed) - 3] + "..." - return story_embed - - -def ff_embed_factory(story_data: Any | None) -> discord.Embed | None: - match story_data: - case atlas_api.Story(): - return create_atlas_ffn_embed(story_data) - case fichub_api.AO3Story() | fichub_api.FFNStory() | fichub_api.OtherStory(): - return create_fichub_embed(story_data) - case ao3.Work(): - return create_ao3_work_embed(story_data) - case ao3.Series(): - return create_ao3_series_embed(story_data) - case _: - return None - - -class AO3SeriesView(PaginatedSelectView[ao3.Work]): - """A view that wraps a dropdown item for AO3 works. - - Parameters - ---------- - author_id: int - The Discord ID of the user that triggered this view. No one else can use it. - series: ao3.Series - The object holding metadata about an AO3 series and the works within. - timeout: float | None, optional - Timeout in seconds from last interaction with the UI before no longer accepting input. - If ``None`` then there is no timeout. - - Attributes - ---------- - series: ao3.Series - The object holding metadata about an AO3 series and the works within. - """ - - def __init__(self, author_id: int, series: ao3.Series, *, timeout: float | None = 180) -> None: - self.series = series - super().__init__(author_id, series.works_list, timeout=timeout) - - async def on_timeout(self) -> None: - """Disables all items on timeout.""" - - for item in self.children: - item.disabled = True # type: ignore - - await self.message.edit(view=self) - self.stop() - - def populate_select(self) -> None: - self.select_page.placeholder = "Choose the work here..." - descr = textwrap.shorten(self.series.description, 100, placeholder="...") - self.select_page.add_option(label=self.series.name, value="0", description=descr, emoji="\N{BOOKS}") - - for i, work in enumerate(self.pages, start=1): - descr = textwrap.shorten(work.summary, 100, placeholder="...") - self.select_page.add_option( - label=f"{i}. {work.title}", - value=str(i), - description=descr, - emoji="\N{OPEN BOOK}", - ) - - def format_page(self) -> discord.Embed: - """Makes the series/work 'page' that the user will see.""" - - if self.page_index != 0: - embed_page = create_ao3_work_embed(self.pages[self.page_index - 1]) - else: - embed_page = create_ao3_series_embed(self.series) - return embed_page diff --git a/src/beira/exts/help.py b/src/beira/exts/help.py index 4e141ee..99df63f 100644 --- a/src/beira/exts/help.py +++ b/src/beira/exts/help.py @@ -3,7 +3,6 @@ The implementation is based off of this guide: https://gist.github.com/InterStella0/b78488fb28cadf279dfd3164b9f0cf96 """ -import logging import re from collections.abc import Mapping from typing import Any @@ -16,8 +15,6 @@ from beira.utils import PaginatedEmbedView -LOGGER = logging.getLogger(__name__) - HELP_COLOR = 0x16A75D @@ -240,7 +237,7 @@ async def help_(self, interaction: beira.Interaction, command: str | None = None A name to match to a bot command. If unfilled, default to the generic help dialog. """ - ctx = await self.bot.get_context(interaction, cls=beira.Context) + ctx = await self.bot.get_context(interaction) if command is not None: await ctx.send_help(command) @@ -258,7 +255,7 @@ async def command_autocomplete( """Autocompletes the help command.""" assert self.bot.help_command - ctx = await self.bot.get_context(interaction, cls=beira.Context) + ctx = await self.bot.get_context(interaction) help_command = self.bot.help_command.copy() help_command.context = ctx @@ -277,6 +274,4 @@ async def command_autocomplete( async def setup(bot: beira.Beira) -> None: - """Connects cog to bot.""" - await bot.add_cog(HelpCog(bot)) diff --git a/src/beira/exts/lol.py b/src/beira/exts/lol.py index 88cf925..09f3484 100644 --- a/src/beira/exts/lol.py +++ b/src/beira/exts/lol.py @@ -5,7 +5,6 @@ import asyncio import itertools -import logging from pathlib import Path from typing import Any, Self from urllib.parse import quote, urljoin @@ -20,8 +19,6 @@ from beira.utils import StatsEmbed -LOGGER = logging.getLogger(__name__) - GECKODRIVER = Path().resolve().joinpath("drivers/geckodriver/geckodriver.exe") GECKODRIVER_LOGS = Path().resolve().joinpath("logs/geckodriver.log") diff --git a/src/beira/exts/other.py b/src/beira/exts/other.py index 37a9ef7..a072580 100644 --- a/src/beira/exts/other.py +++ b/src/beira/exts/other.py @@ -3,7 +3,6 @@ import asyncio import colorsys import importlib.metadata -import logging import math import random import re @@ -19,8 +18,6 @@ import beira -LOGGER = logging.getLogger(__name__) - INSPIROBOT_API_URL = "https://inspirobot.me/api" INSPIROBOT_ICON_URL = "https://pbs.twimg.com/profile_images/815624354876760064/zPmAZWP4_400x400.jpg" diff --git a/src/beira/exts/patreon.py b/src/beira/exts/patreon.py index 4c3c27c..f0c5220 100644 --- a/src/beira/exts/patreon.py +++ b/src/beira/exts/patreon.py @@ -50,7 +50,7 @@ def from_record(cls, record: asyncpg.Record) -> Self: return cls(*(record[attr] for attr in attrs_), emoji=discord.PartialEmoji.from_str(record["tier_emoji"])) -class PatreonTierSelectView(PaginatedSelectView[PatreonTierInfo]): +class PatreonTierSelectView(PaginatedSelectView[list[PatreonTierInfo]]): """A view that displays Patreon tiers and benefits as pages.""" async def on_timeout(self) -> None: diff --git a/src/beira/exts/snowball/snowball.py b/src/beira/exts/snowball/snowball.py index 6c77dfe..48406f1 100644 --- a/src/beira/exts/snowball/snowball.py +++ b/src/beira/exts/snowball/snowball.py @@ -10,14 +10,16 @@ import logging import random from itertools import cycle, islice +from typing import Self import asyncpg import discord +import msgspec from discord import app_commands from discord.ext import commands import beira -from beira.utils import EMOJI_STOCK, StatsEmbed +from beira.utils import EMOJI_STOCK, Connection_alias, Pool_alias, StatsEmbed from .snow_text import ( COLLECT_FAIL_IMGS, @@ -29,14 +31,6 @@ SNOW_CODE_NOTE, SNOW_INSPO_NOTE, ) -from .utils import ( - GuildSnowballSettings, - SnowballRecord, - SnowballSettingsView, - collect_cooldown, - steal_cooldown, - transfer_cooldown, -) LOGGER = logging.getLogger(__name__) @@ -44,6 +38,293 @@ LEADERBOARD_MAX = 10 # Number of people shown on one leaderboard at a time. +class SnowballRecord(msgspec.Struct): + """Record-like structure that represents a member's snowball record. + + Attributes + ---------- + hits: int + The number of snowballs used that the member just hit people with. + misses: int + The number of snowballs used the member just tried to hit someone with and missed. + kos: int + The number of hits the member just took. + stock: int + The *change* in how many snowballs the member has in stock. + """ + + hits: int + misses: int + kos: int + stock: int + + @classmethod + def from_record(cls, record: asyncpg.Record) -> Self: + return cls(*(record[val] for val in ("hits", "misses", "kos", "stock"))) + + +class GuildSnowballSettings(msgspec.Struct): + """Record-like structure to hold a guild's snowball settings. + + Attributes + ---------- + guild_id: int, default=0 + The guild these settings apply to. Defaults to 0. + hit_odds: float, default=0.6 + Chance of hitting someone with a snowball. Defaults to 0.6. + stock_cap: int, default=100 + Maximum number of snowballs regular members can hold in their inventory. Defaults to 100. + transfer_cap: int, default=10 + Maximum number of snowballs that can be gifted or stolen. Defaults to 10. + """ + + guild_id: int = 0 + hit_odds: float = 0.6 + stock_cap: int = 100 + transfer_cap: int = 10 + + @classmethod + def from_record(cls, record: asyncpg.Record) -> Self: + return cls(record["guild_id"], record["hit_odds"], record["stock_cap"], record["transfer_cap"]) + + +async def update_user_snow_record( + conn: Pool_alias | Connection_alias, + member: discord.Member, + hits: int = 0, + misses: int = 0, + kos: int = 0, + stock: int = 0, +) -> SnowballRecord | None: + """Upsert a user's snowball stats based on the given stat parameters.""" + + # Upsert the relevant users and guilds to the database before adding a snowball record. + stmt = """\ + INSERT INTO guilds (guild_id) VALUES ($1) ON CONFLICT (guild_id) DO NOTHING; + INSERT INTO users (user_id) VALUES ($2) ON CONFLICT (user_id) DO NOTHING; + INSERT INTO members (guild_id, user_id) VALUES ($1, $2) ON CONFLICT (guild_id, user_id) DO NOTHING; + + INSERT INTO snowball_stats (user_id, guild_id, hits, misses, kos, stock) + VALUES ($2, $1, $3, $4, $5, $6) + ON CONFLICT (user_id, guild_id) DO UPDATE + SET hits = snowball_stats.hits + EXCLUDED.hits, + misses = snowball_stats.misses + EXCLUDED.misses, + kos = snowball_stats.kos + EXCLUDED.kos, + stock = snowball_stats.stock + $7 + RETURNING *; + """ + + values = (member.id, member.guild.id, hits, misses, kos, max(stock, 0), stock) + record = await conn.fetchrow(stmt, *values) + return SnowballRecord.from_record(record) if record else None + + +async def update_guild_snow_settings(conn: Pool_alias | Connection_alias, settings: GuildSnowballSettings) -> None: + """Upsert these snowball settings into the database.""" + + stmt = """\ + INSERT INTO snowball_settings (guild_id, hit_odds, stock_cap, transfer_cap) + VALUES ($1, $2, $3, $4) + ON CONFLICT(guild_id) + DO UPDATE + SET hit_odds = EXCLUDED.hit_odds, + stock_cap = EXCLUDED.stock_cap, + transfer_cap = EXCLUDED.transfer_cap; + """ + await conn.execute(stmt, settings.guild_id, settings.hit_odds, settings.stock_cap, settings.transfer_cap) + + +# region -------- Views + + +class SnowballSettingsModal(discord.ui.Modal): + """Custom modal for changing the guild-specific settings of the snowball game. + + Parameters + ---------- + default_settings: SnowballSettings + The current snowball-related settings for the guild. + + Attributes + ---------- + hit_odds_input: discord.ui.TextInput + An editable text field showing the current hit odds for this guild. + stock_cap_input: discord.ui.TextInput + An editable text field showing the current stock cap for this guild. + transfer_cap_input: discord.ui.TextInput + An editable text field showing the current transfer cap for this guild. + default_settings: SnowballSettings + The current snowball-related settings for the guild. + new_settings: SnowballSettings, optional + The new snowball-related settings for this guild from user input. + """ + + def __init__(self, default_settings: GuildSnowballSettings) -> None: + super().__init__(title="This Guild's Snowball Settings") + + # Create the items. + self.hit_odds_input: discord.ui.TextInput[Self] = discord.ui.TextInput( + label="The chance of hitting a person (0.0-1.0)", + placeholder=f"Current: {default_settings.hit_odds:.2}", + default=f"{default_settings.hit_odds:.2}", + required=False, + ) + self.stock_cap_input: discord.ui.TextInput[Self] = discord.ui.TextInput( + label="Max snowballs a member can hold (no commas)", + placeholder=f"Current: {default_settings.stock_cap}", + default=str(default_settings.stock_cap), + required=False, + ) + self.transfer_cap_input: discord.ui.TextInput[Self] = discord.ui.TextInput( + label="Max snowballs that can be gifted/stolen", + placeholder=f"Current: {default_settings.transfer_cap}", + default=str(default_settings.transfer_cap), + required=False, + ) + + # Add the items. + for item in (self.hit_odds_input, self.stock_cap_input, self.transfer_cap_input): + self.add_item(item) + + # Save the settings. + self.default_settings: GuildSnowballSettings = default_settings + self.new_settings: GuildSnowballSettings | None = None + + async def on_submit(self, interaction: beira.Interaction, /) -> None: # type: ignore # Narrowing. + """Verify changes and update the snowball settings in the database appropriately.""" + + guild_id = self.default_settings.guild_id + + # Get the new settings values and verify that they are be the right types. + new_odds_val = self.default_settings.hit_odds + try: + temp = float(self.hit_odds_input.value) + except ValueError: + pass + else: + if 0.0 <= temp <= 1.0: + new_odds_val = temp + + new_stock_val = self.default_settings.stock_cap + try: + temp = int(self.stock_cap_input.value) + except ValueError: + pass + else: + if temp >= 0: + new_stock_val = temp + + new_transfer_val = self.default_settings.transfer_cap + try: + temp = int(self.transfer_cap_input.value) + except ValueError: + pass + else: + if temp >= 0: + new_transfer_val = temp + + # Update the record in the database if there's been a change. + self.new_settings = GuildSnowballSettings(guild_id, new_odds_val, new_stock_val, new_transfer_val) + if self.new_settings != self.default_settings: + await update_guild_snow_settings(interaction.client.db_pool, self.new_settings) + await interaction.response.send_message("Settings updated!") + + +class SnowballSettingsView(discord.ui.View): + """A view with a button that allows server administrators and bot owners to change snowball-related settings. + + Parameters + ---------- + guild_settings: SnowballSettings + The current snowball-related settings for the guild. + + Attributes + ---------- + settings: SnowballSettings + The current snowball-related settings for the guild. + message: discord.Message + The message an instance of this view is attached to. + """ + + message: discord.Message + + def __init__(self, guild_name: str, guild_settings: GuildSnowballSettings) -> None: + super().__init__() + self.guild_name = guild_name + self.settings: GuildSnowballSettings = guild_settings + + async def on_timeout(self) -> None: + # Disable everything on timeout. + + for item in self.children: + item.disabled = True # type: ignore + + await self.message.edit(view=self) + + async def interaction_check(self, interaction: beira.Interaction, /) -> bool: # type: ignore # Needed narrowing. + """Ensure people interacting with this view are only server administrators or bot owners.""" + + # This should only ever be called in a guild context. + assert interaction.guild + assert isinstance(interaction.user, discord.Member) + + user = interaction.user + check = bool(user.guild_permissions.administrator or interaction.client.owner_id == user.id) + + if not check: + await interaction.response.send_message("You can't change that unless you're a guild admin.") + return check + + def format_embed(self) -> discord.Embed: + return ( + discord.Embed( + color=0x5E9A40, + title=f"Snowball Settings in {self.guild_name}", + description=( + "Below are the settings for the bot's snowball hit rate, stock maximum, and more. Settings can be " + "added on a per-guild basis, but currently don't have any effect. Fix coming soon." + ), + ) + .add_field( + name=f"Odds = {self.settings.hit_odds}", + value="The odds of landing a snowball on someone.", + inline=False, + ) + .add_field( + name=f"Default Stock Cap = {self.settings.stock_cap}", + value="The maximum number of snowballs the average member can hold at once.", + inline=False, + ) + .add_field( + name=f"Transfer Cap = {self.settings.transfer_cap}", + value="The maximum number of snowballs that can be gifted or stolen at once.", + inline=False, + ) + ) + + @discord.ui.button(label="Update", emoji="⚙") + async def change_settings_button(self, interaction: beira.Interaction, _: discord.ui.Button[Self]) -> None: + """Send a modal that allows the user to edit the snowball settings for this guild.""" + + # Get inputs from a modal. + modal = SnowballSettingsModal(self.settings) + await interaction.response.send_modal(modal) + modal_timed_out = await modal.wait() + + if modal_timed_out or self.is_finished(): + return + + # Update the known settings. + if modal.new_settings is not None and modal.new_settings != self.settings: + self.settings = modal.new_settings + + # Edit the embed with the settings information. + await interaction.edit_original_response(embed=self.format_embed()) + + +# endregion + + class SnowballCog(commands.Cog, name="Snowball"): """A cog that implements all snowball fight-related commands, like Discord's 2021 Snowball bot game.""" @@ -120,7 +401,8 @@ async def settings(self, ctx: beira.GuildContext) -> None: """Show what the settings are for the snowballs in this server.""" # Get the settings for the guild and make an embed display. - guild_settings = await GuildSnowballSettings.from_database(ctx.db, ctx.guild.id) + record = await ctx.db.fetchrow("SELECT * FROM snowball_settings WHERE guild_id = $1;", ctx.guild.id) + guild_settings = GuildSnowballSettings.from_record(record) if record else GuildSnowballSettings(ctx.guild.id) view = SnowballSettingsView(ctx.guild.name, guild_settings) embed = view.format_embed() @@ -130,6 +412,24 @@ async def settings(self, ctx: beira.GuildContext) -> None: else: await ctx.send(embed=embed) + @staticmethod + def collect_cooldown(ctx: beira.Context) -> commands.Cooldown | None: + """Sets cooldown for SnowballCog.collect() command. 10 seconds by default. + + Bot owner and friends get less time. + """ + + rate, per = 1.0, 15.0 # Default cooldown + exempt = [ctx.bot.owner_id, ctx.bot.special_friends["aeroali"]] + testing_guild_ids: list[int] = ctx.bot.config.discord.important_guilds["dev"] + + if ctx.author.id in exempt: + return None + if ctx.guild and (ctx.guild.id in testing_guild_ids): + per = 1.0 + + return commands.Cooldown(rate, per) + @snow.command() @commands.guild_only() @commands.dynamic_cooldown(collect_cooldown, commands.cooldowns.BucketType.user) # type: ignore @@ -141,10 +441,10 @@ async def collect(self, ctx: beira.GuildContext) -> None: base_stock_cap = guild_snow_settings.stock_cap # Only special people get the higher snowball limit. - privilege_check = bool(ctx.author.id == self.bot.owner_id or self.bot.is_ali(ctx.author)) + privilege_check = ctx.author.id in (self.bot.owner_id, ctx.bot.special_friends["aeroali"]) stock_limit = base_stock_cap * 2 if privilege_check else base_stock_cap - record = await SnowballRecord.upsert_record(ctx.db, ctx.author, stock=1) + record = await update_user_snow_record(ctx.db, ctx.author, stock=1) embed = discord.Embed(color=0x5E62D3) if record: @@ -200,15 +500,15 @@ async def throw(self, ctx: beira.GuildContext, *, target: discord.Member) -> Non # Update the database records and prepare the response message and embed based on the outcome. if roll < base_hit_odds: async with ctx.db.acquire() as conn, conn.transaction(): - await SnowballRecord.upsert_record(conn, ctx.author, hits=1, stock=-1) - await SnowballRecord.upsert_record(conn, target, kos=1) + await update_user_snow_record(conn, ctx.author, hits=1, stock=-1) + await update_user_snow_record(conn, target, kos=1) embed.description = random.choice(HIT_NOTES).format(target.mention) embed.set_image(url=random.choice(HIT_IMGS)) message = target.mention else: - await SnowballRecord.upsert_record(ctx.db, ctx.author, misses=1) + await update_user_snow_record(ctx.db, ctx.author, misses=1) misses_text = random.choice(MISS_NOTES) embed.colour = 0xFFA600 @@ -223,6 +523,21 @@ async def throw(self, ctx: beira.GuildContext, *, target: discord.Member) -> Non await ctx.send(content=message, embed=embed, ephemeral=ephemeral) + @staticmethod + def transfer_cooldown(ctx: beira.Context) -> commands.Cooldown | None: + """Sets cooldown for transfer command. 60 seconds by default, less for bot owner and friends.""" + + rate, per = 1.0, 60.0 # Default cooldown + exempt = [ctx.bot.owner_id, ctx.bot.special_friends["aeroali"]] + testing_guild_ids: list[int] = ctx.bot.config.discord.important_guilds["dev"] + + if ctx.author.id in exempt: + return None + if ctx.guild and (ctx.guild.id in testing_guild_ids): + per = 2.0 + + return commands.Cooldown(rate, per) + @snow.command() @commands.guild_only() @commands.dynamic_cooldown(transfer_cooldown, commands.cooldowns.BucketType.user) # type: ignore @@ -251,7 +566,7 @@ async def transfer(self, ctx: beira.GuildContext, amount: int, *, receiver: disc base_stock_cap = guild_snow_settings.stock_cap # Only special people get the higher snowball limit. - privilege_check = bool(ctx.author.id == self.bot.owner_id or self.bot.is_ali(ctx.author)) + privilege_check = ctx.author.id in (self.bot.owner_id, ctx.bot.special_friends["aeroali"]) stock_limit = base_stock_cap * 2 if privilege_check else base_stock_cap # Build on an embed. @@ -288,14 +603,29 @@ async def transfer(self, ctx: beira.GuildContext, amount: int, *, receiver: disc # Update the giver and receiver's records. async with ctx.db.acquire() as conn, conn.transaction(): - await SnowballRecord.upsert_record(conn, ctx.author, stock=-amount) - await SnowballRecord.upsert_record(conn, receiver, stock=amount) + await update_user_snow_record(conn, ctx.author, stock=-amount) + await update_user_snow_record(conn, receiver, stock=amount) # Send notification message of successful transfer. def_embed.description = f"Transfer successful! You've given {receiver.mention} {amount} of your snowballs!" message = f"{ctx.author.mention}, {receiver.mention}" await ctx.send(content=message, embed=def_embed, ephemeral=False) + @staticmethod + def steal_cooldown(ctx: beira.Context) -> commands.Cooldown | None: + """Sets cooldown for steal command. 90 seconds by default, less for bot owner and friends.""" + + rate, per = 1.0, 90.0 # Default cooldown + exempt = [ctx.bot.owner_id, ctx.bot.special_friends["aeroali"], ctx.bot.special_friends["athenahope"]] + testing_guild_ids: list[int] = ctx.bot.config.discord.important_guilds["dev"] + + if ctx.author.id in exempt: + return None + if ctx.guild and (ctx.guild.id in testing_guild_ids): + per = 2.0 + + return commands.Cooldown(rate, per) + @snow.command() @commands.guild_only() @beira.is_owner_or_friend() @@ -328,7 +658,7 @@ async def steal(self, ctx: beira.GuildContext, amount: int, *, victim: discord.M base_stock_cap = guild_snow_settings.stock_cap # Only special people get the higher snowball limit. - privilege_check = bool(ctx.author.id == self.bot.owner_id or self.bot.is_ali(ctx.author)) + privilege_check = ctx.author.id in (self.bot.owner_id, ctx.bot.special_friends["aeroali"]) stock_limit = base_stock_cap * 2 if privilege_check else base_stock_cap # Build on an embed. @@ -370,8 +700,8 @@ async def steal(self, ctx: beira.GuildContext, amount: int, *, victim: discord.M async with ctx.db.acquire() as conn, conn.transaction(): assert victim_record is not None amount_to_steal = min(victim_record["stock"], amount) - await SnowballRecord.upsert_record(conn, ctx.author, stock=amount_to_steal) - await SnowballRecord.upsert_record(conn, victim, stock=-amount_to_steal) + await update_user_snow_record(conn, ctx.author, stock=amount_to_steal) + await update_user_snow_record(conn, victim, stock=-amount_to_steal) # Send notification message of successful theft. def_embed.description = f"Thievery successful! You've stolen {amount_to_steal} snowballs from {victim.mention}!" @@ -562,7 +892,12 @@ async def snow_before(self, ctx: beira.GuildContext) -> None: This allows the use of guild-specific limits stored in the db and now temporarily in the context. """ - guild_snow_settings = await GuildSnowballSettings.from_database(ctx.db, ctx.guild.id) + record = await ctx.db.fetchrow("SELECT * FROM snowball_settings WHERE guild_id = $1;", ctx.guild.id) + if record: + guild_snow_settings = GuildSnowballSettings.from_record(record) + else: + guild_snow_settings = GuildSnowballSettings(ctx.guild.id) + ctx.guild_snow_settings = guild_snow_settings # type: ignore # Runtime attribute assignment. @collect.after_invoke @@ -593,11 +928,7 @@ def _get_entity_from_record(record: asyncpg.Record) -> discord.Guild | discord.U special_stars = (EMOJI_STOCK[name] for name in ("orange_star", "blue_star", "pink_star")) # Create a list of emojis to accompany the leaderboard members. - ldbd_places_emojis = ( - "\N{GLOWING STAR}", - "\N{WHITE MEDIUM STAR}", - *tuple(emoji for emoji in islice(cycle(special_stars), 8)), - ) + ldbd_places_emojis = ("\N{GLOWING STAR}", "\N{WHITE MEDIUM STAR}", *islice(cycle(special_stars), 8)) # Assemble each entry's data. snow_data = [(_get_entity_from_record(row), row["hits"], row["misses"], row["kos"]) for row in records] diff --git a/src/beira/exts/snowball/utils.py b/src/beira/exts/snowball/utils.py deleted file mode 100644 index e24018b..0000000 --- a/src/beira/exts/snowball/utils.py +++ /dev/null @@ -1,368 +0,0 @@ -from typing import Self - -import asyncpg -import discord -import msgspec -from discord.ext import commands - -import beira -from beira.utils.db import Connection_alias, Pool_alias - - -__all__ = ( - "SnowballRecord", - "GuildSnowballSettings", - "SnowballSettingsModal", - "SnowballSettingsView", - "collect_cooldown", - "transfer_cooldown", - "steal_cooldown", -) - - -class SnowballRecord(msgspec.Struct): - """Record-like structure that represents a member's snowball record. - - Attributes - ---------- - hits: int - The number of snowballs used that the member just hit people with. - misses: int - The number of snowballs used the member just tried to hit someone with and missed. - kos: int - The number of hits the member just took. - stock: int - The change in how many snowballs the member has in stock. - """ - - hits: int - misses: int - kos: int - stock: int - - @classmethod - def from_record(cls: type[Self], record: asyncpg.Record | None) -> Self | None: - if record: - return cls(record["hits"], record["misses"], record["kos"], record["stock"]) - return None - - @classmethod - async def upsert_record( - cls, - conn: Pool_alias | Connection_alias, - member: discord.Member, - hits: int = 0, - misses: int = 0, - kos: int = 0, - stock: int = 0, - ) -> Self | None: - """Upserts a user's snowball stats based on the given stat parameters.""" - - # Upsert the relevant users and guilds to the database before adding a snowball record. - user_stmt = "INSERT INTO users (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING;" - await conn.execute(user_stmt, member.id) - guild_stmt = "INSERT INTO guilds (guild_id) VALUES ($1) ON CONFLICT (guild_id) DO NOTHING;" - await conn.execute(guild_stmt, member.guild.id) - member_stmt = ( - "INSERT INTO members (guild_id, user_id) VALUES ($1, $2) ON CONFLICT (guild_id, user_id) DO NOTHING;" - ) - await conn.execute(member_stmt, member.guild.id, member.id) - - snowball_stmt = """\ - INSERT INTO snowball_stats (user_id, guild_id, hits, misses, kos, stock) - VALUES ($1, $2, $3, $4, $5, $6) - ON CONFLICT (user_id, guild_id) DO UPDATE - SET hits = snowball_stats.hits + EXCLUDED.hits, - misses = snowball_stats.misses + EXCLUDED.misses, - kos = snowball_stats.kos + EXCLUDED.kos, - stock = snowball_stats.stock + $7 - RETURNING *; - """ - values = (member.id, member.guild.id, hits, misses, kos, max(stock, 0), stock) - record = await conn.fetchrow(snowball_stmt, *values) - return cls.from_record(record) - - -class GuildSnowballSettings(msgspec.Struct): - """Record-like structure to hold a guild's snowball settings. - - Attributes - ---------- - guild_id: int, default=0 - The guild these settings apply to. Defaults to 0. - hit_odds: float, default=0.6 - Chance of hitting someone with a snowball. Defaults to 0.6. - stock_cap: int, default=100 - Maximum number of snowballs regular members can hold in their inventory. Defaults to 100. - transfer_cap: int, default=10 - Maximum number of snowballs that can be gifted or stolen. Defaults to 10. - """ - - guild_id: int = 0 - hit_odds: float = 0.6 - stock_cap: int = 100 - transfer_cap: int = 10 - - @classmethod - def from_record(cls: type[Self], record: asyncpg.Record) -> Self: - return cls(record["guild_id"], record["hit_odds"], record["stock_cap"], record["transfer_cap"]) - - @classmethod - async def from_database(cls: type[Self], conn: Pool_alias | Connection_alias, guild_id: int) -> Self: - """Query a snowball settings database record for a guild.""" - - record = await conn.fetchrow("SELECT * FROM snowball_settings WHERE guild_id = $1;", guild_id) - return cls.from_record(record) if record else cls(guild_id) - - async def upsert_record(self, conn: Pool_alias | Connection_alias) -> None: - """Upsert these snowball settings into the database.""" - - stmt = """\ - INSERT INTO snowball_settings (guild_id, hit_odds, stock_cap, transfer_cap) - VALUES ($1, $2, $3, $4) - ON CONFLICT(guild_id) - DO UPDATE - SET hit_odds = EXCLUDED.hit_odds, - stock_cap = EXCLUDED.stock_cap, - transfer_cap = EXCLUDED.transfer_cap; - """ - await conn.execute(stmt, self.guild_id, self.hit_odds, self.stock_cap, self.transfer_cap) - - -class SnowballSettingsModal(discord.ui.Modal): - """Custom modal for changing the guild-specific settings of the snowball game. - - Parameters - ---------- - default_settings: SnowballSettings - The current snowball-related settings for the guild. - - Attributes - ---------- - hit_odds_input: discord.ui.TextInput - An editable text field showing the current hit odds for this guild. - stock_cap_input: discord.ui.TextInput - An editable text field showing the current stock cap for this guild. - transfer_cap_input: discord.ui.TextInput - An editable text field showing the current transfer cap for this guild. - default_settings: SnowballSettings - The current snowball-related settings for the guild. - new_settings: SnowballSettings, optional - The new snowball-related settings for this guild from user input. - """ - - def __init__(self, default_settings: GuildSnowballSettings) -> None: - super().__init__(title="This Guild's Snowball Settings") - - # Create the items. - self.hit_odds_input: discord.ui.TextInput[Self] = discord.ui.TextInput( - label="The chance of hitting a person (0.0-1.0)", - placeholder=f"Current: {default_settings.hit_odds:.2}", - default=f"{default_settings.hit_odds:.2}", - required=False, - ) - self.stock_cap_input: discord.ui.TextInput[Self] = discord.ui.TextInput( - label="Max snowballs a member can hold (no commas)", - placeholder=f"Current: {default_settings.stock_cap}", - default=str(default_settings.stock_cap), - required=False, - ) - self.transfer_cap_input: discord.ui.TextInput[Self] = discord.ui.TextInput( - label="Max snowballs that can be gifted/stolen", - placeholder=f"Current: {default_settings.transfer_cap}", - default=str(default_settings.transfer_cap), - required=False, - ) - - # Add the items. - for item in (self.hit_odds_input, self.stock_cap_input, self.transfer_cap_input): - self.add_item(item) - - # Save the settings. - self.default_settings: GuildSnowballSettings = default_settings - self.new_settings: GuildSnowballSettings | None = None - - async def on_submit(self, interaction: beira.Interaction, /) -> None: # type: ignore # Narrowing. - """Verify changes and update the snowball settings in the database appropriately.""" - - guild_id = self.default_settings.guild_id - - # Get the new settings values and verify that they are be the right types. - new_odds_val = self.default_settings.hit_odds - try: - temp = float(self.hit_odds_input.value) - except ValueError: - pass - else: - if 0.0 <= temp <= 1.0: - new_odds_val = temp - - new_stock_val = self.default_settings.stock_cap - try: - temp = int(self.stock_cap_input.value) - except ValueError: - pass - else: - if temp >= 0: - new_stock_val = temp - - new_transfer_val = self.default_settings.transfer_cap - try: - temp = int(self.transfer_cap_input.value) - except ValueError: - pass - else: - if temp >= 0: - new_transfer_val = temp - - # Update the record in the database if there's been a change. - self.new_settings = GuildSnowballSettings(guild_id, new_odds_val, new_stock_val, new_transfer_val) - if self.new_settings != self.default_settings: - await self.new_settings.upsert_record(interaction.client.db_pool) - await interaction.response.send_message("Settings updated!") - - -class SnowballSettingsView(discord.ui.View): - """A view with a button that allows server administrators and bot owners to change snowball-related settings. - - Parameters - ---------- - guild_settings: SnowballSettings - The current snowball-related settings for the guild. - - Attributes - ---------- - settings: SnowballSettings - The current snowball-related settings for the guild. - message: discord.Message - The message an instance of this view is attached to. - """ - - message: discord.Message - - def __init__(self, guild_name: str, guild_settings: GuildSnowballSettings) -> None: - super().__init__() - self.guild_name = guild_name - self.settings: GuildSnowballSettings = guild_settings - - async def on_timeout(self) -> None: - # Disable everything on timeout. - - for item in self.children: - item.disabled = True # type: ignore - - await self.message.edit(view=self) - - async def interaction_check(self, interaction: beira.Interaction, /) -> bool: # type: ignore # Needed narrowing. - """Ensure people interacting with this view are only server administrators or bot owners.""" - - # This should only ever be called in a guild context. - assert interaction.guild - assert isinstance(interaction.user, discord.Member) - - user = interaction.user - check = bool(user.guild_permissions.administrator or interaction.client.owner_id == user.id) - - if not check: - await interaction.response.send_message("You can't change that unless you're a guild admin.") - return check - - def format_embed(self) -> discord.Embed: - return ( - discord.Embed( - color=0x5E9A40, - title=f"Snowball Settings in {self.guild_name}", - description=( - "Below are the settings for the bot's snowball hit rate, stock maximum, and more. Settings can be " - "added on a per-guild basis, but currently don't have any effect. Fix coming soon." - ), - ) - .add_field( - name=f"Odds = {self.settings.hit_odds}", - value="The odds of landing a snowball on someone.", - inline=False, - ) - .add_field( - name=f"Default Stock Cap = {self.settings.stock_cap}", - value="The maximum number of snowballs the average member can hold at once.", - inline=False, - ) - .add_field( - name=f"Transfer Cap = {self.settings.transfer_cap}", - value="The maximum number of snowballs that can be gifted or stolen at once.", - inline=False, - ) - ) - - @discord.ui.button(label="Update", emoji="⚙") - async def change_settings_button(self, interaction: beira.Interaction, _: discord.ui.Button[Self]) -> None: - """Send a modal that allows the user to edit the snowball settings for this guild.""" - - # Get inputs from a modal. - modal = SnowballSettingsModal(self.settings) - await interaction.response.send_modal(modal) - modal_timed_out = await modal.wait() - - if modal_timed_out or self.is_finished(): - return - - # Update the known settings. - if modal.new_settings is not None and modal.new_settings != self.settings: - self.settings = modal.new_settings - - # Edit the embed with the settings information. - await interaction.edit_original_response(embed=self.format_embed()) - - -def collect_cooldown(ctx: beira.Context) -> commands.Cooldown | None: - """Sets cooldown for SnowballCog.collect() command. 10 seconds by default. - - Bot owner and friends get less time. - """ - - rate, per = 1.0, 15.0 # Default cooldown - exempt = [ctx.bot.owner_id, ctx.bot.special_friends["aeroali"]] - testing_guild_ids: list[int] = ctx.bot.config.discord.important_guilds["dev"] - - if ctx.author.id in exempt: - return None - if ctx.guild and (ctx.guild.id in testing_guild_ids): - per = 1.0 - - return commands.Cooldown(rate, per) - - -def transfer_cooldown(ctx: beira.Context) -> commands.Cooldown | None: - """Sets cooldown for SnowballCog.transfer() command. 60 seconds by default. - - Bot owner and friends get less time. - """ - - rate, per = 1.0, 60.0 # Default cooldown - exempt = [ctx.bot.owner_id, ctx.bot.special_friends["aeroali"]] - testing_guild_ids: list[int] = ctx.bot.config.discord.important_guilds["dev"] - - if ctx.author.id in exempt: - return None - if ctx.guild and (ctx.guild.id in testing_guild_ids): - per = 2.0 - - return commands.Cooldown(rate, per) - - -def steal_cooldown(ctx: beira.Context) -> commands.Cooldown | None: - """Sets cooldown for SnowballCog.steal() command. 90 seconds by default. - - Bot owner and friends get less time. - """ - - rate, per = 1.0, 90.0 # Default cooldown - exempt = [ctx.bot.owner_id, ctx.bot.special_friends["aeroali"], ctx.bot.special_friends["athenahope"]] - testing_guild_ids: list[int] = ctx.bot.config.discord.important_guilds["dev"] - - if ctx.author.id in exempt: - return None - if ctx.guild and (ctx.guild.id in testing_guild_ids): - per = 2.0 - - return commands.Cooldown(rate, per) diff --git a/src/beira/exts/starkid.py b/src/beira/exts/starkid.py index 79177a0..740fa9b 100644 --- a/src/beira/exts/starkid.py +++ b/src/beira/exts/starkid.py @@ -3,17 +3,12 @@ Shoutout to Theo and Ali for inspiration, as well as the whole StarKid server. """ -import logging - import discord from discord.ext import commands import beira -LOGGER = logging.getLogger(__name__) - - class StarKidCog(commands.Cog, name="StarKid"): """A cog for StarKid-related commands and functionality.""" diff --git a/src/beira/exts/timing.py b/src/beira/exts/timing.py index 9a45a22..a683599 100644 --- a/src/beira/exts/timing.py +++ b/src/beira/exts/timing.py @@ -11,12 +11,13 @@ # The supposed schedule table right now. +# Requires a postgres extension: https://github.com/fboulnois/pg_uuidv7 """ CREATE TABLE IF NOT EXISTS scheduled_dispatches ( - task_id UUID PRIMARY KEY, - dispatch_name TEXT NOT NULL, - dispatch_time TIMESTAMP WITH TIME ZONE NOT NULL, - dispatch_zone TEXT NOT NULL, + task_id UUID PRIMARY KEY DEFAULT uuid_generate_v7(), + dispatch_name TEXT NOT NULL, + dispatch_time TIMESTAMP NOT NULL, + dispatch_zone TEXT NOT NULL, associated_guild BIGINT, associated_user BIGINT, dispatch_extra JSONB diff --git a/src/beira/exts/todo.py b/src/beira/exts/todo.py index 88f5314..f552e1e 100644 --- a/src/beira/exts/todo.py +++ b/src/beira/exts/todo.py @@ -1,7 +1,6 @@ """A module/cog for handling todo lists made in Discord and stored in a database.""" import datetime -import logging import textwrap from abc import ABC, abstractmethod from typing import Any, Self @@ -15,9 +14,6 @@ from beira.utils import Connection_alias, OwnedView, PaginatedEmbedView, Pool_alias -LOGGER = logging.getLogger(__name__) - - class TodoItem(msgspec.Struct): todo_id: int user_id: int @@ -47,50 +43,6 @@ def from_record(cls, record: asyncpg.Record) -> Self: def generate_deleted(cls) -> Self: return cls(-1, 0, "", discord.utils.utcnow()) - async def change_completion(self, conn: Pool_alias | Connection_alias) -> Self: - """Adds or removes a completion date from the record in the database, giving back the new version of the record. - - This function returns a new instance of the class. - - Parameters - ---------- - conn: asyncpg.Pool | asyncpg.Connection - The connection/pool that will be used to make this database command. - """ - - stmt = "UPDATE todos SET todo_completed_at = $1 WHERE todo_id = $2 RETURNING *;" - new_date = discord.utils.utcnow() if self.completed_at is None else None - record = await conn.fetchrow(stmt, new_date, self.todo_id) - return self.from_record(record) if record else type(self).generate_deleted() - - async def update(self, conn: Pool_alias | Connection_alias, updated_content: str) -> Self: - """Changes the to-do content of the record, giving back the new version of the record. - - This function returns a new instance of the class. - - Parameters - ---------- - conn: asyncpg.Pool | asyncpg.Connection - The connection/pool that will be used to make this database command. - updated_content: str - The new to-do content. - """ - - stmt = "UPDATE todos SET todo_content = $1 WHERE todo_id = $2 RETURNING *;" - record = await conn.fetchrow(stmt, updated_content, self.todo_id) - return self.from_record(record) if record else type(self).generate_deleted() - - async def delete(self, conn: Pool_alias | Connection_alias) -> None: - """Deletes the record from the database. - - Parameters - ---------- - conn: asyncpg.Pool | asyncpg.Connection - The connection/pool that will be used to make this database command. - """ - - await conn.execute("DELETE FROM todos where todo_id = $1;", self.todo_id) - def display_embed(self, *, to_be_deleted: bool = False) -> discord.Embed: """Generates a formatted embed from a to-do record. @@ -130,6 +82,29 @@ def display_embed(self, *, to_be_deleted: bool = False) -> discord.Embed: return todo_embed +async def change_todo_completion(conn: Pool_alias | Connection_alias, item: TodoItem) -> TodoItem: + """Adds or removes a completion date from a to-do item in the database, giving back the new version of the item.""" + + stmt = "UPDATE todos SET todo_completed_at = $1 WHERE todo_id = $2 RETURNING *;" + new_date = discord.utils.utcnow() if item.completed_at is None else None + record = await conn.fetchrow(stmt, new_date, item.todo_id) + return TodoItem.from_record(record) if record else TodoItem.generate_deleted() + + +async def update_todo_content(conn: Pool_alias, item: TodoItem, updated_content: str) -> TodoItem: + """Changes the content of the to-do item, giving back the new version of the item.""" + + stmt = "UPDATE todos SET todo_content = $1 WHERE todo_id = $2 RETURNING *;" + record = await conn.fetchrow(stmt, updated_content, item.todo_id) + return TodoItem.from_record(record) if record else TodoItem.generate_deleted() + + +async def delete_todo(conn: Pool_alias, item: TodoItem) -> None: + """Deletes the to-do item from the database.""" + + await conn.execute("DELETE FROM todos where todo_id = $1;", item.todo_id) + + class TodoModal(discord.ui.Modal, title="What do you want to do?"): """A Discord modal for putting in or editing the content of a to-do item. @@ -207,7 +182,7 @@ async def callback(self, interaction: beira.Interaction) -> None: # type: ignor assert self.view is not None # Get a new version of the to-do item after adding a completion date. - updated_todo_item = await self.view.todo_item.change_completion(interaction.client.db_pool) + updated_todo_item = await change_todo_completion(interaction.client.db_pool, self.view.todo_item) # Adjust the button based on the item. if updated_todo_item.completed_at is None: @@ -257,7 +232,11 @@ async def callback(self, interaction: beira.Interaction) -> None: # type: ignor # Adjust the view to have and display the updated to-do item, and let the user know it's updated. if self.view.todo_item.content != modal.content.value: - updated_todo_item = await self.view.todo_item.update(interaction.client.db_pool, modal.content.value) + updated_todo_item = await update_todo_content( + interaction.client.db_pool, + self.view.todo_item, + modal.content.value, + ) await self.view.update_todo(modal.interaction, updated_todo_item) await modal.interaction.followup.send("Todo item edited.", ephemeral=True) else: @@ -285,7 +264,7 @@ async def callback(self, interaction: beira.Interaction) -> None: # type: ignor assert self.view is not None - await self.view.todo_item.delete(interaction.client.db_pool) + await delete_todo(interaction.client.db_pool, self.view.todo_item) await self.view.update_todo(interaction, TodoItem.generate_deleted()) await interaction.followup.send("Todo task deleted!", ephemeral=True) diff --git a/src/beira/exts/triggers/misc_triggers.py b/src/beira/exts/triggers/misc_triggers.py index e905461..d1bc6f5 100644 --- a/src/beira/exts/triggers/misc_triggers.py +++ b/src/beira/exts/triggers/misc_triggers.py @@ -13,6 +13,8 @@ import beira +LOGGER = logging.getLogger(__name__) + type ValidGuildChannel = ( discord.VoiceChannel | discord.StageChannel | discord.ForumChannel | discord.TextChannel | discord.CategoryChannel ) @@ -47,8 +49,6 @@ LOSSY_TWITTER_LINK_PATTERN = re.compile(r"(?:http(?:s)?://|(? None: @@ -96,7 +96,7 @@ async def on_leveled_role_member_update(self, before: discord.Member, after: dis if after.joined_at is not None: # Technically, at 8 points every two minutes, it's possible to hit the lowest relevant leveled role in # 20h 50m on ACI, so 21 hours will be the limit. - recently_rejoined = (discord.utils.utcnow() - after.joined_at).total_seconds() < 75600 + recently_rejoined = (discord.utils.utcnow() - after.joined_at).total_seconds() < (60 * 24 * 21) else: recently_rejoined = False diff --git a/src/beira/scheduler.py b/src/beira/scheduler.py new file mode 100644 index 0000000..c52dcb8 --- /dev/null +++ b/src/beira/scheduler.py @@ -0,0 +1,428 @@ +# region License +# Modified from https://github.com/mikeshardmind/discord-scheduler. See the license below: +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +# +# Copyright (C) 2023 Michael Hall +# endregion + +import asyncio +from datetime import datetime, timedelta +from typing import Self +from zoneinfo import ZoneInfo + +import asyncpg +import discord +from discord.ext import commands, tasks +from msgspec import Struct + +from .utils import Pool_alias + + +__all__ = ("ScheduledDispatch", "Scheduler") + + +# Requires a postgres extension: https://github.com/fboulnois/pg_uuidv7 +INITIALIZATION_STATEMENTS = """ +CREATE TABLE IF NOT EXISTS scheduled_dispatches ( + task_id UUID PRIMARY KEY DEFAULT uuid_generate_v7(), + dispatch_name TEXT NOT NULL, + dispatch_time TIMESTAMP NOT NULL, + dispatch_zone TEXT NOT NULL, + associated_guild BIGINT, + associated_user BIGINT, + dispatch_extra JSONB +); +""" + +ZONE_SELECTION_STATEMENT = """ +SELECT DISTINCT dispatch_zone FROM scheduled_dispatches; +""" + +UNSCHEDULE_BY_UUID_STATEMENT = """ +DELETE FROM scheduled_dispatches WHERE task_id = $1; +""" + +UNSCHEDULE_ALL_BY_GUILD_STATEMENT = """ +DELETE FROM scheduled_dispatches +WHERE associated_guild IS NOT NULL AND associated_guild = $1; +""" + +UNSCHEDULE_ALL_BY_USER_STATEMENT = """ +DELETE FROM scheduled_dispatches +WHERE associated_user IS NOT NULL AND associated_user = $1; +""" + +UNSCHEDULE_ALL_BY_MEMBER_STATEMENT = """ +DELETE FROM scheduled_dispatches +WHERE + associated_guild IS NOT NULL + AND associated_user IS NOT NULL + AND associated_guild = $1 + AND associated_user = $2 +; +""" + +UNSCHEDULE_ALL_BY_DISPATCH_NAME_STATEMENT = """ +DELETE FROM scheduled_dispatches WHERE dispatch_name = $1; +""" + +UNSCHEDULE_ALL_BY_NAME_AND_USER_STATEMENT = """ +DELETE FROM scheduled_dispatches +WHERE + dispatch_name = $1 + AND associated_user IS NOT NULL + AND associated_user = $2; +""" + +UNSCHEDULE_ALL_BY_NAME_AND_GUILD_STATEMENT = """ +DELETE FROM scheduled_dispatches +WHERE + dispatch_name = $1 + AND associated_guild IS NOT NULL + AND associated_guild = $2; +""" + +UNSCHEDULE_ALL_BY_NAME_AND_MEMBER_STATEMENT = """ +DELETE FROM scheduled_dispatches +WHERE + dispatch_name = $1 + AND associated_guild IS NOT NULL + AND associated_user IS NOT NULL + AND associated_guild = $2 + AND associated_user = $3 +; +""" + +SELECT_ALL_BY_NAME_STATEMENT = """ +SELECT * FROM scheduled_dispatches WHERE dispatch_name = $1; +""" + +SELECT_ALL_BY_NAME_AND_GUILD_STATEMET = """ +SELECT * FROM scheduled_dispatches +WHERE + dispatch_name = $1 + AND associated_guild IS NOT NULL + AND associated_guild = $2; +""" + +SELECT_ALL_BY_NAME_AND_USER_STATEMENT = """ +SELECT * FROM scheduled_dispatches +WHERE + dispatch_name = $1 + AND associated_user IS NOT NULL + AND associated_user = $2; +""" + +SELECT_ALL_BY_NAME_AND_MEMBER_STATEMENT = """ +SELECT * FROM scheduled_dispatches +WHERE + dispatch_name = $1 + AND associated_guild IS NOT NULL + AND associated_user IS NOT NULL + AND associated_guild = $2 + AND associated_user = $3 +; +""" + +INSERT_SCHEDULE_STATEMENT = """ +INSERT INTO scheduled_dispatches +(dispatch_name, dispatch_time, dispatch_zone, associated_guild, associated_user, dispatch_extra) +VALUES ($1, $2, $3, $4, $5, $6::jsonb) +RETURNING task_id; +""" + +DELETE_RETURNING_UPCOMING_IN_ZONE_STATEMENT = """ +DELETE FROM scheduled_dispatches +WHERE (dispatch_time AT TIME ZONE 'UTC' AT TIME ZONE dispatch_zone) < (CURRENT_TIMESTAMP + $1::interval) +RETURNING *; +""" + + +class ScheduledDispatch(Struct, frozen=True, gc=False): + task_id: str + dispatch_name: str + dispatch_time: datetime + dispatch_zone: str + associated_guild: int | None + associated_user: int | None + dispatch_extra: dict[str, object] | None + + def __eq__(self, other: object) -> bool: + return self is other + + def __lt__(self, other: Self) -> bool: + if type(self) is type(other): + return (self.dispatch_time, self.task_id) < (other.dispatch_time, self.task_id) + return False + + def __gt__(self, other: Self) -> bool: + if type(self) is type(other): + return (self.dispatch_time, self.task_id) > (other.dispatch_time, self.task_id) + return False + + @classmethod + def from_row(cls, row: asyncpg.Record) -> Self: + tid, name, time, zone, guild, user, extra = row + time = time.replace(tzinfo=ZoneInfo(zone)) + return cls(tid, name, time, zone, guild, user, extra) + + def to_row(self) -> tuple[str, str, datetime, str, int | None, int | None, dict[str, object] | None]: + return ( + self.task_id, + self.dispatch_name, + self.dispatch_time, + self.dispatch_zone, + self.associated_guild, + self.associated_user, + self.dispatch_extra, + ) + + +async def _get_scheduled(pool: Pool_alias, granularity: int) -> list[ScheduledDispatch]: + async with pool.acquire() as conn, conn.transaction(): + rows = await conn.fetch(DELETE_RETURNING_UPCOMING_IN_ZONE_STATEMENT, timedelta(minutes=granularity)) + return [ScheduledDispatch.from_row(row) for row in rows] + + +async def _schedule( + pool: Pool_alias, + *, + dispatch_name: str, + dispatch_time: datetime, + dispatch_zone: str, + guild_id: int | None, + user_id: int | None, + dispatch_extra: object | None, +) -> str: + # Normalize the given time to UTC, then remove the time zone to make it "naive". + # This is necessary to ensure consistency among the saved timestamps, as well as to work around asyncpg. + # Let dispatch_zone handle the timezone info. + # Ref: https://github.com/MagicStack/asyncpg/issues/481 + dispatch_time = dispatch_time.astimezone(ZoneInfo("UTC")).replace(tzinfo=None) + + async with pool.acquire() as conn, conn.transaction(): + task_id = await conn.fetchval( + INSERT_SCHEDULE_STATEMENT, + dispatch_name, + dispatch_time, + dispatch_zone, + guild_id, + user_id, + dispatch_extra, + column=0, + ) + + return task_id # noqa: RET504 + + +async def _query(pool: Pool_alias, query_str: str, params: tuple[object, ...]) -> list[ScheduledDispatch]: + async with pool.acquire() as conn, conn.transaction(): + return [ScheduledDispatch.from_row(row) for row in await conn.fetch(query_str, *params)] + + +class Scheduler: + def __init__(self, pool: Pool_alias) -> None: + self.pool = pool + self._granularity = 1 + + self._schedule_queue: asyncio.PriorityQueue[ScheduledDispatch] = asyncio.PriorityQueue() + self._schedule_queue_lock = asyncio.Lock() + self._discord_dispatch_task: asyncio.Task[None] | None = None + + async def __aenter__(self): + self.scheduler_loop.start() + return self + + async def __aexit__(self, *exc_info: object): + self.scheduler_loop.cancel() + + @tasks.loop(seconds=25) + async def scheduler_loop(self) -> None: + scheduled: list[ScheduledDispatch] = await _get_scheduled(self.pool, self._granularity) + async with self._schedule_queue_lock: + for dispatch in scheduled: + await self._schedule_queue.put(dispatch) + + @scheduler_loop.after_loop + async def scheduler_loop_after(self) -> None: + if self._discord_dispatch_task: + self._discord_dispatch_task.cancel() + + async def get_next(self) -> ScheduledDispatch: + try: + dispatch = await self._schedule_queue.get() + now = datetime.now(ZoneInfo("UTC")) + scheduled_time = dispatch.dispatch_time + if now < scheduled_time: + delay = (now - scheduled_time).total_seconds() + await asyncio.sleep(delay) + return dispatch + finally: + self._schedule_queue.task_done() + + async def _discord_dispatch_loop(self, bot: commands.Bot, *, wait_until_ready: bool) -> None: + if wait_until_ready: + await bot.wait_until_ready() + + try: + while scheduled := await self.get_next(): + bot.dispatch(f"scheduler_{scheduled.dispatch_name}", scheduled) + except (OSError, discord.ConnectionClosed, asyncpg.PostgresConnectionError): + assert self._discord_dispatch_task + self._discord_dispatch_task.cancel() + self.start_discord_dispatch(bot, wait_until_ready=False) + + def start_discord_dispatch(self, bot: commands.Bot, *, wait_until_ready: bool = True) -> None: + self._discord_dispatch_task = asyncio.create_task( + self._discord_dispatch_loop(bot, wait_until_ready=wait_until_ready) + ) + self._discord_dispatch_task.add_done_callback(lambda t: t.exception() if not t.cancelled() else None) + + async def stop_discord_dispatch(self) -> None: + async with self._schedule_queue_lock: + await self._schedule_queue.join() + + async def schedule_event( + self, + *, + dispatch_name: str, + dispatch_time: datetime, + dispatch_zone: str, + guild_id: int | None = None, + user_id: int | None = None, + dispatch_extra: object | None = None, + ) -> str: + """Schedule something to be emitted later. + + Parameters + ---------- + dispatch_name: str + The event name to dispatch under. You may drop all events dispatching to the same name (such as when + removing a feature built ontop of this), + dispatch_time: str + A time string matching the format "%Y-%m-%d %H:%M" (eg. "2023-01-23 13:15") + dispatch_zone: str + The name of the zone for the event. + - Use `UTC` for absolute things scheduled by machines for machines + - Use the name of the zone (eg. US/Eastern) for things scheduled by + humans for machines to do for humans later + + guild_id: int | None + Optionally, an associated guild_id. This can be used with dispatch_name as a means of querying events or to + drop all scheduled events for a guild. + user_id: int | None + Optionally, an associated user_id. This can be used with dispatch_name as a means of querying events or to + drop all scheduled events for a user. + dispatch_extra: object | None + Optionally, extra data to attach to dispatch. This may be any object serializable by msgspec.json.encode + where the result is round-trip decodable with msgspec.json.decode(..., strict=True). + + Returns + ------- + str + A uuid for the task, used for unique cancelation. + """ + + return await _schedule( + self.pool, + dispatch_name=dispatch_name, + dispatch_time=dispatch_time, + dispatch_zone=dispatch_zone, + guild_id=guild_id, + user_id=user_id, + dispatch_extra=dispatch_extra, + ) + + async def unschedule_uuid(self, uuid: str) -> None: + """Unschedule something by uuid. + + This may miss things which should run within the next interval as defined by `granularity`. + Non-existent uuids are silently handled. + """ + + async with self.pool.acquire() as conn, conn.transaction(): + await conn.execute(UNSCHEDULE_BY_UUID_STATEMENT, uuid) + + async def drop_user_schedule(self, user_id: int) -> None: + """Drop all scheduled events for a user (by user_id). + + Intended use case: + removing everything associated to a user who asks for data removal, doesn't exist anymore, or is blacklisted + """ + + async with self.pool.acquire() as conn, conn.transaction(): + await conn.execute(UNSCHEDULE_ALL_BY_USER_STATEMENT, user_id) + + async def drop_event_for_user(self, dispatch_name: str, user_id: int) -> None: + """Drop scheduled events dispatched to `dispatch_name` for user (by user_id). + + Intended use case example: + A reminder system allowing a user to unschedule all reminders + without effecting how other extensions might use this. + """ + + async with self.pool.acquire() as conn, conn.transaction(): + await conn.execute(UNSCHEDULE_ALL_BY_NAME_AND_USER_STATEMENT, dispatch_name, user_id) + + async def drop_guild_schedule(self, guild_id: int) -> None: + """Drop all scheduled events for a guild (by guild_id). + + Intended use case: + clearing scheduled events for a guild when leaving it. + """ + + async with self.pool.acquire() as conn, conn.transaction(): + await conn.execute(UNSCHEDULE_ALL_BY_GUILD_STATEMENT, guild_id) + + async def drop_event_for_guild(self, dispatch_name: str, guild_id: int) -> None: + """Drop scheduled events dispatched to `dispatch_name` for guild (by guild_id). + + Intended use case example: + An admin command allowing clearing all scheduled messages for a guild. + """ + + async with self.pool.acquire() as conn, conn.transaction(): + await conn.execute(UNSCHEDULE_ALL_BY_NAME_AND_GUILD_STATEMENT, dispatch_name, guild_id) + + async def drop_member_schedule(self, guild_id: int, user_id: int) -> None: + """Drop all scheduled events for a guild (by guild_id, user_id). + + Intended use case: + clearing sccheduled events for a member that leaves a guild + """ + + async with self.pool.acquire() as conn, conn.transaction(): + await conn.execute(UNSCHEDULE_ALL_BY_MEMBER_STATEMENT, guild_id, user_id) + + async def drop_event_for_member(self, dispatch_name: str, guild_id: int, user_id: int) -> None: + """Drop scheduled events dispatched to `dispatch_name` for member (by guild_id, user_id). + + Intended use case example: + see user example, but in a guild + """ + + async with self.pool.acquire() as conn, conn.transaction(): + await conn.execute(UNSCHEDULE_ALL_BY_NAME_AND_MEMBER_STATEMENT, dispatch_name, guild_id, user_id) + + async def list_event_schedule_for_user(self, dispatch_name: str, user_id: int) -> list[ScheduledDispatch]: + """List the events of a specified name scheduled for a user (by user_id).""" + + return await _query(self.pool, SELECT_ALL_BY_NAME_AND_USER_STATEMENT, (dispatch_name, user_id)) + + async def list_event_schedule_for_member( + self, + dispatch_name: str, + guild_id: int, + user_id: int, + ) -> list[ScheduledDispatch]: + """List the events of a specified name scheduled for a guild member (by guild_id, user_id).""" + + return await _query(self.pool, SELECT_ALL_BY_NAME_AND_MEMBER_STATEMENT, (dispatch_name, guild_id, user_id)) + + async def list_event_schedule_for_guild(self, dispatch_name: str, guild_id: int) -> list[ScheduledDispatch]: + """List the events of a specified name scheduled for a guild (by guild_id).""" + + return await _query(self.pool, SELECT_ALL_BY_NAME_AND_USER_STATEMENT, (dispatch_name, guild_id)) diff --git a/src/beira/tree.py b/src/beira/tree.py index ceda873..c12684a 100644 --- a/src/beira/tree.py +++ b/src/beira/tree.py @@ -22,13 +22,14 @@ ClientT_co = TypeVar("ClientT_co", bound=Client, covariant=True) +LOGGER = logging.getLogger(__name__) + type Coro[T] = Coroutine[Any, Any, T] type CoroFunc = Callable[..., Coro[Any]] type AppHook[GroupT: (Group | commands.Cog)] = ( Callable[[GroupT, Interaction[Any]], Coro[Any]] | Callable[[Interaction[Any]], Coro[Any]] ) -LOGGER = logging.getLogger(__name__) __all__ = ("before_app_invoke", "after_app_invoke", "HookableTree") diff --git a/src/beira/utils/db.py b/src/beira/utils/db.py index d9fe7c6..246bb20 100644 --- a/src/beira/utils/db.py +++ b/src/beira/utils/db.py @@ -2,9 +2,9 @@ from typing import TYPE_CHECKING -import msgspec from asyncpg import Connection, Pool, Record from asyncpg.pool import PoolConnectionProxy +from msgspec.json import decode as json_decode, encode as json_encode __all__ = ("Connection_alias", "Pool_alias", "conn_init") @@ -20,9 +20,4 @@ async def conn_init(connection: Connection_alias) -> None: """Sets up codecs for Postgres connection.""" - await connection.set_type_codec( - "jsonb", - schema="pg_catalog", - encoder=msgspec.json.encode, - decoder=msgspec.json.decode, - ) + await connection.set_type_codec("jsonb", schema="pg_catalog", encoder=json_encode, decoder=json_decode) diff --git a/src/beira/utils/pagination.py b/src/beira/utils/pagination.py index cc0cb07..3a3d9ae 100644 --- a/src/beira/utils/pagination.py +++ b/src/beira/utils/pagination.py @@ -286,7 +286,7 @@ async def turn_to_last(self, interaction: discord.Interaction, _: discord.ui.But await self.update_page(interaction) -class PaginatedSelectView[_LT](abc.ABC, OwnedView): +class PaginatedSelectView[_SequenceT: Sequence[Any]](abc.ABC, OwnedView): """A view that handles paginated embeds and page buttons. Parameters @@ -312,7 +312,7 @@ class PaginatedSelectView[_LT](abc.ABC, OwnedView): message: discord.Message - def __init__(self, author_id: int, pages_content: Sequence[_LT], *, timeout: float | None = 180) -> None: + def __init__(self, author_id: int, pages_content: _SequenceT, *, timeout: float | None = 180) -> None: super().__init__(author_id, timeout=timeout) self.pages = pages_content self.page_index: int = 0 diff --git a/src/beira/utils/scheduler.py b/src/beira/utils/scheduler.py deleted file mode 100644 index 1971e6e..0000000 --- a/src/beira/utils/scheduler.py +++ /dev/null @@ -1,628 +0,0 @@ -# region License -# Vendored from https://github.com/mikeshardmind/discord-scheduler with some modifications to accommodate a different -# backend. See the license below: -# -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at http://mozilla.org/MPL/2.0/. -# -# Copyright (C) 2023 Michael Hall -# endregion - -import asyncio -import random -import time -from collections.abc import Callable -from datetime import datetime, timedelta -from typing import Protocol, Self -from warnings import warn -from zoneinfo import ZoneInfo - -import asyncpg -from msgspec import Struct -from msgspec.json import decode as json_decode, encode as json_encode - -from .db import Connection_alias # noqa: TID252 - - -class BotLike(Protocol): - def dispatch(self, event_name: str, /, *args: object, **kwargs: object) -> None: ... - - async def wait_until_ready(self) -> None: ... - - -def _uuid7gen() -> Callable[[], str]: - """UUIDv7 has been accepted as part of rfc9562 - - This is intended to be a compliant implementation, but I am not advertising it - in public, exported APIs as such *yet* - - In particular, this is: - UUIDv7 as described in rfc9562 section 5.7 utilizing the - optional sub-millisecond timestamp fraction described in section 6.2 method 3 - """ - _last_timestamp: int | None = None - - def uuid7() -> str: - """This is unique identifer generator - - This was chosen to increase performance of indexing and - to pick something likely to get specific database support - for this to be a portably efficient choice should someone - decide to have this be backed by something other than sqlite - - This should not be relied on as always generating valid UUIDs of - any version or variant at this time. The current intent is that - this is a UUIDv7 in str form, but this should not be relied - on outside of this library and may be changed in the future for - better performance within this library. - """ - nonlocal _last_timestamp - nanoseconds = time.time_ns() - if _last_timestamp is not None and nanoseconds <= _last_timestamp: - nanoseconds = _last_timestamp + 1 - _last_timestamp = nanoseconds - timestamp_s, timestamp_ns = divmod(nanoseconds, 10**9) - subsec_a = timestamp_ns >> 18 - subsec_b = (timestamp_ns >> 6) & 0x0FFF - subsec_seq_node = (timestamp_ns & 0x3F) << 56 - subsec_seq_node += random.SystemRandom().getrandbits(56) - uuid_int = (timestamp_s & 0x0FFFFFFFFF) << 92 - uuid_int += subsec_a << 80 - uuid_int += subsec_b << 64 - uuid_int += subsec_seq_node - uuid_int &= ~(0xC000 << 48) - uuid_int |= 0x8000 << 48 - uuid_int &= ~(0xF000 << 64) - uuid_int |= 7 << 76 - return f"{uuid_int:032x}" - - return uuid7 - - -_uuid7 = _uuid7gen() - -__all__ = ("DiscordBotScheduler", "ScheduledDispatch", "Scheduler") - -SQLROW_TYPE = tuple[str, str, str, str, int | None, int | None, bytes | None] -DATE_FMT = r"%Y-%m-%d %H:%M" - - -# Requires a postgres extension: https://github.com/fboulnois/pg_uuidv7 -INITIALIZATION_STATEMENTS = """ -CREATE TABLE IF NOT EXISTS scheduled_dispatches ( - task_id UUID PRIMARY KEY DEFAULT uuid_generate_v7(), - dispatch_name TEXT NOT NULL, - dispatch_time TIMESTAMP NOT NULL, - dispatch_zone TEXT NOT NULL, - associated_guild BIGINT, - associated_user BIGINT, - dispatch_extra JSONB -); -""" - -ZONE_SELECTION_STATEMENT = """ -SELECT DISTINCT dispatch_zone FROM scheduled_dispatches; -""" - -UNSCHEDULE_BY_UUID_STATEMENT = """ -DELETE FROM scheduled_dispatches WHERE task_id = $1; -""" - -UNSCHEDULE_ALL_BY_GUILD_STATEMENT = """ -DELETE FROM scheduled_dispatches -WHERE associated_guild IS NOT NULL AND associated_guild = $1; -""" - -UNSCHEDULE_ALL_BY_USER_STATEMENT = """ -DELETE FROM scheduled_dispatches -WHERE associated_user IS NOT NULL AND associated_user = $1; -""" - -UNSCHEDULE_ALL_BY_MEMBER_STATEMENT = """ -DELETE FROM scheduled_dispatches -WHERE - associated_guild IS NOT NULL - AND associated_user IS NOT NULL - AND associated_guild = $1 - AND associated_user = $2 -; -""" - -UNSCHEDULE_ALL_BY_DISPATCH_NAME_STATEMENT = """ -DELETE FROM scheduled_dispatches WHERE dispatch_name = $1; -""" - -UNSCHEDULE_ALL_BY_NAME_AND_USER_STATEMENT = """ -DELETE FROM scheduled_dispatches -WHERE - dispatch_name = $1 - AND associated_user IS NOT NULL - AND associated_user = $2; -""" - -UNSCHEDULE_ALL_BY_NAME_AND_GUILD_STATEMENT = """ -DELETE FROM scheduled_dispatches -WHERE - dispatch_name = $1 - AND associated_guild IS NOT NULL - AND associated_guild = $2; -""" - -UNSCHEDULE_ALL_BY_NAME_AND_MEMBER_STATEMENT = """ -DELETE FROM scheduled_dispatches -WHERE - dispatch_name = $1 - AND associated_guild IS NOT NULL - AND associated_user IS NOT NULL - AND associated_guild = $2 - AND associated_user = $3 -; -""" - -SELECT_ALL_BY_NAME_STATEMENT = """ -SELECT * FROM scheduled_dispatches WHERE dispatch_name = $1; -""" - -SELECT_ALL_BY_NAME_AND_GUILD_STATEMET = """ -SELECT * FROM scheduled_dispatches -WHERE - dispatch_name = $1 - AND associated_guild IS NOT NULL - AND associated_guild = $2; -""" - -SELECT_ALL_BY_NAME_AND_USER_STATEMENT = """ -SELECT * FROM scheduled_dispatches -WHERE - dispatch_name = $1 - AND associated_user IS NOT NULL - AND associated_user = $2; -""" - -SELECT_ALL_BY_NAME_AND_MEMBER_STATEMENT = """ -SELECT * FROM scheduled_dispatches -WHERE - dispatch_name = $1 - AND associated_guild IS NOT NULL - AND associated_user IS NOT NULL - AND associated_guild = $2 - AND associated_user = $3 -; -""" - -INSERT_SCHEDULE_STATEMENT = """ -INSERT INTO scheduled_dispatches -(task_id, dispatch_name, dispatch_time, dispatch_zone, associated_guild, associated_user, dispatch_extra) -VALUES ($1, $2, $3, $4, $5, $6, $7); -""" - -DELETE_RETURNING_UPCOMING_IN_ZONE_STATEMENT = """ -DELETE FROM scheduled_dispatches -WHERE dispatch_time < $1 AND dispatch_zone = $2 -RETURNING *; -""" - - -class ScheduledDispatch(Struct, frozen=True, gc=False): - task_id: str - dispatch_name: str - dispatch_time: str - dispatch_zone: str - associated_guild: int | None - associated_user: int | None - dispatch_extra: bytes | None - - def __eq__(self, other: object) -> bool: - return self is other - - def __lt__(self, other: Self) -> bool: - if type(self) is type(other): - return (self.get_arrow_time(), self.task_id) < (other.get_arrow_time(), self.task_id) - return False - - def __gt__(self, other: Self) -> bool: - if type(self) is type(other): - return (self.get_arrow_time(), self.task_id) > (other.get_arrow_time(), self.task_id) - return False - - @classmethod - def from_pg_row(cls, row: asyncpg.Record) -> Self: - tid, name, time, zone, guild, user, extra_bytes = row - return cls(tid, name, time, zone, guild, user, extra_bytes) - - @classmethod - def from_exposed_api( - cls: type[Self], - *, - name: str, - time: str, - zone: str, - guild: int | None, - user: int | None, - extra: object | None, - ) -> Self: - packed: bytes | None = None - if extra is not None: - f = json_encode(extra) - packed = f - return cls(_uuid7(), name, time, zone, guild, user, packed) - - def to_pg_row(self) -> SQLROW_TYPE: - return ( - self.task_id, - self.dispatch_name, - self.dispatch_time, - self.dispatch_zone, - self.associated_guild, - self.associated_user, - self.dispatch_extra, - ) - - def get_arrow_time(self) -> datetime: - return datetime.strptime(self.dispatch_time, DATE_FMT).replace(tzinfo=ZoneInfo(self.dispatch_zone)) - - def unpack_extra(self) -> object | None: - if self.dispatch_extra is not None: - return json_decode(self.dispatch_extra, strict=True) - return None - - -async def _setup_db(conn: Connection_alias) -> set[str]: - async with conn.transaction(): - await conn.execute(INITIALIZATION_STATEMENTS) - return {row["dispatch_zone"] for row in await conn.fetch(ZONE_SELECTION_STATEMENT)} - - -async def _get_scheduled(conn: Connection_alias, granularity: int, zones: set[str]) -> list[ScheduledDispatch]: - ret: list[ScheduledDispatch] = [] - if not zones: - return ret - - cutoff = datetime.now(ZoneInfo("UTC")) + timedelta(minutes=granularity) - async with conn.transaction(): - for zone in zones: - local_time = cutoff.astimezone(ZoneInfo(zone)).strftime(DATE_FMT) - rows = await conn.fetch(DELETE_RETURNING_UPCOMING_IN_ZONE_STATEMENT, local_time, zone) - ret.extend(map(ScheduledDispatch.from_pg_row, rows)) - - return ret - - -async def _schedule( - conn: Connection_alias, - *, - dispatch_name: str, - dispatch_time: str, - dispatch_zone: str, - guild_id: int | None, - user_id: int | None, - dispatch_extra: object | None, -) -> str: - # do this here, so if it fails, it fails at scheduling - _time = datetime.strptime(dispatch_time, DATE_FMT).replace(tzinfo=ZoneInfo(dispatch_zone)) - obj = ScheduledDispatch.from_exposed_api( - name=dispatch_name, - time=dispatch_time, - zone=dispatch_zone, - guild=guild_id, - user=user_id, - extra=dispatch_extra, - ) - - async with conn.transaction(): - await conn.execute(INSERT_SCHEDULE_STATEMENT, *obj.to_pg_row()) - - return obj.task_id - - -async def _query(conn: Connection_alias, query_str: str, params: tuple[int | str, ...]) -> list[ScheduledDispatch]: - return [ScheduledDispatch.from_pg_row(row) for row in await conn.fetch(query_str, *params)] - - -async def _drop(conn: Connection_alias, query_str: str, params: tuple[int | str, ...]) -> None: - async with conn.transaction(): - await conn.execute(query_str, *params) - - -class Scheduler: - def __init__(self, db_pool: asyncpg.Pool[asyncpg.Record], granularity: int = 1): - if granularity < 1: - msg = "Granularity must be a positive iteger number of minutes" - raise ValueError(msg) - asyncio.get_running_loop() - self.granularity = granularity - self._pool = db_pool - self._zones: set[str] = set() # We don't re-narrow this anywhere currently, only expand it. - self._queue: asyncio.PriorityQueue[ScheduledDispatch] = asyncio.PriorityQueue() - self._ready = False - self._closing = False - self._lock = asyncio.Lock() - self._loop_task: asyncio.Task[None] | None = None - self._discord_task: asyncio.Task[None] | None = None - - def stop(self) -> None: - if self._loop_task is None: - msg = "Contextmanager, use it" - raise RuntimeError(msg) - self._loop_task.cancel() - if self._discord_task: - self._discord_task.cancel() - - async def _loop(self) -> None: - # not currently modifiable once running - # differing granularities here, + a delay on retrieving in .get_next() - # ensures closest - sleep_gran = self.granularity * 25 - while (not self._closing) and await asyncio.sleep(sleep_gran, self._ready): - # Lock needed to ensure that once the db is dropping rows - # that a graceful shutdown doesn't drain the queue until entries are in it. - async with self._lock: - # check on both ends of the await that we aren't closing - if self._closing: - return - async with self._pool.acquire() as conn: - scheduled = await _get_scheduled(conn, self.granularity, self._zones) - for s in scheduled: - await self._queue.put(s) - - async def __aexit__(self, *exc_info: object): - if not self._closing: - msg = "Exiting without use of stop_gracefully may cause loss of tasks" - warn(msg, stacklevel=2) - self.stop() - - async def get_next(self) -> ScheduledDispatch: - """ - gets the next scheduled event, waiting if neccessary. - """ - - try: - dispatch = await self._queue.get() - now = datetime.now(ZoneInfo("UTC")) - scheduled_for = dispatch.get_arrow_time() - if now < scheduled_for: - delay = (now - scheduled_for).total_seconds() - await asyncio.sleep(delay) - return dispatch - finally: - self._queue.task_done() - - async def stop_gracefully(self) -> None: - """Notify the internal scheduling loop to stop scheduling and wait for the internal queue to be empty""" - - self._closing = True - # don't remove lock, see note in _loop - async with self._lock: - await self._queue.join() - - async def __aenter__(self) -> Self: - async with self._pool.acquire() as conn: - self._zones = await _setup_db(conn) - self._ready = True - self._loop_task = asyncio.create_task(self._loop()) - self._loop_task.add_done_callback(lambda f: f.exception() if not f.cancelled() else None) - return self - - async def schedule_event( - self, - *, - dispatch_name: str, - dispatch_time: str, - dispatch_zone: str, - guild_id: int | None = None, - user_id: int | None = None, - dispatch_extra: object | None = None, - ) -> str: - """ - Schedule something to be emitted later. - - Parameters - ---------- - dispatch_name: str - The event name to dispatch under. - You may drop all events dispatching to the same name - (such as when removing a feature built ontop of this) - dispatch_time: str - A time string matching the format "%Y-%m-%d %H:%M" (eg. "2023-01-23 13:15") - dispatch_zone: str - The name of the zone for the event. - - Use `UTC` for absolute things scheduled by machines for machines - - Use the name of the zone (eg. US/Eastern) for things scheduled by - humans for machines to do for humans later - - guild_id: int | None - Optionally, an associated guild_id. - This can be used with dispatch_name as a means of querying events - or to drop all scheduled events for a guild. - user_id: int | None - Optionally, an associated user_id. - This can be used with dispatch_name as a means of querying events - or to drop all scheduled events for a user. - dispatch_extra: object | None - Optionally, Extra data to attach to dispatch. - This may be any object serializable by msgspec.msgpack.encode - where the result is round-trip decodable with - msgspec.msgpack.decode(..., strict=True) - - Returns - ------- - str - A uuid for the task, used for unique cancelation. - """ - - self._zones.add(dispatch_zone) - async with self._pool.acquire() as conn: - return await _schedule( - conn, - dispatch_name=dispatch_name, - dispatch_time=dispatch_time, - dispatch_zone=dispatch_zone, - guild_id=guild_id, - user_id=user_id, - dispatch_extra=dispatch_extra, - ) - - async def unschedule_uuid(self, uuid: str) -> None: - """ - Unschedule something by uuid. - This may miss things which should run within the next interval as defined by `granularity` - Non-existent uuids are silently handled. - """ - - async with self._pool.acquire() as conn: - await _drop(conn, UNSCHEDULE_BY_UUID_STATEMENT, (uuid,)) - - async def drop_user_schedule(self, user_id: int) -> None: - """ - Drop all scheduled events for a user (by user_id) - - Intended use case: - removing everything associated to a user who asks for data removal, doesn't exist anymore, or is blacklisted - """ - - async with self._pool.acquire() as conn: - await _drop(conn, UNSCHEDULE_ALL_BY_USER_STATEMENT, (user_id,)) - - async def drop_event_for_user(self, dispatch_name: str, user_id: int) -> None: - """ - Drop scheduled events dispatched to `dispatch_name` for user (by user_id) - - Intended use case example: - A reminder system allowing a user to unschedule all reminders - without effecting how other extensions might use this. - """ - - async with self._pool.acquire() as conn: - await _drop(conn, UNSCHEDULE_ALL_BY_NAME_AND_USER_STATEMENT, (dispatch_name, user_id)) - - async def drop_guild_schedule(self, guild_id: int) -> None: - """ - Drop all scheduled events for a guild (by guild_id) - - Intended use case: - clearing sccheduled events for a guild when leaving it. - """ - - async with self._pool.acquire() as conn: - await _drop(conn, UNSCHEDULE_ALL_BY_GUILD_STATEMENT, (guild_id,)) - - async def drop_event_for_guild(self, dispatch_name: str, guild_id: int) -> None: - """ - Drop scheduled events dispatched to `dispatch_name` for guild (by guild_id) - - Intended use case example: - An admin command allowing clearing all scheduled messages for a guild. - """ - - async with self._pool.acquire() as conn: - await _drop(conn, UNSCHEDULE_ALL_BY_NAME_AND_GUILD_STATEMENT, (dispatch_name, guild_id)) - - async def drop_member_schedule(self, guild_id: int, user_id: int) -> None: - """ - Drop all scheduled events for a guild (by guild_id, user_id) - - Intended use case: - clearing sccheduled events for a member that leaves a guild - """ - - async with self._pool.acquire() as conn: - await _drop(conn, UNSCHEDULE_ALL_BY_MEMBER_STATEMENT, (guild_id, user_id)) - - async def drop_event_for_member(self, dispatch_name: str, guild_id: int, user_id: int) -> None: - """ - Drop scheduled events dispatched to `dispatch_name` for member (by guild_id, user_id) - - Intended use case example: - see user example, but in a guild - """ - - async with self._pool.acquire() as conn: - await _drop(conn, UNSCHEDULE_ALL_BY_NAME_AND_MEMBER_STATEMENT, (dispatch_name, guild_id, user_id)) - - async def list_event_schedule_for_user(self, dispatch_name: str, user_id: int) -> list[ScheduledDispatch]: - """ - list the events of a specified name scheduled for a user (by user_id) - """ - - async with self._pool.acquire() as conn: - return await _query(conn, SELECT_ALL_BY_NAME_AND_USER_STATEMENT, (dispatch_name, user_id)) - - async def list_event_schedule_for_member( - self, - dispatch_name: str, - guild_id: int, - user_id: int, - ) -> list[ScheduledDispatch]: - """ - list the events of a specified name scheduled for a member (by guild_id, user_id) - """ - - async with self._pool.acquire() as conn: - return await _query(conn, SELECT_ALL_BY_NAME_AND_MEMBER_STATEMENT, (dispatch_name, guild_id, user_id)) - - async def list_event_schedule_for_guild(self, dispatch_name: str, guild_id: int) -> list[ScheduledDispatch]: - """ - list the events of a specified name scheduled for a guild (by guild_id) - """ - - async with self._pool.acquire() as conn: - return await _query(conn, SELECT_ALL_BY_NAME_AND_USER_STATEMENT, (dispatch_name, guild_id)) - - @staticmethod - def time_str_from_params(year: int, month: int, day: int, hour: int, minute: int) -> str: - """ - A quick helper for people working with other time representations - (if you have a datetime object, just use strftime with "%Y-%m-%d %H:%M") - """ - - return datetime(year, month, day, hour, minute, tzinfo=ZoneInfo("UTC")).strftime(DATE_FMT) - - -class DiscordBotScheduler(Scheduler): - """Scheduler with convienence dispatches compatible with discord.py's commands extenstion - - Note: long-term compatability not guaranteed, dispatch isn't covered by discord.py's version guarantees. - """ - - async def _bot_dispatch_loop(self, bot: BotLike, wait_until_ready: bool) -> None: - if not self._ready: - msg = "context manager, use it" - raise RuntimeError(msg) - - if wait_until_ready: - await bot.wait_until_ready() - - while scheduled := await self.get_next(): - bot.dispatch(f"sinbad_scheduler_{scheduled.dispatch_name}", scheduled) - - def start_dispatch_to_bot(self, bot: BotLike, *, wait_until_ready: bool = True) -> None: - """ - Starts dispatching events to the bot. - - Events will dispatch under a name with the following format: - - sinbad_scheduler_{dispatch_name} - - where dispatch_name is set when submitting events to schedule. - This is done to avoid potential conflicts with existing or future event names, - as well as anyone else building a scheduler on top of bot.dispatch - (hence author name inclusion) and someone deciding to use both. - - Listeners get a single object as their argument, `ScheduledDispatch` - - to listen for an event you submit with `reminder` as the name - - @commands.Cog.listener("on_sinbad_scheduler_reminder") - async def some_listener(self, scheduled_object: ScheduledDispatch): - ... - - Events will not start being sent until the bot is considered ready if `wait_until_ready` is True - """ - - if not self._ready: - msg = "context manager, use it" - raise RuntimeError(msg) - - self._discord_task = asyncio.create_task(self._bot_dispatch_loop(bot, wait_until_ready)) - self._discord_task.add_done_callback(lambda f: f.exception() if not f.cancelled() else None)