diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index cfaa3f6..0000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,13 +0,0 @@ -{ - "sqltools.connections": [ - { - "previewLimit": 50, - "server": "192.168.0.23", - "port": 5432, - "driver": "PostgreSQL", - "name": "Beira DB", - "username": "thanos", - "database": "discord_beira_db" - } - ] -} \ No newline at end of file diff --git a/core/__init__.py b/core/__init__.py deleted file mode 100644 index 17b83c3..0000000 --- a/core/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from . import tree as tree -from .bot import Beira as Beira -from .checks import * -from .config import * -from .context import * -from .errors import * diff --git a/core/context.py b/core/context.py deleted file mode 100644 index bf6ef5d..0000000 --- a/core/context.py +++ /dev/null @@ -1,52 +0,0 @@ -"""context.py: For the custom context and interaction subclasses. Mainly used for type narrowing.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -import aiohttp -import discord -import wavelink -from discord.ext import commands - -from .utils.db import Pool_alias - - -if TYPE_CHECKING: - from .bot import Beira - - -__all__ = ("Context", "GuildContext", "Interaction") - -type Interaction = discord.Interaction[Beira] - - -class Context(commands.Context["Beira"]): - """A custom context subclass for Beira. - - Attributes - ---------- - session - db - """ - - voice_client: wavelink.Player | None # type: ignore # Type lie for narrowing - - @property - def session(self) -> aiohttp.ClientSession: - """`ClientSession`: Returns the asynchronous HTTP session used by the bot for HTTP requests.""" - - return self.bot.web_session - - @property - def db(self) -> Pool_alias: - """`Pool`: Returns the asynchronous connection pool used by the bot for database management.""" - - return self.bot.db_pool - - -class GuildContext(Context): - author: discord.Member # type: ignore # Type lie for narrowing - guild: discord.Guild # type: ignore # Type lie for narrowing - channel: discord.abc.GuildChannel | discord.Thread # type: ignore # Type lie for narrowing - me: discord.Member # type: ignore # Type lie for narrowing diff --git a/exts/_dev/__init__.py b/exts/_dev/__init__.py deleted file mode 100644 index 6e65bc1..0000000 --- a/exts/_dev/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -import core - -from ._dev import DevCog -from ._test import TestCog - - -async def setup(bot: core.Beira) -> None: - """Connects cog to bot.""" - - # Can't use the guilds kwarg, as it doesn't currently work for hybrids. It would look like this: - # guilds=[discord.Object(guild_id) for guild_id in CONFIG["discord"]["guilds"]["dev"]]) - await bot.add_cog(DevCog(bot)) - await bot.add_cog(TestCog(bot)) diff --git a/exts/notifications/__init__.py b/exts/notifications/__init__.py deleted file mode 100644 index 934c5ed..0000000 --- a/exts/notifications/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -import core - -from .aci_notifications import make_listeners as make_aci_listeners -from .other_triggers import make_listeners as make_other_listeners -from .rss_notifications import RSSNotificationsCog - - -async def setup(bot: core.Beira) -> None: - """Connects listeners and cog to bot.""" - - listener_info = make_aci_listeners(bot) + make_other_listeners(bot) - for event_name, listener in listener_info: - bot.add_listener(listener, event_name) - - await bot.add_cog(RSSNotificationsCog(bot)) diff --git a/exts/notifications/aci_notifications.py b/exts/notifications/aci_notifications.py deleted file mode 100644 index 20565fd..0000000 --- a/exts/notifications/aci_notifications.py +++ /dev/null @@ -1,210 +0,0 @@ -"""custom_notifications.py: One or more listenerrs for sending custom notifications based on events.""" - -import functools -import logging -import re -from collections.abc import Callable -from typing import Any - -import discord -from discord import CategoryChannel, ForumChannel, StageChannel, TextChannel, VoiceChannel - -import core - - -type ValidGuildChannel = VoiceChannel | StageChannel | ForumChannel | TextChannel | CategoryChannel - -LOGGER = logging.getLogger(__name__) - -# 799077440139034654 would be the actual channel should the delete hooks go into "production". -ACI_DELETE_CHANNEL = 975459460560605204 - -# A list of ids for Tatsu leveled roles to keep track of. -ACI_LEVELED_ROLES = { - 694616299476877382, - 694615984438509636, - 694615108323639377, - 694615102237835324, - 747520979735019572, -} - -# The mod role(s) to ping when sending notifications. -ACI_MOD_ROLE = 780904973004570654 - -ACI_GUILD_ID = core.CONFIG.discord.important_guilds["prod"][0] - -LEAKY_INSTAGRAM_LINK_PATTERN = re.compile(r"(instagram\.com/.*?)&igsh.*==") - - -async def on_server_boost_role_member_update( - log_webhook: discord.Webhook, - before: discord.Member, - after: discord.Member, -) -> None: - """Listener that sends a notification if members of the ACI100 server earn certain roles. - - Condition for activating: - - Boost the server and earn the premium subscriber, or "Server Booster", role. - """ - - # Check if the update is in the right server, a member got new roles, and they got a new "Server Booster" role. - if ( - before.guild.id == ACI_GUILD_ID - and len(new_roles := set(after.roles).difference(before.roles)) > 0 - and after.guild.premium_subscriber_role in new_roles - ): - # Send a message notifying holders of some other role(s) about this new role acquisition. - content = f"<@&{ACI_MOD_ROLE}>, {after.mention} just boosted the server!" - await log_webhook.send(content) - - -async def on_leveled_role_member_update( - log_webhook: discord.Webhook, - before: discord.Member, - after: discord.Member, -) -> None: - """Listener that sends a notification if members of the ACI100 server earn certain roles. - - Condition for activating: - - Earn a Tatsu leveled role above "The Ears". - """ - - # Check if the update is in the right server, a member got new roles, and they got a relevant leveled role. - if ( - before.guild.id == ACI_GUILD_ID - and len(new_roles := set(after.roles).difference(before.roles)) > 0 - and (new_leveled_roles := tuple(role for role in new_roles if (role.id in ACI_LEVELED_ROLES))) - ): - # Ensure the user didn't just rejoin. - if after.joined_at is not None: - # Technically, at 8 points every two minutes, it's possible to hit the lowest relevant leveled role in - # 20h 50m on ACI, so 21 hours will be the limit. - recently_rejoined = (discord.utils.utcnow() - after.joined_at).total_seconds() < 75600 - else: - recently_rejoined = False - - if new_leveled_roles and not recently_rejoined: - # Send a message notifying holders of some other role(s) about this new role acquisition. - role_names = tuple(role.name for role in new_leveled_roles) - content = f"<@&{ACI_MOD_ROLE}>, {after.mention} was given the `{role_names}` role(s)." - await log_webhook.send(content) - - -async def on_bad_twitter_link(bot: core.Beira, message: discord.Message) -> None: - if message.author == bot.user or (not message.guild or message.guild.id != ACI_GUILD_ID): - return - - if links := re.findall(r"(?:http(?:s)?://|(? None: - if (not message.guild) or (message.guild.id != ACI_GUILD_ID): - return - - if not LEAKY_INSTAGRAM_LINK_PATTERN.search(message.content): - return - - cleaned_content = re.sub(LEAKY_INSTAGRAM_LINK_PATTERN, "\1", message.content) - new_content = ( - f"*Cleaned Instagram link(s)*\n" - f"Reposted from {message.author.mention} ({message.author.name} - {message.author.id}):\n" - "————————\n" - "\n" - f"{cleaned_content}" - ) - - send_kwargs: dict[str, Any] = {} - if message.attachments: - send_kwargs["files"] = [await atmt.to_file() for atmt in message.attachments] - - await message.delete() - await message.channel.send(new_content, allowed_mentions=discord.AllowedMentions(users=False), **send_kwargs) - - -async def test_on_any_message_delete(bot: core.Beira, payload: discord.RawMessageDeleteEvent) -> None: - # TODO: Improve. - - # The ID of the guild this listener is for. - aci_guild_id: int = core.CONFIG.discord.important_guilds["prod"][0] - - # Only check in ACI100 server. - if payload.guild_id == aci_guild_id: - # Attempt to get the channel the message was sent in. - try: - channel = bot.get_channel(payload.channel_id) or await bot.fetch_channel(payload.channel_id) - except (discord.InvalidData, discord.HTTPException): - LOGGER.info("Could not find the channel of the deleted message: %s", payload) - return - assert isinstance(channel, ValidGuildChannel | discord.Thread) # Known if we reach this point. - - # Attempt to get the message itself. - message = payload.cached_message - if not message and not isinstance(channel, ForumChannel | CategoryChannel): - try: - message = await channel.fetch_message(payload.message_id) - except discord.HTTPException: - LOGGER.info("Could not find the deleted message: %s", payload) - return - assert message is not None # Known if we reach this point. - - # Create a log embed to represent the deleted message. - extra_attachments: list[str] = [] - embed = ( - discord.Embed( - colour=discord.Colour.dark_blue(), - description=( - f"**Message sent by {message.author.mention} - Deleted in <#{payload.channel_id}>**" - f"\n{message.content}" - ), - timestamp=discord.utils.utcnow(), - ) - .set_author(name=str(message.author), icon_url=message.author.display_avatar.url) - .set_footer(text=f"Author: {message.author.id} | Message ID: {payload.message_id}") - .add_field(name="Sent at:", value=discord.utils.format_dt(message.created_at, style="F"), inline=False) - ) - - # Put attachments in the one log message or in another. - if len(message.attachments) == 1: - if message.attachments[0].content_type in ("gif", "jpg", "png", "webp", "webm", "mp4"): - embed.set_image(url=message.attachments[0].url) - else: - embed.add_field(name="Attachment", value="See below.") - extra_attachments.append(message.attachments[0].url) - elif len(message.attachments) > 1: - embed.add_field(name="Attachments", value="See below.") - extra_attachments.extend(att.url for att in message.attachments) - - # Send the log message(s). - delete_log_channel = bot.get_channel(ACI_DELETE_CHANNEL) - assert isinstance(delete_log_channel, discord.TextChannel) # Known at runtime. - - await delete_log_channel.send(embed=embed) - if extra_attachments: - content = "\n".join(extra_attachments) - await delete_log_channel.send(content) - - -def make_listeners(bot: core.Beira) -> tuple[tuple[str, Callable[..., Any]], ...]: - """Connects listeners to bot.""" - - # The webhook url that will be used to send ACI-related notifications. - aci_webhook_url = core.CONFIG.discord.webhooks[0] - role_log_webhook = discord.Webhook.from_url(aci_webhook_url, session=bot.web_session) - - # Adjust the arguments for the listeners and provide corresponding event name. - return ( - ("on_member_update", functools.partial(on_leveled_role_member_update, role_log_webhook)), - ("on_member_update", functools.partial(on_server_boost_role_member_update, role_log_webhook)), - ("on_message", on_leaky_instagram_link), - # ("on_message", functools.partial(on_bad_twitter_link, bot)), # Twitter works. # noqa: ERA001 - ) diff --git a/exts/notifications/other_triggers.py b/exts/notifications/other_triggers.py deleted file mode 100644 index e259d3e..0000000 --- a/exts/notifications/other_triggers.py +++ /dev/null @@ -1,56 +0,0 @@ -import asyncio -import functools -import re -from collections.abc import Callable -from typing import Any - -import aiohttp -import discord -import lxml.etree -import lxml.html -import msgspec - -import core - - -HEADERS = { - "User-Agent": ( - "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " - "AppleWebKit/537.36 (KHTML, like Gecko) " - "Chrome/104.0.0.0 Safari/537.36" - ) -} -private_guild_with_9gag_links = 1097976528832307271 - - -async def get_9gag_mp4(session: aiohttp.ClientSession, link: str) -> str | None: - async with session.get(link, headers=HEADERS) as response: - data = lxml.html.fromstring(await response.read()) - element = data.find(".//script[@type='application/ld+json']") - if element is not None and element.text: - return msgspec.json.decode(element.text)["video"]["contentUrl"] - return None - - -async def on_bad_9gag_link(bot: core.Beira, message: discord.Message) -> None: - if message.author == bot.user or ((not message.guild) or message.guild.id != private_guild_with_9gag_links): - return - - if links := re.findall(r"(?:http(?:s)?://)9gag\.com/gag/[\S]*", message.content): - tasks = [asyncio.create_task(get_9gag_mp4(bot.web_session, link)) for link in links] - results = await asyncio.gather(*tasks) - new_links = "\n".join(result for result in results if result is not None) - if new_links: - content = ( - f"*Corrected 9gag link(s)*\n" - f"Reposted from {message.author.mention} ({message.author.name} - {message.author.id}):\n\n" - f"{new_links}" - ) - await message.reply(content, allowed_mentions=discord.AllowedMentions(users=False, replied_user=False)) - - -def make_listeners(bot: core.Beira) -> tuple[tuple[str, Callable[..., Any]], ...]: - """Connects listeners to bot.""" - - # Adjust the arguments for the listeners and provide corresponding event name. - return (("on_message", functools.partial(on_bad_9gag_link, bot)),) diff --git a/main.py b/main.py deleted file mode 100644 index 37c879f..0000000 --- a/main.py +++ /dev/null @@ -1,42 +0,0 @@ -import asyncio - -import aiohttp -import asyncpg -import discord - -import core -from core.tree import HookableTree -from core.utils import LoggingManager, conn_init - - -async def main() -> None: - """Starts an instance of the bot.""" - - # Initialize a connection to a PostgreSQL database, an asynchronous web session, and a custom logger setup. - async with ( - aiohttp.ClientSession() as web_session, - asyncpg.create_pool(dsn=core.CONFIG.database.pg_url, command_timeout=30, init=conn_init) as pool, - LoggingManager() as logging_manager, - ): - # Set the bot's basic starting parameters. - intents = discord.Intents.all() - intents.presences = False - default_prefix: str = core.CONFIG.discord.default_prefix - - # Initialize and start the bot. - async with core.Beira( - command_prefix=default_prefix, - db_pool=pool, - web_session=web_session, - intents=intents, - tree_cls=HookableTree, - ) as bot: - bot.logging_manager = logging_manager - await bot.start(core.CONFIG.discord.token) - - # Needed for graceful exit? - await asyncio.sleep(0.1) - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index 33e3346..357479e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ authors = [ Homepage = "https://github.com/Sachaa-Thanasius/Beira" [tool.ruff] -include = ["main.py", "core/*", "exts/*", "**/pyproject.toml", "misc/**/*.py"] +include = ["src/beira/**/*.py", "misc/**/*.py"] line-length = 120 target-version = "py312" @@ -94,7 +94,7 @@ lines-after-imports = 2 combine-as-imports = true [tool.pyright] -include = ["main.py", "core", "exts"] +include = ["src/beira"] pythonVersion = "3.12" typeCheckingMode = "strict" diff --git a/requirements.txt b/requirements.txt index d024001..133c6df 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,12 @@ aiohttp -ao3.py @ git+https://github.com/Sachaa-Thanasius/ao3.py.git@main +ao3.py @ git+https://github.com/Sachaa-Thanasius/ao3.py@main arsenic async-lru asyncpg==0.29.0 asyncpg-stubs==0.29.1 -atlas-api @ https://github.com/Sachaa-Thanasius/atlas-api-wrapper/releases/download/v0.2.2/atlas_api-0.2.2-py3-none-any.whl +atlas-api @ https://github.com/Sachaa-Thanasius/atlas-api-wrapper discord.py[speed,voice]>=2.4.0 -fichub-api @ https://github.com/Sachaa-Thanasius/fichub-api/releases/download/v0.2.2/fichub_api-0.2.2-py3-none-any.whl -importlib_resources; python_version < "3.12" +fichub-api @ https://github.com/Sachaa-Thanasius/fichub-api jishaku @ git+https://github.com/Gorialis/jishaku@a6661e2813124fbfe53326913e54f7c91e5d0dec lxml>=4.9.3 markdownify diff --git a/src/beira/__init__.py b/src/beira/__init__.py new file mode 100644 index 0000000..121f3c5 --- /dev/null +++ b/src/beira/__init__.py @@ -0,0 +1,4 @@ +from .bot import * +from .checks import * +from .config import * +from .errors import * diff --git a/src/beira/__main__.py b/src/beira/__main__.py new file mode 100644 index 0000000..e69d385 --- /dev/null +++ b/src/beira/__main__.py @@ -0,0 +1,7 @@ +import asyncio + +from . import main + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/core/bot.py b/src/beira/bot.py similarity index 72% rename from core/bot.py rename to src/beira/bot.py index fedbcf0..758c6a8 100644 --- a/core/bot.py +++ b/src/beira/bot.py @@ -1,32 +1,71 @@ -"""bot.py: The main bot code.""" +"""The main bot code.""" +import asyncio import logging import sys import time import traceback -from typing import Any +from typing import Any, Self, overload from zoneinfo import ZoneInfo, ZoneInfoNotFoundError import aiohttp import ao3 import async_lru +import asyncpg import atlas_api import discord import fichub_api import wavelink from discord.ext import commands - +from discord.utils import MISSING from exts import EXTENSIONS from .checks import is_blocked -from .config import CONFIG -from .context import Context -from .utils import LoggingManager, Pool_alias +from .config import Config, load_config +from .tree import HookableTree +from .utils import LoggingManager, Pool_alias, conn_init LOGGER = logging.getLogger(__name__) +__all__ = ("Interaction", "Context", "GuildContext", "Beira", "main") + + +type Interaction = discord.Interaction[Beira] + + +class Context(commands.Context["Beira"]): + """A custom context subclass for Beira. + + Attributes + ---------- + session + db + """ + + voice_client: wavelink.Player | None # type: ignore # Type lie for narrowing + + @property + def session(self) -> aiohttp.ClientSession: + """`ClientSession`: Returns the asynchronous HTTP session used by the bot for HTTP requests.""" + + return self.bot.web_session + + @property + def db(self) -> Pool_alias: + """`Pool`: Returns the asynchronous connection pool used by the bot for database management.""" + + return self.bot.db_pool + + +class GuildContext(Context): + author: discord.Member # type: ignore # Type lie for narrowing + guild: discord.Guild # type: ignore # Type lie for narrowing + channel: discord.abc.GuildChannel | discord.Thread # type: ignore # Type lie for narrowing + me: discord.Member # type: ignore # Type lie for narrowing + + class Beira(commands.Bot): """A personal Discord bot for API experimentation. @@ -49,18 +88,20 @@ class Beira(commands.Bot): def __init__( self, *args: Any, + config: Config, db_pool: Pool_alias, web_session: aiohttp.ClientSession, initial_extensions: list[str] | None = None, **kwargs: Any, ) -> None: super().__init__(*args, **kwargs) + self.config = config self.db_pool = db_pool self.web_session = web_session self.initial_extensions: list[str] = initial_extensions or [] # Various webfiction-related clients. - atlas_auth = aiohttp.BasicAuth(CONFIG.atlas.user, CONFIG.atlas.password) + atlas_auth = aiohttp.BasicAuth(config.atlas.user, config.atlas.password) self.atlas_client = atlas_api.Client(auth=atlas_auth, session=self.web_session) self.fichub_client = fichub_api.Client(session=self.web_session) self.ao3_client = ao3.Client(session=self.web_session) @@ -87,7 +128,11 @@ async def setup_hook(self) -> None: await self._load_extensions() # Connect to lavalink node(s). - node = wavelink.Node(uri=CONFIG.lavalink.uri, password=CONFIG.lavalink.password, inactive_player_timeout=600) + node = wavelink.Node( + uri=self.config.lavalink.uri, + password=self.config.lavalink.password, + inactive_player_timeout=600, + ) await wavelink.Pool.connect(client=self, nodes=[node]) # Get information about owner. @@ -103,15 +148,24 @@ async def get_prefix(self, message: discord.Message, /) -> list[str] | str: return self.prefix_cache.get(message.guild.id, "$") if message.guild else "$" - async def get_context( + @overload + async def get_context(self, origin: discord.Message | discord.Interaction, /) -> commands.Context[Self]: ... + + @overload + async def get_context[ContextT: commands.Context[Any]]( + self, origin: discord.Message | discord.Interaction, /, *, cls: type[ContextT] + ) -> ContextT: ... + + async def get_context[ContextT: commands.Context[Any]]( self, origin: discord.Message | discord.Interaction, /, *, - cls: type[commands.Context[commands.Bot]] | None = None, - ) -> Context: - # Figure out if there's a way to type-hint this better to allow cls to actually work. - return await super().get_context(origin, cls=Context) + cls: type[ContextT] = MISSING, + ) -> Any: + if cls is MISSING: + cls = Context # pyright: ignore + return await super().get_context(origin, cls=cls) async def on_error(self, event_method: str, /, *args: object, **kwargs: object) -> None: exc_type, exception, tb = sys.exc_info() @@ -180,12 +234,9 @@ def owner(self) -> discord.User: async def _load_blocked_entities(self) -> None: """Load all blocked users and guilds from the bot database.""" - user_query = "SELECT user_id FROM users WHERE is_blocked;" - guild_query = "SELECT guild_id FROM guilds WHERE is_blocked;" - async with self.db_pool.acquire() as conn, conn.transaction(): - user_records = await conn.fetch(user_query) - guild_records = await conn.fetch(guild_query) + user_records = await conn.fetch("SELECT user_id FROM users WHERE is_blocked;") + guild_records = await conn.fetch("SELECT guild_id FROM guilds WHERE is_blocked;") self.blocked_entities_cache["users"] = {record["user_id"] for record in user_records} self.blocked_entities_cache["guilds"] = {record["guild_id"] for record in guild_records} @@ -194,18 +245,19 @@ async def _load_guild_prefixes(self, guild_id: int | None = None) -> None: """Load all prefixes from the bot database.""" query = "SELECT guild_id, prefix FROM guild_prefixes" - try: - if guild_id: - query += " WHERE guild_id = $1" + if guild_id: + query += " WHERE guild_id = $1" + try: db_prefixes = await self.db_pool.fetch(query) + except OSError: + LOGGER.exception("Couldn't load guild prefixes from the database. Ignoring for sake of defaults.") + else: for entry in db_prefixes: self.prefix_cache.setdefault(entry["guild_id"], []).append(entry["prefix"]) msg = f"(Re)loaded guild prefixes for {guild_id}." if guild_id else "(Re)loaded all guild prefixes." LOGGER.info(msg) - except OSError: - LOGGER.exception("Couldn't load guild prefixes from the database. Ignoring for sake of defaults.") async def _load_extensions(self) -> None: """Loads extensions/cogs. @@ -218,13 +270,13 @@ async def _load_extensions(self) -> None: exts_to_load = self.initial_extensions or EXTENSIONS all_exts_start_time = time.perf_counter() for extension in exts_to_load: + start_time = time.perf_counter() try: - start_time = time.perf_counter() await self.load_extension(extension) - end_time = time.perf_counter() except commands.ExtensionError as err: LOGGER.exception("Failed to load extension: %s", extension, exc_info=err) else: + end_time = time.perf_counter() LOGGER.info("Loaded extension: %s -- Time: %.5f", extension, end_time - start_time) all_exts_end_time = time.perf_counter() LOGGER.info("Total extension loading time: Time: %.5f", all_exts_end_time - all_exts_start_time) @@ -232,7 +284,7 @@ async def _load_extensions(self) -> None: async def _load_special_friends(self) -> None: await self.wait_until_ready() - friends_ids: list[int] = CONFIG.discord.friend_ids + friends_ids: list[int] = self.config.discord.friend_ids for user_id in friends_ids: if user_obj := self.get_user(user_id): self.special_friends[user_obj.name] = user_id @@ -267,3 +319,35 @@ def is_ali(self, user: discord.abc.User, /) -> bool: return user.id == self.special_friends["aeroali"] return False + + +async def main() -> None: + """Starts an instance of the bot.""" + + config = load_config() + + # Initialize a connection to a PostgreSQL database, an asynchronous web session, and a custom logger setup. + async with ( + aiohttp.ClientSession() as web_session, + asyncpg.create_pool(dsn=config.database.pg_url, command_timeout=30, init=conn_init) as pool, + LoggingManager() as logging_manager, + ): + # Set the bot's basic starting parameters. + intents = discord.Intents.all() + intents.presences = False + default_prefix: str = config.discord.default_prefix + + # Initialize and start the bot. + async with Beira( + command_prefix=default_prefix, + config=config, + db_pool=pool, + web_session=web_session, + intents=intents, + tree_cls=HookableTree, + ) as bot: + bot.logging_manager = logging_manager + await bot.start(config.discord.token) + + # Needed for graceful exit? + await asyncio.sleep(0.1) diff --git a/core/checks.py b/src/beira/checks.py similarity index 96% rename from core/checks.py rename to src/beira/checks.py index 7245a2a..5082c1f 100644 --- a/core/checks.py +++ b/src/beira/checks.py @@ -39,7 +39,7 @@ def is_owner_or_friend() -> "Check[Any]": This check raises a special exception, `.NotOwnerOrFriend` that is derived from `commands.CheckFailure`. """ - from .context import Context + from . import Context async def predicate(ctx: Context) -> bool: if not (ctx.bot.is_special_friend(ctx.author) or await ctx.bot.is_owner(ctx.author)): @@ -56,7 +56,7 @@ def is_admin() -> "Check[Any]": This check raises a special exception, `NotAdmin` that is derived from `commands.CheckFailure`. """ - from .context import GuildContext + from . import GuildContext async def predicate(ctx: GuildContext) -> bool: if not ctx.author.guild_permissions.administrator: @@ -73,7 +73,7 @@ def in_bot_vc() -> "Check[Any]": This check raises a special exception, `NotInBotVoiceChannel` that is derived from `commands.CheckFailure`. """ - from .context import GuildContext + from . import GuildContext async def predicate(ctx: GuildContext) -> bool: vc = ctx.voice_client @@ -94,7 +94,7 @@ def in_aci100_guild() -> "Check[Any]": This check raises the exception `commands.CheckFailure`. """ - from .context import GuildContext + from . import GuildContext async def predicate(ctx: GuildContext) -> bool: if ctx.guild.id != 602735169090224139: @@ -111,7 +111,7 @@ def is_blocked() -> "Check[Any]": This check raises the exception `commands.CheckFailure`. """ - from .context import Context + from . import Context async def predicate(ctx: Context) -> bool: if not (await ctx.bot.is_owner(ctx.author)): diff --git a/core/config.py b/src/beira/config.py similarity index 83% rename from core/config.py rename to src/beira/config.py index 2949648..33d266e 100644 --- a/core/config.py +++ b/src/beira/config.py @@ -1,4 +1,4 @@ -"""config.py: Imports configuration information, such as api keys and tokens, default prefixes, etc.""" +"""For loading configuration information, such as api keys and tokens, default prefixes, etc.""" import pathlib from typing import Any @@ -7,8 +7,8 @@ __all__ = ( - "CONFIG", "Config", + "load_config", ) @@ -75,9 +75,10 @@ class Config(Base): def decode(data: bytes | str) -> Config: - """Decode a ``config.toml`` file from TOML.""" + """Decode a TOMl file with the `Config` schema.""" return msgspec.toml.decode(data, type=Config) -CONFIG = decode(pathlib.Path("config.toml").read_text(encoding="utf-8")) +def load_config() -> Config: + return decode(pathlib.Path("config.toml").read_text(encoding="utf-8")) diff --git a/core/errors.py b/src/beira/errors.py similarity index 98% rename from core/errors.py rename to src/beira/errors.py index 5c04ee6..09b4fcb 100644 --- a/core/errors.py +++ b/src/beira/errors.py @@ -1,4 +1,4 @@ -"""errors.py: Custom errors used by the bot.""" +"""Custom errors used by the bot.""" from discord import app_commands from discord.ext import commands diff --git a/exts/__init__.py b/src/beira/exts/__init__.py similarity index 100% rename from exts/__init__.py rename to src/beira/exts/__init__.py diff --git a/exts/_dev/_dev.py b/src/beira/exts/_dev.py similarity index 89% rename from exts/_dev/_dev.py rename to src/beira/exts/_dev.py index 02e5754..6fc32ae 100644 --- a/exts/_dev/_dev.py +++ b/src/beira/exts/_dev.py @@ -1,6 +1,4 @@ -"""_dev.py: A cog that implements commands for reloading and syncing extensions and other commands, at the owner's -behest. -""" +"""A cog that implements commands for reloading and syncing extensions and other commands, at the owner's behest.""" import contextlib import logging @@ -13,8 +11,9 @@ from discord import app_commands from discord.ext import commands -import core -from exts import EXTENSIONS +import beira + +from . import EXTENSIONS LOGGER = logging.getLogger(__name__) @@ -31,11 +30,6 @@ ("[+] —— (D-N-T!) Clear all commands from all guilds and sync, thereby removing all guild commands.", "+"), ] -dev_guilds_objects = [discord.Object(id=guild_id) for guild_id in core.CONFIG.discord.important_guilds["dev"]] - -# Preload the dev-guild-only app commands decorator. -only_dev_guilds = app_commands.guilds(*dev_guilds_objects) - class DevCog(commands.Cog, name="_Dev", command_attrs={"hidden": True}): """A cog for handling bot-related like syncing commands or reloading cogs while live. @@ -43,8 +37,10 @@ class DevCog(commands.Cog, name="_Dev", command_attrs={"hidden": True}): Meant to be used by the bot dev(s) only. """ - def __init__(self, bot: core.Beira) -> None: + def __init__(self, bot: beira.Beira, dev_guilds: list[discord.Object]) -> None: self.bot = bot + self.dev_guilds = dev_guilds + self.block_add_ctx_menu = app_commands.ContextMenu( name="Bot Block", callback=self.context_menu_block_add, @@ -53,8 +49,9 @@ def __init__(self, bot: core.Beira) -> None: name="Bot Unblock", callback=self.context_menu_block_remove, ) - self.bot.tree.add_command(self.block_add_ctx_menu, guilds=dev_guilds_objects) - self.bot.tree.add_command(self.block_remove_ctx_menu, guilds=dev_guilds_objects) + + self.bot.tree.add_command(self.block_add_ctx_menu, guilds=dev_guilds) + self.bot.tree.add_command(self.block_remove_ctx_menu, guilds=dev_guilds) @property def cog_emoji(self) -> discord.PartialEmoji: @@ -63,7 +60,7 @@ def cog_emoji(self) -> discord.PartialEmoji: return discord.PartialEmoji(name="discord_dev", animated=True, id=1084608963896672256) async def cog_unload(self) -> None: - for dev_guild in dev_guilds_objects: + for dev_guild in self.dev_guilds: self.bot.tree.remove_command( self.block_add_ctx_menu.name, type=self.block_add_ctx_menu.type, @@ -75,12 +72,12 @@ async def cog_unload(self) -> None: guild=dev_guild, ) - async def cog_check(self, ctx: core.Context) -> bool: # type: ignore # Narrowing, and async is allowed. + async def cog_check(self, ctx: beira.Context) -> bool: # type: ignore # Narrowing, and async is allowed. """Set up bot owner check as universal within the cog.""" return await self.bot.is_owner(ctx.author) - async def cog_command_error(self, ctx: core.Context, error: Exception) -> None: # type: ignore # Narrowing + async def cog_command_error(self, ctx: beira.Context, error: Exception) -> None: # type: ignore # Narrowing assert ctx.command # Extract the original error. @@ -99,8 +96,7 @@ async def cog_command_error(self, ctx: core.Context, error: Exception) -> None: await ctx.send(embed=embed, ephemeral=True) @commands.hybrid_group(fallback="get") - @only_dev_guilds - async def block(self, ctx: core.Context) -> None: + async def block(self, ctx: beira.Context) -> None: """A group of commands for blocking and unblocking users or guilds from using the bot. By default, display the users and guilds that are blocked from using the bot. @@ -122,7 +118,7 @@ async def block(self, ctx: core.Context) -> None: @block.command("add") async def block_add( self, - ctx: core.Context, + ctx: beira.Context, block_type: Literal["users", "guilds"] = "users", *, entities: commands.Greedy[discord.Object], @@ -131,7 +127,7 @@ async def block_add( Parameters ---------- - ctx: `core.Context` + ctx: `beira.Context` The invocation context. block_type: Literal["user", "guild"], default="user" What type of entity or entities are being blocked. Defaults to "user". @@ -167,10 +163,9 @@ async def block_add( await ctx.send("Blocked the following from bot usage:", embed=embed, ephemeral=True) @block.command("remove") - @only_dev_guilds async def block_remove( self, - ctx: core.Context, + ctx: beira.Context, block_type: Literal["users", "guild"] = "users", *, entities: commands.Greedy[discord.Object], @@ -179,7 +174,7 @@ async def block_remove( Parameters ---------- - ctx: `core.Context` + ctx: `beira.Context` The invocation context block_type: Literal["user", "guild"], default="user" What type of entity or entities are being unblocked. Defaults to "user". @@ -216,7 +211,7 @@ async def block_remove( @block_add.error @block_remove.error - async def block_change_error(self, ctx: core.Context, error: commands.CommandError) -> None: + async def block_change_error(self, ctx: beira.Context, error: commands.CommandError) -> None: # Extract the original error. error = getattr(error, "original", error) if ctx.interaction: @@ -232,7 +227,7 @@ async def block_change_error(self, ctx: core.Context, error: commands.CommandErr @app_commands.check(lambda interaction: interaction.user.id == interaction.client.owner_id) async def context_menu_block_add( self, - interaction: core.Interaction, + interaction: beira.Interaction, user: discord.User | discord.Member, ) -> None: stmt = """ @@ -252,7 +247,7 @@ async def context_menu_block_add( @app_commands.check(lambda interaction: interaction.user.id == interaction.client.owner_id) async def context_menu_block_remove( self, - interaction: core.Interaction, + interaction: beira.Interaction, user: discord.User | discord.Member, ) -> None: stmt = """ @@ -270,8 +265,7 @@ async def context_menu_block_remove( await interaction.response.send_message("Unlocked the following from bot usage:", embed=embed, ephemeral=True) @commands.hybrid_command() - @only_dev_guilds - async def shutdown(self, ctx: core.Context) -> None: + async def shutdown(self, ctx: beira.Context) -> None: """Shut down the bot.""" LOGGER.info("Shutting down bot with dev command.") @@ -279,13 +273,12 @@ async def shutdown(self, ctx: core.Context) -> None: await self.bot.close() @commands.hybrid_command() - @only_dev_guilds - async def walk(self, ctx: core.Context) -> None: + async def walk(self, ctx: beira.Context) -> None: """Walk through all app commands globally and in every guild to see what is synced and where. Parameters ---------- - ctx: `core.Context` + ctx: `beira.Context` The invocation context where the command was called. """ @@ -310,14 +303,13 @@ def create_walk_embed(title: str, cmds: list[app_commands.AppCommand]) -> None: await ctx.reply(embeds=all_embeds, ephemeral=True) @commands.hybrid_command() - @only_dev_guilds @app_commands.describe(extension="The file name of the extension/cog you wish to load, excluding the file type.") - async def load(self, ctx: core.Context, extension: str) -> None: + async def load(self, ctx: beira.Context, extension: str) -> None: """Loads an extension/cog. Parameters ---------- - ctx: `core.Context` + ctx: `beira.Context` The invocation context. extension: `str` The name of the chosen extension to load, excluding the file type. If activated as a prefix command, the @@ -339,7 +331,7 @@ async def load(self, ctx: core.Context, extension: str) -> None: await ctx.send(embed=embed, ephemeral=True) @load.autocomplete("extension") - async def load_ext_autocomplete(self, _: core.Interaction, current: str) -> list[app_commands.Choice[str]]: + async def load_ext_autocomplete(self, _: beira.Interaction, current: str) -> list[app_commands.Choice[str]]: """Autocompletes names for extensions that are ignored or unloaded.""" exts_to_load = set(EXTENSIONS).difference(set(self.bot.extensions), set(IGNORE_EXTENSIONS)) @@ -350,14 +342,13 @@ async def load_ext_autocomplete(self, _: core.Interaction, current: str) -> list ][:25] @commands.hybrid_command() - @only_dev_guilds @app_commands.describe(extension="The file name of the extension/cog you wish to unload, excluding the file type.") - async def unload(self, ctx: core.Context, extension: str) -> None: + async def unload(self, ctx: beira.Context, extension: str) -> None: """Unloads an extension/cog. Parameters ---------- - ctx: `core.Context` + ctx: `beira.Context` The invocation context. extension: `str` The name of the chosen extension to unload, excluding the file type. If activated as a prefix command, the @@ -379,14 +370,13 @@ async def unload(self, ctx: core.Context, extension: str) -> None: await ctx.send(embed=embed, ephemeral=True) @commands.hybrid_command() - @only_dev_guilds @app_commands.describe(extension="The file name of the extension/cog you wish to reload, excluding the file type.") - async def reload(self, ctx: core.Context, extension: str) -> None: + async def reload(self, ctx: beira.Context, extension: str) -> None: """Reloads an extension/cog. Parameters ---------- - ctx: `core.Context` + ctx: `beira.Context` The invocation context. extension: `str` The name of the chosen extension to reload, excluding the file type. If activated as a prefix command, the @@ -433,7 +423,7 @@ async def reload(self, ctx: core.Context, extension: str) -> None: @unload.autocomplete("extension") @reload.autocomplete("extension") - async def ext_autocomplete(self, _: core.Interaction, current: str) -> list[app_commands.Choice[str]]: + async def ext_autocomplete(self, _: beira.Interaction, current: str) -> list[app_commands.Choice[str]]: """Autocompletes names for currently loaded extensions.""" return [ @@ -443,11 +433,10 @@ async def ext_autocomplete(self, _: core.Interaction, current: str) -> list[app_ ][:25] @commands.hybrid_command("sync") - @only_dev_guilds @app_commands.choices(spec=[app_commands.Choice(name=name, value=value) for name, value in SPEC_CHOICES]) async def sync_( self, - ctx: core.Context, + ctx: beira.Context, guilds: commands.Greedy[discord.Object] = None, # type: ignore # Can't be type-hinted as optional. spec: app_commands.Choice[str] | None = None, ) -> None: @@ -457,7 +446,7 @@ async def sync_( Parameters ---------- - ctx: `core.Context` + ctx: `beira.Context` The invocation context. guilds: Greedy[`discord.Object`], optional The guilds to sync the app commands if no specification is entered. Converts guild ids to @@ -533,12 +522,12 @@ async def sync_( await ctx.send(f"Synced the tree to {ret}/{len(guilds)}.", ephemeral=True) @sync_.error - async def sync_error(self, ctx: core.Context, error: commands.CommandError) -> None: + async def sync_error(self, ctx: beira.Context, error: commands.CommandError) -> None: """A local error handler for the :meth:`sync_` command. Parameters ---------- - ctx: `core.Context` + ctx: `beira.Context` The invocation context. error: `commands.CommandError` The error thrown by the command. @@ -573,7 +562,7 @@ async def sync_error(self, ctx: core.Context, error: commands.CommandError) -> N await ctx.reply(embed=embed) @commands.hybrid_command() - async def cmd_tree(self, ctx: core.Context) -> None: + async def cmd_tree(self, ctx: beira.Context) -> None: indent_level = 0 @contextlib.contextmanager @@ -587,7 +576,11 @@ def new_indent(num: int = 4) -> Generator[None, object, None]: def walk_commands_with_indent(group: commands.GroupMixin[Any]) -> Generator[str, object, None]: for cmd in group.commands: - indent = "" if (indent_level == 0) else (indent_level - 1) * "─" + if indent_level != 0: # noqa: SIM108 + indent = (indent_level - 1) * "─" + else: + indent = "" + yield f"└{indent}{cmd.qualified_name}" if isinstance(cmd, commands.GroupMixin): @@ -596,3 +589,19 @@ def walk_commands_with_indent(group: commands.GroupMixin[Any]) -> Generator[str, result = "\n".join(["Beira", *walk_commands_with_indent(ctx.bot)]) await ctx.send(f"```\n{result}\n```") + + +async def setup(bot: beira.Beira) -> None: + """Connects cog to bot.""" + + # Can't use the guilds kwarg in add_cog, as it doesn't currently work for hybrids. + # Ref: https://github.com/Rapptz/discord.py/pull/9428 + dev_guilds_objects = [discord.Object(id=guild_id) for guild_id in bot.config.discord.important_guilds["dev"]] + cog = DevCog(bot, dev_guilds_objects) + for cmd in cog.walk_app_commands(): + if cmd._guild_ids is None: + cmd._guild_ids = [g.id for g in dev_guilds_objects] + else: + cmd._guild_ids.extend(g.id for g in dev_guilds_objects) + + await bot.add_cog(cog) diff --git a/exts/_dev/_test.py b/src/beira/exts/_test.py similarity index 70% rename from exts/_dev/_test.py rename to src/beira/exts/_test.py index a09b7de..4d87b04 100644 --- a/exts/_dev/_test.py +++ b/src/beira/exts/_test.py @@ -4,10 +4,8 @@ from discord import app_commands from discord.ext import commands -import core -from core.tree import after_app_invoke, before_app_invoke - -from ._dev import only_dev_guilds +import beira +from beira.tree import after_app_invoke, before_app_invoke LOGGER = logging.getLogger(__name__) @@ -24,15 +22,15 @@ async def example_after_hook(itx: discord.Interaction) -> None: class TestCog(commands.Cog, name="_Test", command_attrs={"hidden": True}): - def __init__(self, bot: core.Beira) -> None: + def __init__(self, bot: beira.Beira) -> None: self.bot = bot - async def cog_check(self, ctx: core.Context) -> bool: # type: ignore # Narrowing, and async is allowed. + async def cog_check(self, ctx: beira.Context) -> bool: # type: ignore # Narrowing, and async is allowed. """Set up bot owner check as universal within the cog.""" return await self.bot.is_owner(ctx.author) - async def cog_command_error(self, ctx: core.Context, error: Exception) -> None: # type: ignore # Narrowing + async def cog_command_error(self, ctx: beira.Context, error: Exception) -> None: # type: ignore # Narrowing # Extract the original error. error = getattr(error, "original", error) if ctx.interaction: @@ -41,27 +39,25 @@ async def cog_command_error(self, ctx: core.Context, error: Exception) -> None: LOGGER.exception("", exc_info=error) @commands.command() - async def test_pre(self, ctx: core.Context) -> None: + async def test_pre(self, ctx: beira.Context) -> None: """Test prefix command.""" await ctx.send("Test prefix command.") @commands.hybrid_command() - @only_dev_guilds - async def test_hy(self, ctx: core.Context) -> None: + async def test_hy(self, ctx: beira.Context) -> None: """Test hybrid command.""" await ctx.send("Test hybrid command.") @app_commands.command() - @only_dev_guilds - async def test_sl(self, interaction: core.Interaction) -> None: + async def test_sl(self, interaction: beira.Interaction) -> None: """Test app command.""" await interaction.response.send_message("Test app command.") @commands.command() - async def test_embeds(self, ctx: core.Context) -> None: + async def test_embeds(self, ctx: beira.Context) -> None: """Test multiple images in an embeds.""" await ctx.send("Test hybrid command.") @@ -88,9 +84,24 @@ async def test_embeds(self, ctx: core.Context) -> None: @before_app_invoke(example_before_hook) @after_app_invoke(example_after_hook) @app_commands.command() - @only_dev_guilds async def test_hooks(self, itx: discord.Interaction, arg: str) -> None: """Test the custom pre and post-command hooking mechanism.""" send_msg = itx.response.send_message if not itx.response.is_done() else itx.followup.send await send_msg(f"In command with given argument: {arg}") + + +async def setup(bot: beira.Beira) -> None: + """Connects cog to bot.""" + + # Can't use the guilds kwarg in add_cog, as it doesn't currently work for hybrids. + # Ref: https://github.com/Rapptz/discord.py/pull/9428 + dev_guilds_objects = [discord.Object(id=guild_id) for guild_id in bot.config.discord.important_guilds["dev"]] + cog = TestCog(bot) + for cmd in cog.walk_app_commands(): + if cmd._guild_ids is None: + cmd._guild_ids = [g.id for g in dev_guilds_objects] + else: + cmd._guild_ids.extend(g.id for g in dev_guilds_objects) + + await bot.add_cog(cog) diff --git a/exts/admin.py b/src/beira/exts/admin.py similarity index 85% rename from exts/admin.py rename to src/beira/exts/admin.py index 778a774..4df9244 100644 --- a/exts/admin.py +++ b/src/beira/exts/admin.py @@ -8,7 +8,7 @@ from asyncpg import PostgresError, PostgresWarning from discord.ext import commands -import core +import beira LOGGER = logging.getLogger(__name__) @@ -17,7 +17,7 @@ class AdminCog(commands.Cog, name="Administration"): """A cog for handling administrative tasks like adding and removing prefixes.""" - def __init__(self, bot: core.Beira) -> None: + def __init__(self, bot: beira.Beira) -> None: self.bot = bot @property @@ -26,7 +26,7 @@ def cog_emoji(self) -> discord.PartialEmoji: return discord.PartialEmoji(name="endless_gears", animated=True, id=1077981366911766549) - async def cog_command_error(self, ctx: core.Context, error: Exception) -> None: # type: ignore # Narrowing + async def cog_command_error(self, ctx: beira.Context, error: Exception) -> None: # type: ignore # Narrowing # Extract the original error. error = getattr(error, "original", error) if ctx.interaction: @@ -36,7 +36,7 @@ async def cog_command_error(self, ctx: core.Context, error: Exception) -> None: @commands.hybrid_command() @commands.guild_only() - async def get_timeouts(self, ctx: core.GuildContext) -> None: + async def get_timeouts(self, ctx: beira.GuildContext) -> None: """Get all timed out members on the server.""" async with ctx.typing(): @@ -51,7 +51,7 @@ async def get_timeouts(self, ctx: core.GuildContext) -> None: @commands.hybrid_group(fallback="get") @commands.guild_only() - async def prefixes(self, ctx: core.GuildContext) -> None: + async def prefixes(self, ctx: beira.GuildContext) -> None: """View the prefixes set for this bot in this location.""" async with ctx.typing(): @@ -61,13 +61,13 @@ async def prefixes(self, ctx: core.GuildContext) -> None: @prefixes.command("add") @commands.guild_only() - @commands.check_any(commands.is_owner(), core.is_admin()) - async def prefixes_add(self, ctx: core.GuildContext, *, new_prefix: str) -> None: + @commands.check_any(commands.is_owner(), beira.is_admin()) + async def prefixes_add(self, ctx: beira.GuildContext, *, new_prefix: str) -> None: """Set a prefix that you'd like this bot to respond to. Parameters ---------- - ctx: `core.GuildContext` + ctx: `beira.GuildContext` The invocation context. new_prefix: `str` The prefix to be added. @@ -95,13 +95,13 @@ async def prefixes_add(self, ctx: core.GuildContext, *, new_prefix: str) -> None @prefixes.command("remove") @commands.guild_only() - @commands.check_any(commands.is_owner(), core.is_admin()) - async def prefixes_remove(self, ctx: core.GuildContext, *, old_prefix: str) -> None: + @commands.check_any(commands.is_owner(), beira.is_admin()) + async def prefixes_remove(self, ctx: beira.GuildContext, *, old_prefix: str) -> None: """Remove a prefix that you'd like this bot to no longer respond to. Parameters ---------- - ctx: `core.GuildContext` + ctx: `beira.GuildContext` The invocation context. old_prefix: `str` The prefix to be removed. @@ -122,13 +122,13 @@ async def prefixes_remove(self, ctx: core.GuildContext, *, old_prefix: str) -> N @prefixes.command("reset") @commands.guild_only() - @commands.check_any(commands.is_owner(), core.is_admin()) - async def prefixes_reset(self, ctx: core.GuildContext) -> None: + @commands.check_any(commands.is_owner(), beira.is_admin()) + async def prefixes_reset(self, ctx: beira.GuildContext) -> None: """Remove all prefixes within this server for the bot to respond to. Parameters ---------- - ctx: `core.GuildContext` + ctx: `beira.GuildContext` The invocation context. """ @@ -144,7 +144,7 @@ async def prefixes_reset(self, ctx: core.GuildContext) -> None: @prefixes_add.error @prefixes_remove.error @prefixes_reset.error - async def prefixes_subcommands_error(self, ctx: core.Context, error: commands.CommandError) -> None: + async def prefixes_subcommands_error(self, ctx: beira.Context, error: commands.CommandError) -> None: # Extract the original error. error = getattr(error, "original", error) if ctx.interaction: @@ -164,7 +164,7 @@ async def prefixes_subcommands_error(self, ctx: core.Context, error: commands.Co LOGGER.exception("", exc_info=error) -async def setup(bot: core.Beira) -> None: +async def setup(bot: beira.Beira) -> None: """Connects cog to bot.""" await bot.add_cog(AdminCog(bot)) diff --git a/exts/bot_stats.py b/src/beira/exts/bot_stats.py similarity index 89% rename from exts/bot_stats.py rename to src/beira/exts/bot_stats.py index faba533..5167b1b 100644 --- a/exts/bot_stats.py +++ b/src/beira/exts/bot_stats.py @@ -9,8 +9,8 @@ from discord.app_commands import Choice from discord.ext import commands -import core -from core.utils import StatsEmbed +import beira +from beira.utils import StatsEmbed LOGGER = logging.getLogger(__name__) @@ -37,7 +37,7 @@ class CommandStatsSearchFlags(commands.FlagConverter): class BotStatsCog(commands.Cog, name="Bot Stats"): """A cog for tracking different bot metrics.""" - def __init__(self, bot: core.Beira) -> None: + def __init__(self, bot: beira.Beira) -> None: self.bot = bot @property @@ -46,7 +46,7 @@ def cog_emoji(self) -> discord.PartialEmoji: return discord.PartialEmoji(name="\N{CHART WITH UPWARDS TREND}") - async def cog_command_error(self, ctx: core.Context, error: Exception) -> None: # type: ignore # Narrowing + async def cog_command_error(self, ctx: beira.Context, error: Exception) -> None: # type: ignore # Narrowing # Extract the original error. error = getattr(error, "original", error) if ctx.interaction: @@ -54,7 +54,7 @@ async def cog_command_error(self, ctx: core.Context, error: Exception) -> None: LOGGER.exception("", exc_info=error) - async def track_command_use(self, ctx: core.Context) -> None: + async def track_command_use(self, ctx: beira.Context) -> None: """Stores records of command uses in the database after some processing.""" assert ctx.command is not None @@ -97,13 +97,13 @@ async def track_command_use(self, ctx: core.Context) -> None: await db.execute(stmt, *cmd, timeout=60.0) @commands.Cog.listener("on_command_completion") - async def track_command_completion(self, ctx: core.Context) -> None: + async def track_command_completion(self, ctx: beira.Context) -> None: """Record prefix and hybrid command usage.""" await self.track_command_use(ctx) @commands.Cog.listener("on_interaction") - async def track_interaction(self, interaction: core.Interaction) -> None: + async def track_interaction(self, interaction: beira.Interaction) -> None: """Record application command usage, ignoring hybrid or other interactions. References @@ -116,12 +116,12 @@ async def track_interaction(self, interaction: core.Interaction) -> None: and interaction.type is discord.InteractionType.application_command and not isinstance(interaction.command, commands.hybrid.HybridAppCommand) ): - ctx = await core.Context.from_interaction(interaction) + ctx = await beira.Context.from_interaction(interaction) ctx.command_failed = interaction.command_failed await self.track_command_use(ctx) @commands.Cog.listener("on_command_error") - async def track_command_error(self, ctx: core.Context, error: commands.CommandError) -> None: + async def track_command_error(self, ctx: beira.Context, error: commands.CommandError) -> None: """Records prefix, hybrid, and application command usage, even if the result is an error.""" if not isinstance(error, commands.CommandNotFound): @@ -135,12 +135,12 @@ async def add_guild_to_db(self, guild: discord.Guild) -> None: await self.bot.db_pool.execute(stmt, guild.id, timeout=60.0) @commands.hybrid_command(name="usage") - async def check_usage(self, ctx: core.Context, *, search_factors: CommandStatsSearchFlags) -> None: + async def check_usage(self, ctx: beira.Context, *, search_factors: CommandStatsSearchFlags) -> None: """Retrieve statistics about bot command usage. Parameters ---------- - ctx: `core.Context` + ctx: `beira.Context` The invocation context. search_factors: `CommandStatsSearchFlags` A flag converter for taking a few query specifications when searching for usage stats. @@ -227,11 +227,11 @@ async def get_usage( return await self.bot.db_pool.fetch(query, *query_args) @check_usage.autocomplete("command") - async def command_autocomplete(self, interaction: core.Interaction, current: str) -> list[Choice[str]]: + async def command_autocomplete(self, interaction: beira.Interaction, current: str) -> list[Choice[str]]: """Autocompletes with bot command names.""" assert self.bot.help_command - ctx = await self.bot.get_context(interaction, cls=core.Context) + ctx = await self.bot.get_context(interaction, cls=beira.Context) help_command = self.bot.help_command.copy() help_command.context = ctx @@ -243,7 +243,7 @@ async def command_autocomplete(self, interaction: core.Interaction, current: str ][:25] -async def setup(bot: core.Beira) -> None: +async def setup(bot: beira.Beira) -> None: """Connects cog to bot.""" await bot.add_cog(BotStatsCog(bot)) diff --git a/exts/dice.py b/src/beira/exts/dice.py similarity index 97% rename from exts/dice.py rename to src/beira/exts/dice.py index b5fb36a..a7ae9bf 100644 --- a/exts/dice.py +++ b/src/beira/exts/dice.py @@ -15,8 +15,8 @@ from discord import ui from discord.ext import commands -import core -from core.utils import EMOJI_STOCK +import beira +from beira.utils import EMOJI_STOCK LOGGER = logging.getLogger(__name__) @@ -448,7 +448,7 @@ async def on_error(self, _: discord.Interaction, error: Exception, item: ui.Item emoji="\N{HEAVY PLUS SIGN}", row=3, ) - async def set_modifier(self, interaction: core.Interaction, button: ui.Button[Self]) -> None: + async def set_modifier(self, interaction: beira.Interaction, button: ui.Button[Self]) -> None: """Allow the user to set a modifier to add to the result of a roll or series of rolls. Applies to individual buttons and the select menu. Happens once at the end of multiple rolls. @@ -482,7 +482,7 @@ async def set_modifier(self, interaction: core.Interaction, button: ui.Button[Se emoji="\N{HEAVY MULTIPLICATION X}", row=3, ) - async def set_number(self, interaction: core.Interaction, button: ui.Button[Self]) -> None: + async def set_number(self, interaction: beira.Interaction, button: ui.Button[Self]) -> None: """Allow the user to set the number of dice to roll at once. Applies to individual buttons and the select menu. @@ -523,7 +523,7 @@ async def set_number(self, interaction: core.Interaction, button: ui.Button[Self emoji="\N{ABACUS}", row=3, ) - async def set_expression(self, interaction: core.Interaction, _: ui.Button[Self]) -> None: + async def set_expression(self, interaction: beira.Interaction, _: ui.Button[Self]) -> None: """Allow the user to enter a custom dice expression to be evaluated for result.""" # Create and send a modal for user input. @@ -589,12 +589,12 @@ async def run_expression(self, interaction: discord.Interaction, _: ui.Button[Se @commands.hybrid_command() -async def roll(ctx: core.Context, expression: str | None = None) -> None: +async def roll(ctx: beira.Context, expression: str | None = None) -> None: """Send an interface for rolling different dice. Parameters ---------- - ctx: `core.Context` + ctx: `beira.Context` The invocation context. expression: `str`, optional A custom dice expression to calculate. Optional. @@ -621,7 +621,7 @@ async def roll(ctx: core.Context, expression: str | None = None) -> None: @roll.error # pyright: ignore [reportUnknownMemberType] # discord.py bug: see https://github.com/Rapptz/discord.py/issues/9788. -async def roll_error(ctx: core.Context, error: commands.CommandError) -> None: +async def roll_error(ctx: beira.Context, error: commands.CommandError) -> None: # Extract the original error. error = getattr(error, "original", error) if ctx.interaction: @@ -630,7 +630,7 @@ async def roll_error(ctx: core.Context, error: commands.CommandError) -> None: LOGGER.exception("", exc_info=error) -async def setup(bot: core.Beira) -> None: +async def setup(bot: beira.Beira) -> None: """Add roll command and persistent dice view to bot.""" bot.add_view(DiceView()) diff --git a/exts/emoji_ops.py b/src/beira/exts/emoji_ops.py similarity index 95% rename from exts/emoji_ops.py rename to src/beira/exts/emoji_ops.py index 6b6a750..6803927 100644 --- a/exts/emoji_ops.py +++ b/src/beira/exts/emoji_ops.py @@ -15,7 +15,7 @@ from discord.ext import commands from PIL import Image -import core +import beira LOGGER = logging.getLogger(__name__) @@ -161,7 +161,7 @@ async def on_error(self, interaction: discord.Interaction, error: Exception, ite @app_commands.context_menu(name="Add Emoji(s)") @app_commands.checks.has_permissions(manage_emojis_and_stickers=True) @app_commands.checks.bot_has_permissions(manage_emojis_and_stickers=True) -async def context_menu_emoji_add(interaction: core.Interaction, message: discord.Message) -> None: +async def context_menu_emoji_add(interaction: beira.Interaction, message: discord.Message) -> None: """Context menu command for adding emojis from a message to the guild in context.""" # Regex taken from commands.PartialEmojiConverter. @@ -192,7 +192,7 @@ async def context_menu_emoji_add(interaction: core.Interaction, message: discord @app_commands.context_menu(name="Add Sticker(s)") @app_commands.checks.has_permissions(manage_emojis_and_stickers=True) @app_commands.checks.bot_has_permissions(manage_emojis_and_stickers=True) -async def context_menu_sticker_add(interaction: core.Interaction, message: discord.Message) -> None: +async def context_menu_sticker_add(interaction: beira.Interaction, message: discord.Message) -> None: """Context menu command for adding stickers from a message to the guild in context.""" added_count = 0 @@ -226,7 +226,7 @@ async def context_menu_sticker_add(interaction: core.Interaction, message: disco class EmojiOpsCog(commands.Cog, name="Emoji Operations"): """A cog with commands for performing actions with emojis and stickers.""" - def __init__(self, bot: core.Beira) -> None: + def __init__(self, bot: beira.Beira) -> None: self.bot = bot self.bot.tree.add_command(context_menu_sticker_add) self.bot.tree.add_command(context_menu_emoji_add) @@ -241,7 +241,7 @@ async def cog_unload(self) -> None: self.bot.tree.remove_command(context_menu_sticker_add.name, type=context_menu_sticker_add.type) self.bot.tree.remove_command(context_menu_sticker_add.name, type=context_menu_sticker_add.type) - async def cog_command_error(self, ctx: core.Context, error: Exception) -> None: # type: ignore # Narrowing + async def cog_command_error(self, ctx: beira.Context, error: Exception) -> None: # type: ignore # Narrowing """A local error handler for the emoji and sticker-related commands. Parameters @@ -276,12 +276,12 @@ async def cog_command_error(self, ctx: core.Context, error: Exception) -> None: await ctx.send(embed=embed) @staticmethod - async def convert_str_to_emoji(ctx: core.Context, entity: str) -> discord.Emoji | discord.PartialEmoji | None: + async def convert_str_to_emoji(ctx: beira.Context, entity: str) -> discord.Emoji | discord.PartialEmoji | None: """Attempt to convert a string to an emoji or partial emoji. Parameters ---------- - ctx: `core.Context` + ctx: `beira.Context` The invocation context. entity: `str` The string that might be an emoji or unicode character. @@ -313,18 +313,18 @@ async def convert_str_to_emoji(ctx: core.Context, entity: str) -> discord.Emoji @commands.hybrid_group("emoji") @commands.guild_only() - async def emoji_(self, ctx: core.GuildContext) -> None: + async def emoji_(self, ctx: beira.GuildContext) -> None: """A group of emoji-related commands, like identifying emojis and adding them to a server.""" await ctx.send_help(ctx.command) @emoji_.command("info") - async def emoji_info(self, ctx: core.GuildContext, entity: str, *, ephemeral: bool = True) -> None: + async def emoji_info(self, ctx: beira.GuildContext, entity: str, *, ephemeral: bool = True) -> None: """Identify a particular emoji and see information about it. Parameters ---------- - ctx: `core.GuildContext` + ctx: `beira.GuildContext` The invocation context. entity: `str` The emoji to provide information about. @@ -372,7 +372,7 @@ async def emoji_info(self, ctx: core.GuildContext, entity: str, *, ephemeral: bo @commands.has_guild_permissions(manage_emojis_and_stickers=True) async def emoji_add( self, - ctx: core.GuildContext, + ctx: beira.GuildContext, name: str, entity: str | None = None, attachment: discord.Attachment | None = None, @@ -381,7 +381,7 @@ async def emoji_add( Parameters ---------- - ctx: `core.GuildContext` + ctx: `beira.GuildContext` The invocation context. name: `str` The name of the emoji. @@ -429,18 +429,18 @@ async def emoji_add( @commands.hybrid_group() @commands.guild_only() - async def sticker(self, ctx: core.GuildContext) -> None: + async def sticker(self, ctx: beira.GuildContext) -> None: """A group of sticker-related commands, like adding them to a server.""" await ctx.send_help(ctx.command) @sticker.command("info") - async def sticker_info(self, ctx: core.GuildContext, sticker: str, *, ephemeral: bool = True) -> None: + async def sticker_info(self, ctx: beira.GuildContext, sticker: str, *, ephemeral: bool = True) -> None: """Identify a particular sticker and see information about it. Parameters ---------- - ctx: `core.GuildContext` + ctx: `beira.GuildContext` The invocation context. sticker: `discord.GuildSticker` The id or name of the sticker to provide information about. @@ -484,7 +484,7 @@ async def sticker_info(self, ctx: core.GuildContext, sticker: str, *, ephemeral: @commands.has_guild_permissions(manage_emojis_and_stickers=True) async def sticker_add( self, - ctx: core.GuildContext, + ctx: beira.GuildContext, sticker: str | None = None, *, sticker_flags: GuildStickerFlags, @@ -493,7 +493,7 @@ async def sticker_add( Parameters ---------- - ctx: `core.GuildContext` + ctx: `beira.GuildContext` The invocation context. sticker: `discord.GuildSticker`, optional The name or id of an existing sticker to steal. If filled, no other parameters are necessary. @@ -531,7 +531,7 @@ async def sticker_add( await ctx.send(f"Sticker successfully added: `{new_sticker.name}`.", stickers=[new_sticker]) -async def setup(bot: core.Beira) -> None: +async def setup(bot: beira.Beira) -> None: """Connects cog to bot.""" await bot.add_cog(EmojiOpsCog(bot)) diff --git a/exts/fandom_wiki.py b/src/beira/exts/fandom_wiki.py similarity index 95% rename from exts/fandom_wiki.py rename to src/beira/exts/fandom_wiki.py index 8909227..7f92f67 100644 --- a/exts/fandom_wiki.py +++ b/src/beira/exts/fandom_wiki.py @@ -15,8 +15,8 @@ from discord import app_commands from discord.ext import commands -import core -from core.utils import EMOJI_URL, html_to_markdown +import beira +from beira.utils import EMOJI_URL, html_to_markdown LOGGER = logging.getLogger(__name__) @@ -176,18 +176,18 @@ class FandomWikiSearchCog(commands.Cog, name="Fandom Wiki Search"): Parameters ---------- - bot: `core.Beira` + bot: `beira.Beira` The main Discord bot this cog is a part of. Attributes ---------- - bot: `core.Beira` + bot: `beira.Beira` The main Discord bot this cog is a part of. all_wikis: dict[`str`, dict[`str`, `str`]] The dict containing information for various wikis. """ - def __init__(self, bot: core.Beira) -> None: + def __init__(self, bot: beira.Beira) -> None: self.bot = bot self.all_wikis: dict[str, dict[str, str]] = {} @@ -206,7 +206,7 @@ async def cog_load(self) -> None: LOGGER.info("All wiki names: %s", list(self.all_wikis.keys())) - async def cog_command_error(self, ctx: core.Context, error: Exception) -> None: # type: ignore # Narrowing + async def cog_command_error(self, ctx: beira.Context, error: Exception) -> None: # type: ignore # Narrowing # Extract the original error. error = getattr(error, "original", error) if ctx.interaction: @@ -216,12 +216,12 @@ async def cog_command_error(self, ctx: core.Context, error: Exception) -> None: @commands.hybrid_command() @commands.cooldown(1, 5, commands.cooldowns.BucketType.user) - async def wiki(self, ctx: core.Context, wiki: str, search_term: str) -> None: + async def wiki(self, ctx: beira.Context, wiki: str, search_term: str) -> None: """Search a selection of pre-indexed Fandom wikis. General purpose. Parameters ---------- - ctx: `core.Context` + ctx: `beira.Context` The invocation context. wiki: `str` The name of the wiki that's being searched. @@ -233,7 +233,7 @@ async def wiki(self, ctx: core.Context, wiki: str, search_term: str) -> None: await ctx.send(embed=embed) @wiki.autocomplete("wiki") - async def wiki_autocomplete(self, _: core.Interaction, current: str) -> list[app_commands.Choice[str]]: + async def wiki_autocomplete(self, _: beira.Interaction, current: str) -> list[app_commands.Choice[str]]: """Autocomplete callback for the names of different wikis.""" options = self.all_wikis.keys() @@ -244,7 +244,7 @@ async def wiki_autocomplete(self, _: core.Interaction, current: str) -> list[app @wiki.autocomplete("search_term") async def wiki_search_term_autocomplete( self, - interaction: core.Interaction, + interaction: beira.Interaction, current: str, ) -> list[app_commands.Choice[str]]: """Autocomplete callback for the names of different wiki pages. @@ -329,7 +329,7 @@ async def search_wiki(self, wiki_name: str, wiki_query: str) -> discord.Embed: return final_embed -async def setup(bot: core.Beira) -> None: +async def setup(bot: beira.Beira) -> None: """Connects cog to bot.""" await bot.add_cog(FandomWikiSearchCog(bot)) diff --git a/exts/ff_metadata/__init__.py b/src/beira/exts/ff_metadata/__init__.py similarity index 67% rename from exts/ff_metadata/__init__.py rename to src/beira/exts/ff_metadata/__init__.py index b12d56e..12441ba 100644 --- a/exts/ff_metadata/__init__.py +++ b/src/beira/exts/ff_metadata/__init__.py @@ -1,9 +1,9 @@ -import core +import beira from .ff_metadata import FFMetadataCog -async def setup(bot: core.Beira) -> None: +async def setup(bot: beira.Beira) -> None: """Connects cog to bot.""" await bot.add_cog(FFMetadataCog(bot)) diff --git a/exts/ff_metadata/ff_metadata.py b/src/beira/exts/ff_metadata/ff_metadata.py similarity index 94% rename from exts/ff_metadata/ff_metadata.py rename to src/beira/exts/ff_metadata/ff_metadata.py index e37bf2e..2921cdb 100644 --- a/exts/ff_metadata/ff_metadata.py +++ b/src/beira/exts/ff_metadata/ff_metadata.py @@ -12,7 +12,7 @@ import fichub_api from discord.ext import commands -import core +import beira from .utils import ( STORY_WEBSITE_REGEX, @@ -32,12 +32,12 @@ class FFMetadataCog(commands.GroupCog, name="Fanfiction Metadata Search", group_name="ff"): """A cog with triggers and commands for retrieving story metadata.""" - def __init__(self, bot: core.Beira) -> None: + def __init__(self, bot: beira.Beira) -> None: self.bot = bot self.atlas_client = bot.atlas_client self.fichub_client = bot.fichub_client self.ao3_client = bot.ao3_client - self.aci100_id: int = core.CONFIG.discord.important_guilds["prod"][0] + self.aci100_guild_id: int = bot.config.discord.important_guilds["prod"][0] self.allowed_channels_cache: dict[int, set[int]] = {} @property @@ -53,7 +53,7 @@ async def cog_load(self) -> None: for record in records: self.allowed_channels_cache.setdefault(record["guild_id"], set()).add(record["channel_id"]) - async def cog_command_error(self, ctx: core.Context, error: Exception) -> None: # type: ignore # Narrowing + async def cog_command_error(self, ctx: beira.Context, error: Exception) -> None: # type: ignore # Narrowing # Extract the original error. error = getattr(error, "original", error) if ctx.interaction: @@ -68,7 +68,7 @@ async def on_posted_fanfic_link(self, message: discord.Message) -> None: Must be triggered in an allowed channel. """ - if (message.author == self.bot.user) or (not message.guild) or message.guild.id == self.aci100_id: + if (message.author == self.bot.user) or (not message.guild) or message.guild.id == self.aci100_guild_id: return # Listen to the allowed channels in the allowed guilds for valid fanfic links. @@ -92,7 +92,7 @@ async def on_fanficfinder_nothing_found_message(self, message: discord.Message) if bool( message.guild - and (message.guild.id == self.aci100_id) + and (message.guild.id == self.aci100_guild_id) and (message.author.id == FANFICFINDER_ID) and message.embeds and (embed := message.embeds[0]) @@ -103,7 +103,7 @@ async def on_fanficfinder_nothing_found_message(self, message: discord.Message) @commands.hybrid_group(fallback="get") @commands.guild_only() - async def autoresponse(self, ctx: core.GuildContext) -> None: + async def autoresponse(self, ctx: beira.GuildContext) -> None: """Autoresponse-related commands for automatically responding to fanfiction links in certain channels. By default, display the channels in the server set to autorespond. @@ -121,7 +121,7 @@ async def autoresponse(self, ctx: core.GuildContext) -> None: @autoresponse.command("add") async def autoresponse_add( self, - ctx: core.GuildContext, + ctx: beira.GuildContext, *, channels: commands.Greedy[discord.abc.GuildChannel], ) -> None: @@ -131,7 +131,7 @@ async def autoresponse_add( Parameters ---------- - ctx: `core.GuildContext` + ctx: `beira.GuildContext` The invocation context. channels: `commands.Greedy`[`discord.abc.GuildChannel`] A list of channels to add, separated by spaces. @@ -161,7 +161,7 @@ async def autoresponse_add( @autoresponse.command("remove") async def autoresponse_remove( self, - ctx: core.GuildContext, + ctx: beira.GuildContext, *, channels: commands.Greedy[discord.abc.GuildChannel], ) -> None: @@ -171,7 +171,7 @@ async def autoresponse_remove( Parameters ---------- - ctx: `core.GuildContext` + ctx: `beira.GuildContext` The invocation context. channels: `commands.Greedy`[`discord.abc.GuildChannel`] A list of channels to remove, separated by spaces. @@ -197,14 +197,20 @@ async def autoresponse_remove( await ctx.send(embed=embed) @commands.hybrid_command() - async def ff_search(self, ctx: core.Context, platform: Literal["ao3", "ffn", "other"], *, name_or_url: str) -> None: + async def ff_search( + self, + ctx: beira.Context, + platform: Literal["ao3", "ffn", "other"], + *, + name_or_url: str, + ) -> None: """Search available platforms for a fic with a certain title or url. Note: Only urls are accepted for `other`. Parameters ---------- - ctx: `core.Context` + ctx: `beira.Context` The invocation context. platform: Literal["ao3", "ffn", "other"] The platform to search. diff --git a/exts/ff_metadata/utils.py b/src/beira/exts/ff_metadata/utils.py similarity index 99% rename from exts/ff_metadata/utils.py rename to src/beira/exts/ff_metadata/utils.py index 9f29598..bdedfd9 100644 --- a/exts/ff_metadata/utils.py +++ b/src/beira/exts/ff_metadata/utils.py @@ -8,7 +8,7 @@ import fichub_api import lxml.html -from core.utils import PaginatedSelectView, html_to_markdown +from beira.utils import PaginatedSelectView, html_to_markdown __all__ = ( diff --git a/exts/help.py b/src/beira/exts/help.py similarity index 94% rename from exts/help.py rename to src/beira/exts/help.py index d3167b9..b088af9 100644 --- a/exts/help.py +++ b/src/beira/exts/help.py @@ -14,8 +14,8 @@ from discord import app_commands from discord.ext import commands -import core -from core.utils import PaginatedEmbedView +import beira +from beira.utils import PaginatedEmbedView LOGGER = logging.getLogger(__name__) @@ -219,7 +219,7 @@ def clean_docstring(docstring: str) -> str: class HelpCog(commands.Cog, name="Help"): """A cog that allows more dynamic usage of my custom help command class, `BeiraHelpCommand`.""" - def __init__(self, bot: core.Beira) -> None: + def __init__(self, bot: beira.Beira) -> None: self.bot = bot self._old_help_command = self.bot.help_command self.bot.help_command = BeiraHelpCommand() @@ -235,10 +235,10 @@ async def cog_command_error(self, ctx: commands.Context[Any], error: Exception) LOGGER.exception("", exc_info=error) @app_commands.command(name="help") - async def help_(self, interaction: core.Interaction, command: str | None = None) -> None: + async def help_(self, interaction: beira.Interaction, command: str | None = None) -> None: """Access the help commands through the slash system.""" - ctx = await self.bot.get_context(interaction, cls=core.Context) + ctx = await self.bot.get_context(interaction, cls=beira.Context) if command is not None: await ctx.send_help(command) @@ -248,11 +248,15 @@ async def help_(self, interaction: core.Interaction, command: str | None = None) await interaction.response.send_message(content="Help dialogue sent!", ephemeral=True) @help_.autocomplete("command") - async def command_autocomplete(self, interaction: core.Interaction, current: str) -> list[app_commands.Choice[str]]: + async def command_autocomplete( + self, + interaction: beira.Interaction, + current: str, + ) -> list[app_commands.Choice[str]]: """Autocompletes the help command.""" assert self.bot.help_command - ctx = await self.bot.get_context(interaction, cls=core.Context) + ctx = await self.bot.get_context(interaction, cls=beira.Context) help_command = self.bot.help_command.copy() help_command.context = ctx @@ -270,7 +274,7 @@ async def command_autocomplete(self, interaction: core.Interaction, current: str ][:25] -async def setup(bot: core.Beira) -> None: +async def setup(bot: beira.Beira) -> None: """Connects cog to bot.""" await bot.add_cog(HelpCog(bot)) diff --git a/exts/lol.py b/src/beira/exts/lol.py similarity index 93% rename from exts/lol.py rename to src/beira/exts/lol.py index 4a8c00f..1a36b7d 100644 --- a/exts/lol.py +++ b/src/beira/exts/lol.py @@ -16,8 +16,8 @@ from arsenic import browsers, errors, get_session, services # type: ignore # Third-party lib typing. from discord.ext import commands -import core -from core.utils import StatsEmbed +import beira +from beira.utils import StatsEmbed LOGGER = logging.getLogger(__name__) @@ -71,7 +71,7 @@ async def interaction_check(self, interaction: discord.Interaction, /) -> bool: return check @discord.ui.button(label="Update", style=discord.ButtonStyle.blurple) - async def update(self, interaction: core.Interaction, button: discord.ui.Button[Self]) -> None: + async def update(self, interaction: beira.Interaction, button: discord.ui.Button[Self]) -> None: """Update the information in the given leaderboard.""" # Change the button to show the update is in progress. @@ -101,7 +101,7 @@ class LoLCog(commands.Cog, name="League of Legends"): Credit to Ralph for the main code; I'm just testing it out to see how it would work in Discord. """ - def __init__(self, bot: core.Beira) -> None: + def __init__(self, bot: beira.Beira) -> None: self.bot = bot self.default_summoners_list = [ "Real Iron IV", @@ -127,7 +127,7 @@ def cog_emoji(self) -> discord.PartialEmoji: return discord.PartialEmoji(name="ok_lol", id=1077980829315252325) - async def cog_command_error(self, ctx: core.Context, error: Exception) -> None: # type: ignore # Narrowing + async def cog_command_error(self, ctx: beira.Context, error: Exception) -> None: # type: ignore # Narrowing # Extract the original error. error = getattr(error, "original", error) if ctx.interaction: @@ -136,18 +136,18 @@ async def cog_command_error(self, ctx: core.Context, error: Exception) -> None: LOGGER.exception("", exc_info=error) @commands.hybrid_group() - async def lol(self, ctx: core.Context) -> None: + async def lol(self, ctx: beira.Context) -> None: """A group of League of Legends-related commands.""" await ctx.send_help(ctx.command) @lol.command("stats") - async def lol_stats(self, ctx: core.Context, summoner_name: str) -> None: + async def lol_stats(self, ctx: beira.Context, summoner_name: str) -> None: """Gets the League of Legends stats for a summoner. Parameters ---------- - ctx: `core.Context` + ctx: `beira.Context` The invocation context. summoner_name: `str` The summoner name, or username, of the League of Legends player being queried. @@ -168,12 +168,12 @@ async def lol_stats(self, ctx: core.Context, summoner_name: str) -> None: await ctx.send(embed=embed) @lol.command("leaderboard") - async def lol_leaderboard(self, ctx: core.Context, *, summoner_names: str | None = None) -> None: + async def lol_leaderboard(self, ctx: beira.Context, *, summoner_names: str | None = None) -> None: """Get the League of Legends ranked stats for a group of summoners and display them. Parameters ---------- - ctx: `core.Context` + ctx: `beira.Context` The invocation context. summoner_names: list[`str`] A string of summoner names to create a leaderboard from. Separate these by spaces. @@ -267,7 +267,7 @@ async def check_lol_stats(self, summoner_name: str) -> tuple[str, str, str]: return summoner_name, winrate, rank -async def setup(bot: core.Beira) -> None: +async def setup(bot: beira.Beira) -> None: """Connects cog to bot.""" await bot.add_cog(LoLCog(bot)) diff --git a/exts/music.py b/src/beira/exts/music.py similarity index 89% rename from exts/music.py rename to src/beira/exts/music.py index bb586db..c2c3bdc 100644 --- a/exts/music.py +++ b/src/beira/exts/music.py @@ -16,8 +16,8 @@ from discord.ext import commands from wavelink.types.filters import FilterPayload -import core -from core.utils import EMOJI_STOCK, PaginatedEmbedView +import beira +from beira.utils import EMOJI_STOCK, PaginatedEmbedView LOGGER = logging.getLogger(__name__) @@ -126,7 +126,7 @@ def __init__( class MusicCog(commands.Cog, name="Music"): """A cog with audio-playing functionality.""" - def __init__(self, bot: core.Beira) -> None: + def __init__(self, bot: beira.Beira) -> None: self.bot = bot @property @@ -135,7 +135,7 @@ def cog_emoji(self) -> discord.PartialEmoji: return discord.PartialEmoji(name="\N{MUSICAL NOTE}") - async def cog_command_error(self, ctx: core.Context, error: Exception) -> None: # type: ignore # Narrowing + async def cog_command_error(self, ctx: beira.Context, error: Exception) -> None: # type: ignore # Narrowing """Catch errors from commands inside this cog.""" embed = discord.Embed(title="Music Error", description="Something went wrong with this command.") @@ -147,7 +147,7 @@ async def cog_command_error(self, ctx: core.Context, error: Exception) -> None: if isinstance(error, commands.MissingPermissions): embed.description = "You don't have permission to do this." - elif isinstance(error, core.NotInBotVoiceChannel): + elif isinstance(error, beira.NotInBotVoiceChannel): embed.description = "You're not in the same voice channel as the bot." elif isinstance(error, InvalidShortTimeFormat): embed.description = error.message @@ -200,13 +200,13 @@ async def on_wavelink_inactive_player(self, player: wavelink.Player) -> None: @commands.hybrid_group() @commands.guild_only() - async def music(self, ctx: core.GuildContext) -> None: + async def music(self, ctx: beira.GuildContext) -> None: """Music-related commands.""" await ctx.send_help(ctx.command) @music.command() - async def connect(self, ctx: core.GuildContext, channel: discord.VoiceChannel | None = None) -> None: + async def connect(self, ctx: beira.GuildContext, channel: discord.VoiceChannel | None = None) -> None: """Join a voice channel.""" vc: wavelink.Player | None = ctx.voice_client @@ -235,12 +235,12 @@ async def connect(self, ctx: core.GuildContext, channel: discord.VoiceChannel | await ctx.send(f"Joined the {ctx.author.voice.channel} channel.") @music.command() - async def play(self, ctx: core.GuildContext, query: str, _channel: discord.VoiceChannel | None = None) -> None: + async def play(self, ctx: beira.GuildContext, query: str, _channel: discord.VoiceChannel | None = None) -> None: """Play audio from a url or search term. Parameters ---------- - ctx: `core.GuildContext` + ctx: `beira.GuildContext` The invocation context. query: `str` A search term/url that is converted into a track or playlist. @@ -274,7 +274,7 @@ async def play(self, ctx: core.GuildContext, query: str, _channel: discord.Voice await vc.play(vc.queue.get()) @play.before_invoke - async def play_ensure_voice(self, ctx: core.GuildContext) -> None: + async def play_ensure_voice(self, ctx: beira.GuildContext) -> None: """Ensures that the voice client automatically connects the right channel.""" vc: wavelink.Player | None = ctx.voice_client @@ -295,8 +295,8 @@ async def play_autocomplete(self, _: discord.Interaction, current: str) -> list[ return [app_commands.Choice(name=track.title, value=track.uri or track.title) for track in tracks][:25] @music.command() - @core.in_bot_vc() - async def pause(self, ctx: core.GuildContext) -> None: + @beira.in_bot_vc() + async def pause(self, ctx: beira.GuildContext) -> None: """Pause the audio.""" if vc := ctx.voice_client: @@ -307,8 +307,8 @@ async def pause(self, ctx: core.GuildContext) -> None: await ctx.send("No player to perform this on.") @music.command() - @core.in_bot_vc() - async def resume(self, ctx: core.GuildContext) -> None: + @beira.in_bot_vc() + async def resume(self, ctx: beira.GuildContext) -> None: """Resume the audio if paused.""" if vc := ctx.voice_client: @@ -321,8 +321,8 @@ async def resume(self, ctx: core.GuildContext) -> None: await ctx.send("No player to perform this on.") @music.command(aliases=["disconnect"]) - @core.in_bot_vc() - async def stop(self, ctx: core.GuildContext) -> None: + @beira.in_bot_vc() + async def stop(self, ctx: beira.GuildContext) -> None: """Stop playback and disconnect the bot from voice.""" if vc := ctx.voice_client: @@ -332,7 +332,7 @@ async def stop(self, ctx: core.GuildContext) -> None: await ctx.send("No player to perform this on.") @music.command() - async def current(self, ctx: core.GuildContext) -> None: + async def current(self, ctx: beira.GuildContext) -> None: """Display the current track.""" vc: wavelink.Player | None = ctx.voice_client @@ -349,7 +349,7 @@ async def current(self, ctx: core.GuildContext) -> None: await ctx.send(embed=current_embed) @music.group(fallback="get") - async def queue(self, ctx: core.GuildContext) -> None: + async def queue(self, ctx: beira.GuildContext) -> None: """Music queue-related commands. By default, this displays everything in the queue. Use `play` to add things to the queue. @@ -369,13 +369,13 @@ async def queue(self, ctx: core.GuildContext) -> None: view.message = message @queue.command() - @core.in_bot_vc() - async def remove(self, ctx: core.GuildContext, entry: int) -> None: + @beira.in_bot_vc() + async def remove(self, ctx: beira.GuildContext, entry: int) -> None: """Remove a track from the queue by position. Parameters ---------- - ctx: `core.GuildContext` + ctx: `beira.GuildContext` The invocation context. entry: `int` The track's position. @@ -391,8 +391,8 @@ async def remove(self, ctx: core.GuildContext, entry: int) -> None: await ctx.send("No player to perform this on.") @queue.command() - @core.in_bot_vc() - async def clear(self, ctx: core.GuildContext) -> None: + @beira.in_bot_vc() + async def clear(self, ctx: beira.GuildContext) -> None: """Empty the queue.""" if vc := ctx.voice_client: @@ -405,13 +405,13 @@ async def clear(self, ctx: core.GuildContext) -> None: await ctx.send("No player to perform this on.") @music.command() - @core.in_bot_vc() - async def move(self, ctx: core.GuildContext, before: int, after: int) -> None: + @beira.in_bot_vc() + async def move(self, ctx: beira.GuildContext, before: int, after: int) -> None: """Move a song from one spot to another within the queue. Parameters ---------- - ctx: `core.GuildContext` + ctx: `beira.GuildContext` The invocation context. before: `int` The index of the song you want moved. @@ -432,13 +432,13 @@ async def move(self, ctx: core.GuildContext, before: int, after: int) -> None: await ctx.send("No player to perform this on.") @music.command() - @core.in_bot_vc() - async def skip(self, ctx: core.GuildContext, index: int = 1) -> None: + @beira.in_bot_vc() + async def skip(self, ctx: beira.GuildContext, index: int = 1) -> None: """Skip to the numbered track in the queue. If no number is given, skip to the next track. Parameters ---------- - ctx: `core.GuildContext` + ctx: `beira.GuildContext` The invocation context. index: `int` The place in the queue to skip to. @@ -461,8 +461,8 @@ async def skip(self, ctx: core.GuildContext, index: int = 1) -> None: await ctx.send("No player to perform this on.") @music.command() - @core.in_bot_vc() - async def shuffle(self, ctx: core.GuildContext) -> None: + @beira.in_bot_vc() + async def shuffle(self, ctx: beira.GuildContext) -> None: """Shuffle the tracks in the queue.""" if vc := ctx.voice_client: @@ -475,13 +475,13 @@ async def shuffle(self, ctx: core.GuildContext) -> None: await ctx.send("No player to perform this on.") @music.command() - @core.in_bot_vc() - async def loop(self, ctx: core.GuildContext, loop: Literal["All Tracks", "Current Track", "Off"] = "Off") -> None: + @beira.in_bot_vc() + async def loop(self, ctx: beira.GuildContext, loop: Literal["All Tracks", "Current Track", "Off"] = "Off") -> None: """Loop the current track(s). Parameters ---------- - ctx: `core.GuildContext` + ctx: `beira.GuildContext` The invocation context. loop: Literal["All Tracks", "Current Track", "Off"] The loop settings. "All Tracks" loops everything in the queue, "Current Track" loops the playing track, and @@ -502,10 +502,10 @@ async def loop(self, ctx: core.GuildContext, loop: Literal["All Tracks", "Curren await ctx.send("No player to perform this on.") @music.command() - @core.in_bot_vc() + @beira.in_bot_vc() async def seek( self, - ctx: core.GuildContext, + ctx: beira.GuildContext, *, position: app_commands.Transform[datetime.timedelta, ShortDurationTransformer], ) -> None: @@ -513,7 +513,7 @@ async def seek( Parameters ---------- - ctx: `core.GuildContext` + ctx: `beira.GuildContext` The invocation context. position: `str` The time to jump to, given in the format `hours:minutes:seconds` or `minutes:seconds`. @@ -536,13 +536,13 @@ async def seek( await ctx.send("No player to perform this on.") @music.command() - @core.in_bot_vc() - async def volume(self, ctx: core.GuildContext, volume: int | None = None) -> None: + @beira.in_bot_vc() + async def volume(self, ctx: beira.GuildContext, volume: int | None = None) -> None: """Show the player's volume. If given a number, you can change it as well, with 1000 as the limit. Parameters ---------- - ctx: `core.GuildContext` + ctx: `beira.GuildContext` The invocation context. volume: `int`, optional The volume to change to, with a maximum of 1000. @@ -558,13 +558,13 @@ async def volume(self, ctx: core.GuildContext, volume: int | None = None) -> Non await ctx.send("No player to perform this on.") @music.command("filter") - @core.in_bot_vc() - async def muse_filter(self, ctx: core.GuildContext, name: str) -> None: + @beira.in_bot_vc() + async def muse_filter(self, ctx: beira.GuildContext, name: str) -> None: """Set a filter on the incoming audio. Parameters ---------- - ctx: `core.GuildContext` + ctx: `beira.GuildContext` The invocation context. name: `str` The name of the filter to use. "reset" resets the filters. @@ -588,7 +588,7 @@ async def muse_filter(self, ctx: core.GuildContext, name: str) -> None: await ctx.send("No player to perform this on.") @muse_filter.autocomplete("name") - async def muse_filter_name_autocomplete(self, _: core.Interaction, current: str) -> list[app_commands.Choice[str]]: + async def muse_filter_name_autocomplete(self, _: beira.Interaction, current: str) -> list[app_commands.Choice[str]]: return [ app_commands.Choice(name=name, value=name) for name in COMMON_FILTERS @@ -597,8 +597,8 @@ async def muse_filter_name_autocomplete(self, _: core.Interaction, current: str) @music.command(name="export") @commands.guild_only() - @core.in_bot_vc() - async def muse_export(self, ctx: core.GuildContext) -> None: + @beira.in_bot_vc() + async def muse_export(self, ctx: beira.GuildContext) -> None: """Export the current queue to a file. Can be re-imported later to recreate the queue.""" if vc := ctx.voice_client: @@ -616,12 +616,12 @@ async def muse_export(self, ctx: core.GuildContext) -> None: @music.command(name="import") @commands.guild_only() - async def muse_import(self, ctx: core.GuildContext, import_file: discord.Attachment) -> None: + async def muse_import(self, ctx: beira.GuildContext, import_file: discord.Attachment) -> None: """Import a file with track information to recreate a music queue. May be created with /export. Parameters ---------- - ctx: core.GuildContext + ctx: beira.GuildContext The invocation context. import_file: discord.Attachment A JSON file with track information to recreate the queue with. May be created by /export. @@ -650,7 +650,7 @@ async def muse_import(self, ctx: core.GuildContext, import_file: discord.Attachm await ctx.send("No player to perform this on.") @muse_import.error - async def muse_import_error(self, ctx: core.Context, error: commands.CommandError) -> None: + async def muse_import_error(self, ctx: beira.Context, error: commands.CommandError) -> None: """Error handler for /music import. Provides better error messages for users.""" actual_error = error.__cause__ or error @@ -665,7 +665,7 @@ async def muse_import_error(self, ctx: core.Context, error: commands.CommandErro await ctx.send(error_text) @muse_import.before_invoke - async def import_ensure_voice(self, ctx: core.GuildContext) -> None: + async def import_ensure_voice(self, ctx: beira.GuildContext) -> None: """Ensures that the voice client automatically connects the right channel.""" vc: wavelink.Player | None = ctx.voice_client @@ -681,7 +681,7 @@ async def import_ensure_voice(self, ctx: core.GuildContext) -> None: raise commands.CommandError(msg) -async def setup(bot: core.Beira) -> None: +async def setup(bot: beira.Beira) -> None: """Connects cog to bot.""" await bot.add_cog(MusicCog(bot)) diff --git a/exts/misc.py b/src/beira/exts/other.py similarity index 84% rename from exts/misc.py rename to src/beira/exts/other.py index 2f3e9ed..a4afb31 100644 --- a/exts/misc.py +++ b/src/beira/exts/other.py @@ -2,6 +2,7 @@ import asyncio import colorsys +import importlib.metadata import logging import math import random @@ -13,10 +14,9 @@ import discord import openpyxl import openpyxl.styles -from discord import app_commands from discord.ext import commands -import core +import beira LOGGER = logging.getLogger(__name__) @@ -36,8 +36,8 @@ def capitalize_meow(word: str, reference: str) -> str: return word.lower() # Char-by-char processing. - for cw, cr in zip(word, reference, strict=True): - new_word.write(cw.upper() if cr.isupper() else cw) + for word_char, ref_char in zip(word, reference, strict=True): + new_word.write(word_char.upper() if ref_char.isupper() else word_char) return new_word.getvalue() @@ -61,7 +61,7 @@ def meowify_word(match: re.Match[str]) -> str: internal_len = len(word) - 2 e_len = random.randint(1, internal_len) o_len = internal_len - e_len - temp = "m" + "e" * e_len + "o" * o_len + "w" + temp = f"m{"e" * e_len}{"o" * o_len}w" return capitalize_meow(temp, word) @@ -71,8 +71,8 @@ def meowify_text(text: str) -> str: return re.sub(r"\w+", meowify_word, text) -@app_commands.context_menu(name="Meowify") -async def context_menu_meowify(interaction: core.Interaction, message: discord.Message) -> None: +@discord.app_commands.context_menu(name="Meowify") +async def context_menu_meowify(interaction: beira.Interaction, message: discord.Message) -> None: """Context menu command callback for meowifying the test in a message.""" if len(message.content) > 2000: @@ -99,20 +99,23 @@ def color_step(r: int, g: int, b: int, repetitions: int = 1) -> tuple[int, int, def process_color_data(role_data: list[tuple[str, discord.Colour]]) -> BytesIO: """Format role names and colors in an excel sheet and return that sheet as a bytes stream.""" - headers = ["Role Name", "Role Color (Hex)"] workbook = openpyxl.Workbook() sheet = workbook.active - sheet.append(headers) # type: ignore + assert sheet # openpyxl automatically adds a sheet on initialization if the workbook isn't write-only. + + headers = ["Role Name", "Role Color (Hex)"] + sheet.append(headers) + for i, (name, colour) in enumerate(role_data, start=2): color_value = colour.value str_hex = f"{color_value:#08x}".removeprefix("0x") - sheet.append([name, str_hex]) # type: ignore + sheet.append([name, str_hex]) if color_value != 0: - sheet[f"A{i}"].fill = openpyxl.styles.PatternFill(fill_type="solid", start_color=str_hex) # type: ignore + sheet[f"A{i}"].fill = openpyxl.styles.PatternFill(fill_type="solid", start_color=str_hex) ft = openpyxl.styles.Font(bold=True) - for row in sheet["A1:C1"]: # type: ignore - for cell in row: # type: ignore + for row in sheet["A1:C1"]: + for cell in row: cell.font = ft with tempfile.NamedTemporaryFile() as tmp: @@ -121,10 +124,10 @@ def process_color_data(role_data: list[tuple[str, discord.Colour]]) -> BytesIO: return BytesIO(tmp.read()) -class MiscCog(commands.Cog, name="Misc"): - """A cog with some basic commands, originally used for testing slash and hybrid command functionality.""" +class OtherCog(commands.Cog, name="Other"): + """A cog with some basic or random commands.""" - def __init__(self, bot: core.Beira) -> None: + def __init__(self, bot: beira.Beira) -> None: self.bot = bot self.bot.tree.add_command(context_menu_meowify) @@ -137,7 +140,7 @@ def cog_emoji(self) -> discord.PartialEmoji: async def cog_unload(self) -> None: self.bot.tree.remove_command(context_menu_meowify.name, type=context_menu_meowify.type) - async def cog_command_error(self, ctx: core.Context, error: Exception) -> None: # type: ignore # Narrowing + async def cog_command_error(self, ctx: beira.Context, error: Exception) -> None: # type: ignore # Narrowing # Extract the original error. error = getattr(error, "original", error) if ctx.interaction: @@ -146,7 +149,27 @@ async def cog_command_error(self, ctx: core.Context, error: Exception) -> None: LOGGER.exception("", exc_info=error) @commands.hybrid_command() - async def about(self, ctx: core.Context) -> None: + async def hello(self, ctx: beira.Context) -> None: + """Get a "Hello, World!" response.""" + + await ctx.send("Hello, World!") + + @commands.hybrid_command() + async def echo(self, ctx: beira.Context, *, arg: str) -> None: + """Echo back the user's input. + + Parameters + ---------- + ctx: `beira.Context` + The invocation context. + arg: `str` + The user input. + """ + + await ctx.send(arg) + + @commands.hybrid_command() + async def about(self, ctx: beira.Context) -> None: """See some basic information about the bot, including its source.""" assert self.bot.user # Known to exist during runtime. @@ -166,37 +189,17 @@ async def about(self, ctx: core.Context) -> None: ) .set_author(name=f"Made by {self.bot.owner}", icon_url=self.bot.owner.display_avatar.url) .set_thumbnail(url=self.bot.user.display_avatar.url) - .set_footer(text=f"Made with discord.py v{discord.__version__}") + .set_footer(text=f"Made with discord.py v{importlib.metadata.version("discord")}") ) await ctx.send(embed=embed) @commands.hybrid_command() - async def hello(self, ctx: core.Context) -> None: - """Get a "Hello, World!" response.""" - - await ctx.send("Hello, World!") - - @commands.hybrid_command() - async def echo(self, ctx: core.Context, *, arg: str) -> None: - """Echo back the user's input. - - Parameters - ---------- - ctx: `core.Context` - The invocation context. - arg: `str` - The user input. - """ - - await ctx.send(arg) - - @commands.hybrid_command() - async def quote(self, ctx: core.Context, *, message: discord.Message) -> None: + async def quote(self, ctx: beira.Context, *, message: discord.Message) -> None: """Display a message's contents, specified with a message link, message ID, or channel-message ID pair. Parameters ---------- - ctx: `core.Context` + ctx: `beira.Context` The invocation context. message: `discord.Message` The message to be quoted. It can be specified by a message link, message ID, or channel-message ID pair. @@ -210,12 +213,12 @@ async def quote(self, ctx: core.Context, *, message: discord.Message) -> None: await ctx.send(embed=quote_embed) @commands.hybrid_command(name="ping") - async def ping_(self, ctx: core.Context) -> None: + async def ping_(self, ctx: beira.Context) -> None: """Display the time necessary for the bot to communicate with Discord. Parameters ---------- - ctx: `core.Context` + ctx: `beira.Context` The invocation context. """ @@ -247,12 +250,12 @@ async def ping_(self, ctx: core.Context) -> None: await message.edit(embed=pong_embed) @commands.hybrid_command() - async def meowify(self, ctx: core.Context, *, text: str) -> None: + async def meowify(self, ctx: beira.Context, *, text: str) -> None: """Meowify some text. Parameters ---------- - ctx: `core.Context` + ctx: `beira.Context` The invocation context. text: `str` The text to convert into meows. @@ -267,12 +270,12 @@ async def meowify(self, ctx: core.Context, *, text: str) -> None: @commands.guild_only() @commands.hybrid_command() - async def role_excel(self, ctx: core.GuildContext, by_color: bool = False) -> None: + async def role_excel(self, ctx: beira.GuildContext, by_color: bool = False) -> None: """Get a spreadsheet with a guild's roles, optionally sorted by color. Parameters ---------- - ctx: `core.GuildContext` + ctx: `beira.GuildContext` The invocation context, restricted to a guild. by_color: `bool`, default=False Whether the roles should be sorted by color. If False, sorts by name. Default is False. @@ -292,7 +295,7 @@ def color_key(item: tuple[str, discord.Colour]) -> tuple[int, int, int]: await ctx.send("Created Excel sheet with roles.", file=disc_file) @commands.hybrid_command() - async def inspire_me(self, ctx: core.Context) -> None: + async def inspire_me(self, ctx: beira.Context) -> None: """Generate a random inspirational poster with InspiroBot.""" async with ctx.typing(): @@ -309,7 +312,7 @@ async def inspire_me(self, ctx: core.Context) -> None: await ctx.send(embed=embed) -async def setup(bot: core.Beira) -> None: +async def setup(bot: beira.Beira) -> None: """Connects cog to bot.""" - await bot.add_cog(MiscCog(bot)) + await bot.add_cog(OtherCog(bot)) diff --git a/exts/patreon.py b/src/beira/exts/patreon.py similarity index 94% rename from exts/patreon.py rename to src/beira/exts/patreon.py index c920f64..b3e0a2a 100644 --- a/exts/patreon.py +++ b/src/beira/exts/patreon.py @@ -13,8 +13,8 @@ import msgspec from discord.ext import commands, tasks -import core -from core.utils import PaginatedSelectView +import beira +from beira.utils import PaginatedSelectView LOGGER = logging.getLogger(__name__) @@ -98,9 +98,9 @@ class PatreonCheckCog(commands.Cog, name="Patreon"): In development. """ - def __init__(self, bot: core.Beira) -> None: + def __init__(self, bot: beira.Beira) -> None: self.bot = bot - self.access_token = core.CONFIG.patreon.creator_access_token + self.access_token = bot.config.patreon.creator_access_token self.patrons_on_discord: dict[str, list[discord.Member]] = {} @property @@ -123,13 +123,13 @@ async def cog_unload(self) -> None: if self.get_current_discord_patrons.is_running(): self.get_current_discord_patrons.stop() - async def cog_check(self, ctx: core.Context) -> bool: # type: ignore # Narrowing, and async allowed. + async def cog_check(self, ctx: beira.Context) -> bool: # type: ignore # Narrowing, and async allowed. """Set up bot owner check as universal within the cog.""" original = commands.is_owner().predicate return await original(ctx) - async def cog_command_error(self, ctx: core.Context, error: Exception) -> None: # type: ignore # Narrowing + async def cog_command_error(self, ctx: beira.Context, error: Exception) -> None: # type: ignore # Narrowing # Extract the original error. error = getattr(error, "original", error) if ctx.interaction: @@ -169,7 +169,7 @@ async def _get_patreon_roles(self) -> None: self.patreon_tiers_info.insert(0, menu_info) @commands.hybrid_command() - async def patreon_benefits(self, ctx: core.Context) -> None: + async def patreon_benefits(self, ctx: beira.Context) -> None: """See what kind of patreon benefits and tiers ACI100 has to offer.""" async with ctx.typing(): @@ -182,7 +182,7 @@ async def get_current_discord_patrons(self) -> None: LOGGER.info("Checking for new patrons, old patrons, and updated patrons!") - aci100_id = core.CONFIG.patreon.patreon_guild_id + aci100_id = self.bot.config.patreon.patreon_guild_id patreon_guild = self.bot.get_guild(aci100_id) assert patreon_guild is not None @@ -198,7 +198,7 @@ async def before_background_task(self) -> None: async def get_current_actual_patrons(self) -> None: """Get all active patrons from Patreon's API.""" - api_token = core.CONFIG.patreon.creator_access_token + api_token = self.bot.config.patreon.creator_access_token headers = {"Authorization": f"Bearer {api_token}"} # Get campaign data. @@ -271,7 +271,7 @@ async def get_current_actual_patrons(self) -> None: LOGGER.info("Remaining: %s", not_ok_members) -async def setup(bot: core.Beira) -> None: +async def setup(bot: beira.Beira) -> None: """Connects cog to bot.""" await bot.add_cog(PatreonCheckCog(bot)) diff --git a/exts/presence.py b/src/beira/exts/presence.py similarity index 92% rename from exts/presence.py rename to src/beira/exts/presence.py index 3f64a0a..6601d3f 100644 --- a/exts/presence.py +++ b/src/beira/exts/presence.py @@ -3,11 +3,11 @@ import discord from discord.ext import commands, tasks -import core +import beira class PresenceCog(commands.Cog): - def __init__(self, bot: core.Beira) -> None: + def __init__(self, bot: beira.Beira) -> None: self.bot = bot self.set_custom_presence.start() @@ -29,5 +29,5 @@ async def set_custom_presence_before(self) -> None: await self.bot.wait_until_ready() -async def setup(bot: core.Beira) -> None: +async def setup(bot: beira.Beira) -> None: await bot.add_cog(PresenceCog(bot)) diff --git a/exts/snowball/__init__.py b/src/beira/exts/snowball/__init__.py similarity index 66% rename from exts/snowball/__init__.py rename to src/beira/exts/snowball/__init__.py index 1f2b429..e107e63 100644 --- a/exts/snowball/__init__.py +++ b/src/beira/exts/snowball/__init__.py @@ -1,9 +1,9 @@ -import core +import beira from .snowball import SnowballCog -async def setup(bot: core.Beira) -> None: +async def setup(bot: beira.Beira) -> None: """Connects cog to bot.""" await bot.add_cog(SnowballCog(bot)) diff --git a/exts/snowball/snow_text.py b/src/beira/exts/snowball/snow_text.py similarity index 100% rename from exts/snowball/snow_text.py rename to src/beira/exts/snowball/snow_text.py diff --git a/exts/snowball/snowball.py b/src/beira/exts/snowball/snowball.py similarity index 93% rename from exts/snowball/snowball.py rename to src/beira/exts/snowball/snowball.py index e00ed3a..e1792c0 100644 --- a/exts/snowball/snowball.py +++ b/src/beira/exts/snowball/snowball.py @@ -16,8 +16,8 @@ from discord import app_commands from discord.ext import commands -import core -from core.utils import EMOJI_STOCK, StatsEmbed +import beira +from beira.utils import EMOJI_STOCK, StatsEmbed from .snow_text import ( COLLECT_FAIL_IMGS, @@ -47,7 +47,7 @@ class SnowballCog(commands.Cog, name="Snowball"): """A cog that implements all snowball fight-related commands, like Discord's 2021 Snowball bot game.""" - def __init__(self, bot: core.Beira) -> None: + def __init__(self, bot: beira.Beira) -> None: self.bot = bot @property @@ -56,7 +56,7 @@ def cog_emoji(self) -> discord.PartialEmoji: return discord.PartialEmoji(name="snowflake", animated=True, id=1077980648867901531) - async def cog_command_error(self, ctx: core.Context, error: Exception) -> None: # type: ignore # Narrowing + async def cog_command_error(self, ctx: beira.Context, error: Exception) -> None: # type: ignore # Narrowing """Handles errors that occur within this cog. For example, when using prefix commands, this will tell users if they are missing arguments. Other error cases @@ -64,7 +64,7 @@ async def cog_command_error(self, ctx: core.Context, error: Exception) -> None: Parameters ---------- - ctx: `core.Context` + ctx: `beira.Context` The invocation context where the error happened. error: `Exception` The error that happened. @@ -86,7 +86,7 @@ async def cog_command_error(self, ctx: core.Context, error: Exception) -> None: elif isinstance(error, commands.CommandOnCooldown): embed.title = "Command on Cooldown!" embed.description = f"Please wait {error.retry_after:.2f} seconds before trying this command again." - elif isinstance(error, core.CannotTargetSelf): + elif isinstance(error, beira.CannotTargetSelf): embed.title = "No Targeting Yourself!" embed.description = ( "Are you a masochist or do you just like the taste of snow? Regardless, no hitting yourself in the " @@ -103,14 +103,14 @@ async def cog_command_error(self, ctx: core.Context, error: Exception) -> None: await ctx.send(embed=embed, ephemeral=True) @commands.hybrid_group() - async def snow(self, ctx: core.Context) -> None: + async def snow(self, ctx: beira.Context) -> None: """A group of snowball-related commands.""" await ctx.send_help(ctx.command) @snow.command() @commands.guild_only() - async def settings(self, ctx: core.GuildContext) -> None: + async def settings(self, ctx: beira.GuildContext) -> None: """Show what the settings are for the snowballs in this server.""" # Get the settings for the guild and make an embed display. @@ -119,7 +119,7 @@ async def settings(self, ctx: core.GuildContext) -> None: embed = view.format_embed() # Only send the view with the embed if invoker has certain perms. - if ctx.author.id == self.bot.owner_id or await core.is_admin().predicate(ctx): + if ctx.author.id == self.bot.owner_id or await beira.is_admin().predicate(ctx): view.message = await ctx.send(embed=embed, view=view) else: await ctx.send(embed=embed) @@ -127,7 +127,7 @@ async def settings(self, ctx: core.GuildContext) -> None: @snow.command() @commands.guild_only() @commands.dynamic_cooldown(collect_cooldown, commands.cooldowns.BucketType.user) # type: ignore - async def collect(self, ctx: core.GuildContext) -> None: + async def collect(self, ctx: beira.GuildContext) -> None: """Collects a snowball.""" # Get the snowball settings for this particular guild. @@ -161,12 +161,12 @@ async def collect(self, ctx: core.GuildContext) -> None: @snow.command() @commands.guild_only() @app_commands.describe(target="Who do you want to throw a snowball at?") - async def throw(self, ctx: core.GuildContext, *, target: discord.Member) -> None: + async def throw(self, ctx: beira.GuildContext, *, target: discord.Member) -> None: """Start a snowball fight with another server member. Parameters ---------- - ctx: `core.GuildContext` + ctx: `beira.GuildContext` The invocation context. target: `discord.Member` The user to hit with a snowball. @@ -174,7 +174,7 @@ async def throw(self, ctx: core.GuildContext, *, target: discord.Member) -> None if target == ctx.author: msg = "You cannot target yourself with this argument." - raise core.CannotTargetSelf(msg) + raise beira.CannotTargetSelf(msg) # Get the snowball settings for this particular guild. guild_snow_settings = getattr(ctx, "guild_snow_settings", GuildSnowballSettings(ctx.guild.id)) @@ -221,12 +221,12 @@ async def throw(self, ctx: core.GuildContext, *, target: discord.Member) -> None @commands.guild_only() @commands.dynamic_cooldown(transfer_cooldown, commands.cooldowns.BucketType.user) # type: ignore @app_commands.describe(receiver="Who do you want to give some balls? You can't transfer more than 10 at a time.") - async def transfer(self, ctx: core.GuildContext, amount: int, *, receiver: discord.Member) -> None: + async def transfer(self, ctx: beira.GuildContext, amount: int, *, receiver: discord.Member) -> None: """Give another server member some of your snowballs, though no more than 10 at a time. Parameters ---------- - ctx: `core.GuildContext` + ctx: `beira.GuildContext` The invocation context. amount: `int` The number of snowballs to transfer. If is greater than 10, pushes the receiver's snowball stock past the @@ -237,7 +237,7 @@ async def transfer(self, ctx: core.GuildContext, amount: int, *, receiver: disco if receiver == ctx.author: msg = "You cannot target yourself with this argument." - raise core.CannotTargetSelf(msg) + raise beira.CannotTargetSelf(msg) # Get the snowball settings for this particular guild. guild_snow_settings = getattr(ctx, "guild_snow_settings", GuildSnowballSettings(ctx.guild.id)) @@ -292,18 +292,18 @@ async def transfer(self, ctx: core.GuildContext, amount: int, *, receiver: disco @snow.command() @commands.guild_only() - @core.is_owner_or_friend() + @beira.is_owner_or_friend() @commands.dynamic_cooldown(steal_cooldown, commands.cooldowns.BucketType.user) # type: ignore @app_commands.describe( amount="How much do you want to steal? (No more than 10 at a time)", victim="Who do you want to pilfer some balls from?", ) - async def steal(self, ctx: core.GuildContext, amount: int, *, victim: discord.Member) -> None: + async def steal(self, ctx: beira.GuildContext, amount: int, *, victim: discord.Member) -> None: """Steal snowballs from another server member, though no more than 10 at a time. Parameters ---------- - ctx: `core.GuildContext` + ctx: `beira.GuildContext` The invocation context. amount: `int` The number of snowballs to steal. If is greater than 10, pushes the receiver's snowball stock past the @@ -314,7 +314,7 @@ async def steal(self, ctx: core.GuildContext, amount: int, *, victim: discord.Me if victim == ctx.author: msg = "You cannot target yourself with this argument." - raise core.CannotTargetSelf(msg) + raise beira.CannotTargetSelf(msg) # Get the snowball settings for this particular guild. guild_snow_settings = getattr(ctx, "guild_snow_settings", GuildSnowballSettings(ctx.guild.id)) @@ -375,12 +375,12 @@ async def steal(self, ctx: core.GuildContext, amount: int, *, victim: discord.Me @snow.group(fallback="get") @commands.guild_only() @app_commands.describe(target="Look up a particular Snowball Sparrer's stats.") - async def stats(self, ctx: core.GuildContext, *, target: discord.User = commands.Author) -> None: + async def stats(self, ctx: beira.GuildContext, *, target: discord.User = commands.Author) -> None: """See who's the best at shooting snow spheres. Parameters ---------- - ctx: `core.GuildContext` + ctx: `beira.GuildContext` The invocation context. target: `discord.User`, default=`commands.Author` The user whose stats are to be displayed. If none, defaults to the caller. Their stats are specifically from @@ -424,12 +424,12 @@ async def stats(self, ctx: core.GuildContext, *, target: discord.User = commands @stats.command(name="global") @app_commands.describe(target="Look up a a player's stats as a summation across all servers.") - async def stats_global(self, ctx: core.Context, *, target: discord.User = commands.Author) -> None: + async def stats_global(self, ctx: beira.Context, *, target: discord.User = commands.Author) -> None: """See who's the best across all Beira servers. Parameters ---------- - ctx: `core.Context` + ctx: `beira.Context` The invocation context. target: `discord.User`, default=`commands.Author` The user whose stats are to be displayed. If none, defaults to the caller. Their global stats are a @@ -463,7 +463,7 @@ async def stats_global(self, ctx: core.Context, *, target: discord.User = comman @snow.group(fallback="get") @commands.guild_only() - async def leaderboard(self, ctx: core.GuildContext) -> None: + async def leaderboard(self, ctx: beira.GuildContext) -> None: """See who's dominating the Snowball Bot leaderboard in your server.""" query = """\ @@ -494,7 +494,7 @@ async def leaderboard(self, ctx: core.GuildContext) -> None: await ctx.send(embed=embed, ephemeral=False) @leaderboard.command(name="global") - async def leaderboard_global(self, ctx: core.Context) -> None: + async def leaderboard_global(self, ctx: beira.Context) -> None: """See who's dominating the Global Snowball Bot leaderboard across all the servers.""" assert self.bot.user # Known to exist during runtime. @@ -515,7 +515,7 @@ async def leaderboard_global(self, ctx: core.Context) -> None: await ctx.send(embed=embed, ephemeral=False) @leaderboard.command(name="guilds") - async def leaderboard_guilds(self, ctx: core.Context) -> None: + async def leaderboard_guilds(self, ctx: beira.Context) -> None: """See which guild is dominating the Snowball Bot leaderboard.""" assert self.bot.user # Known to exist during runtime. @@ -536,7 +536,7 @@ async def leaderboard_guilds(self, ctx: core.Context) -> None: await ctx.send(embed=embed, ephemeral=False) @snow.command() - async def sources(self, ctx: core.Context) -> None: + async def sources(self, ctx: beira.Context) -> None: """Gives links and credit to the Snowsgiving 2021 Help Center article and to reference code.""" embed = ( @@ -550,7 +550,7 @@ async def sources(self, ctx: core.Context) -> None: @throw.before_invoke @transfer.before_invoke @steal.before_invoke - async def snow_before(self, ctx: core.GuildContext) -> None: + async def snow_before(self, ctx: beira.GuildContext) -> None: """Load the snowball settings from the db for the current guild before certain commands are executed. This allows the use of guild-specific limits stored in the db and now temporarily in the context. @@ -563,7 +563,7 @@ async def snow_before(self, ctx: core.GuildContext) -> None: @throw.after_invoke @transfer.after_invoke @steal.after_invoke - async def snow_after(self, ctx: core.GuildContext) -> None: + async def snow_after(self, ctx: beira.GuildContext) -> None: """Remove the snowball settings from the context. Probably not necessary.""" delattr(ctx, "guild_snow_settings") diff --git a/exts/snowball/utils.py b/src/beira/exts/snowball/utils.py similarity index 92% rename from exts/snowball/utils.py rename to src/beira/exts/snowball/utils.py index a568111..c331b02 100644 --- a/exts/snowball/utils.py +++ b/src/beira/exts/snowball/utils.py @@ -5,8 +5,8 @@ import msgspec from discord.ext import commands -import core -from core.utils.db import Connection_alias, Pool_alias +import beira +from beira.utils.db import Connection_alias, Pool_alias __all__ = ( @@ -78,8 +78,9 @@ async def upsert_record( stock = snowball_stats.stock + $7 RETURNING *; """ - args = (member.id, member.guild.id, hits, misses, kos, max(stock, 0), stock) - return cls.from_record(await conn.fetchrow(snowball_stmt, *args)) + values = (member.id, member.guild.id, hits, misses, kos, max(stock, 0), stock) + record = await conn.fetchrow(snowball_stmt, *values) + return cls.from_record(record) class GuildSnowballSettings(msgspec.Struct): @@ -181,7 +182,7 @@ def __init__(self, default_settings: GuildSnowballSettings) -> None: self.default_settings: GuildSnowballSettings = default_settings self.new_settings: GuildSnowballSettings | None = None - async def on_submit(self, interaction: core.Interaction, /) -> None: # type: ignore # Narrowing. + async def on_submit(self, interaction: beira.Interaction, /) -> None: # type: ignore # Narrowing. """Verify changes and update the snowball settings in the database appropriately.""" guild_id = self.default_settings.guild_id @@ -252,7 +253,7 @@ async def on_timeout(self) -> None: await self.message.edit(view=self) - async def interaction_check(self, interaction: core.Interaction, /) -> bool: # type: ignore # Needed narrowing. + async def interaction_check(self, interaction: beira.Interaction, /) -> bool: # type: ignore # Needed narrowing. """Ensure people interacting with this view are only server administrators or bot owners.""" # This should only ever be called in a guild context. @@ -294,7 +295,7 @@ def format_embed(self) -> discord.Embed: ) @discord.ui.button(label="Update", emoji="⚙") - async def change_settings_button(self, interaction: core.Interaction, _: discord.ui.Button[Self]) -> None: + async def change_settings_button(self, interaction: beira.Interaction, _: discord.ui.Button[Self]) -> None: """Send a modal that allows the user to edit the snowball settings for this guild.""" # Get inputs from a modal. @@ -313,7 +314,7 @@ async def change_settings_button(self, interaction: core.Interaction, _: discord await interaction.edit_original_response(embed=self.format_embed()) -def collect_cooldown(ctx: core.Context) -> commands.Cooldown | None: +def collect_cooldown(ctx: beira.Context) -> commands.Cooldown | None: """Sets cooldown for SnowballCog.collect() command. 10 seconds by default. Bot owner and friends get less time. @@ -321,7 +322,7 @@ def collect_cooldown(ctx: core.Context) -> commands.Cooldown | None: rate, per = 1.0, 15.0 # Default cooldown exempt = [ctx.bot.owner_id, ctx.bot.special_friends["aeroali"]] - testing_guild_ids: list[int] = core.CONFIG.discord.important_guilds["dev"] + testing_guild_ids: list[int] = ctx.bot.config.discord.important_guilds["dev"] if ctx.author.id in exempt: return None @@ -331,7 +332,7 @@ def collect_cooldown(ctx: core.Context) -> commands.Cooldown | None: return commands.Cooldown(rate, per) -def transfer_cooldown(ctx: core.Context) -> commands.Cooldown | None: +def transfer_cooldown(ctx: beira.Context) -> commands.Cooldown | None: """Sets cooldown for SnowballCog.transfer() command. 60 seconds by default. Bot owner and friends get less time. @@ -339,7 +340,7 @@ def transfer_cooldown(ctx: core.Context) -> commands.Cooldown | None: rate, per = 1.0, 60.0 # Default cooldown exempt = [ctx.bot.owner_id, ctx.bot.special_friends["aeroali"]] - testing_guild_ids: list[int] = core.CONFIG.discord.important_guilds["dev"] + testing_guild_ids: list[int] = ctx.bot.config.discord.important_guilds["dev"] if ctx.author.id in exempt: return None @@ -349,7 +350,7 @@ def transfer_cooldown(ctx: core.Context) -> commands.Cooldown | None: return commands.Cooldown(rate, per) -def steal_cooldown(ctx: core.Context) -> commands.Cooldown | None: +def steal_cooldown(ctx: beira.Context) -> commands.Cooldown | None: """Sets cooldown for SnowballCog.steal() command. 90 seconds by default. Bot owner and friends get less time. @@ -357,7 +358,7 @@ def steal_cooldown(ctx: core.Context) -> commands.Cooldown | None: rate, per = 1.0, 90.0 # Default cooldown exempt = [ctx.bot.owner_id, ctx.bot.special_friends["aeroali"], ctx.bot.special_friends["athenahope"]] - testing_guild_ids: list[int] = core.CONFIG.discord.important_guilds["dev"] + testing_guild_ids: list[int] = ctx.bot.config.discord.important_guilds["dev"] if ctx.author.id in exempt: return None diff --git a/exts/starkid.py b/src/beira/exts/starkid.py similarity index 87% rename from exts/starkid.py rename to src/beira/exts/starkid.py index a2730d3..1956687 100644 --- a/exts/starkid.py +++ b/src/beira/exts/starkid.py @@ -8,7 +8,7 @@ import discord from discord.ext import commands -import core +import beira LOGGER = logging.getLogger(__name__) @@ -17,7 +17,7 @@ class StarKidCog(commands.Cog, name="StarKid"): """A cog for StarKid-related commands and functionality.""" - def __init__(self, bot: core.Beira) -> None: + def __init__(self, bot: beira.Beira) -> None: self.bot = bot @property @@ -27,7 +27,7 @@ def cog_emoji(self) -> discord.PartialEmoji: return discord.PartialEmoji(name="starkid", id=1077980709802758215) @commands.hybrid_command() - async def nightmare_of_black(self, ctx: core.Context) -> None: + async def nightmare_of_black(self, ctx: beira.Context) -> None: """Bring forth a morphed, warped image of the Lords of Black to prostrate and pray before.""" embed = discord.Embed( @@ -39,7 +39,7 @@ async def nightmare_of_black(self, ctx: core.Context) -> None: await ctx.send(embed=embed) -async def setup(bot: core.Beira) -> None: +async def setup(bot: beira.Beira) -> None: """Connects cog to bot.""" await bot.add_cog(StarKidCog(bot)) diff --git a/exts/story_search.py b/src/beira/exts/story_search.py similarity index 96% rename from exts/story_search.py rename to src/beira/exts/story_search.py index 4278a6d..518ec49 100644 --- a/exts/story_search.py +++ b/src/beira/exts/story_search.py @@ -21,8 +21,8 @@ import msgspec from discord.ext import commands -import core -from core.utils import EMOJI_STOCK, EMOJI_URL, PaginatedEmbedView +import beira +from beira.utils import EMOJI_STOCK, EMOJI_URL, PaginatedEmbedView if TYPE_CHECKING: @@ -254,7 +254,7 @@ class StorySearchCog(commands.Cog, name="Quote Search"): story_records: ClassVar[dict[str, StoryInfo]] = {} - def __init__(self, bot: core.Beira) -> None: + def __init__(self, bot: beira.Beira) -> None: self.bot = bot @property @@ -274,7 +274,7 @@ async def cog_load(self) -> None: if work.is_file() and work.name.endswith("text.md"): self.load_story_text(work) - async def cog_command_error(self, ctx: core.Context, error: Exception) -> None: # type: ignore # Narrowing + async def cog_command_error(self, ctx: beira.Context, error: Exception) -> None: # type: ignore # Narrowing # Extract the original error. error = getattr(error, "original", error) if ctx.interaction: @@ -389,12 +389,12 @@ def _binary_search_text(cls, story: str, list_of_indices: list[int], index: int) return cls.story_records[story].text[actual_index] if actual_index != -1 else "—————" @commands.hybrid_command() - async def random_text(self, ctx: core.Context) -> None: + async def random_text(self, ctx: beira.Context) -> None: """Display a random line from the story. Parameters ---------- - ctx: `core.Context` + ctx: `beira.Context` The invocation context where the command was called. """ @@ -430,12 +430,12 @@ async def random_text(self, ctx: core.Context) -> None: discord.app_commands.Choice(name="Perversion of Purity", value="pop"), ], ) - async def search_text(self, ctx: core.Context, story: str, *, query: str) -> None: + async def search_text(self, ctx: beira.Context, story: str, *, query: str) -> None: """Search the works of ACI100 for a word or phrase. Parameters ---------- - ctx: `core.Context` + ctx: `beira.Context` The invocation context. story: `str` The acronym or abbreviation of a story's title. Currently, there are only four choices. @@ -450,12 +450,12 @@ async def search_text(self, ctx: core.Context, story: str, *, query: str) -> Non view.message = message @commands.hybrid_command() - async def search_cadmean(self, ctx: core.Context, *, query: str) -> None: + async def search_cadmean(self, ctx: beira.Context, *, query: str) -> None: """Search *A Cadmean Victory Remastered* by MJ Bradley for a word or phrase. Parameters ---------- - ctx: `core.Context` + ctx: `beira.Context` The invocation context. query: `str` The string to search for in the story. @@ -478,7 +478,7 @@ async def search_cadmean(self, ctx: core.Context, *, query: str) -> None: ) async def find_text( self, - ctx: core.Context, + ctx: beira.Context, query: str, known_story: str | None = None, url: str | None = None, @@ -504,7 +504,7 @@ async def find_text( await ctx.send("One of `url` and `known_story` must be filled to perform a text search.") -async def setup(bot: core.Beira) -> None: +async def setup(bot: beira.Beira) -> None: """Loads story metadata and connects cog to bot.""" story_info_records = await bot.db_pool.fetch("SELECT * FROM story_information;") diff --git a/exts/timing.py b/src/beira/exts/timing.py similarity index 92% rename from exts/timing.py rename to src/beira/exts/timing.py index 91fac11..0fcc1cc 100644 --- a/exts/timing.py +++ b/src/beira/exts/timing.py @@ -7,7 +7,7 @@ import msgspec from discord.ext import commands -import core +import beira # The supposed schedule table right now. @@ -78,15 +78,15 @@ async def parse_bcp47_timezones(session: aiohttp.ClientSession) -> dict[str, str return _timezone_aliases -class TimingCog(commands.Cog): - def __init__(self, bot: core.Beira) -> None: +class TimingCog(commands.Cog, name="Timing"): + def __init__(self, bot: beira.Beira) -> None: self.bot = bot async def cog_load(self) -> None: self.timezone_aliases: dict[str, str] = await parse_bcp47_timezones(self.bot.web_session) @commands.hybrid_group("timezone", fallback="get") - async def timezone_(self, ctx: core.Context) -> None: + async def timezone_(self, ctx: beira.Context) -> None: """Display your timezone if it's been set previously.""" tz_str = await self.bot.get_user_timezone(ctx.db, ctx.author.id) @@ -97,12 +97,12 @@ async def timezone_(self, ctx: core.Context) -> None: await ctx.send(f"Your timezone is {tz_str}. Your current time is {user_time}.", ephemeral=True) @timezone_.command("set") - async def timezone_set(self, ctx: core.Context, tz: str) -> None: + async def timezone_set(self, ctx: beira.Context, tz: str) -> None: """Set your timezone. Parameters ---------- - ctx: core.Context + ctx: beira.Context The command invocation context. tz: str The timezone. @@ -128,7 +128,7 @@ async def timezone_set(self, ctx: core.Context, tz: str) -> None: ) @timezone_.command("clear") - async def timezone_clear(self, ctx: core.Context) -> None: + async def timezone_clear(self, ctx: beira.Context) -> None: """Clear your timezone.""" await ctx.db.execute("UPDATE users SET timezone = NULL WHERE user_id = $1;", ctx.author.id) @@ -136,7 +136,7 @@ async def timezone_clear(self, ctx: core.Context) -> None: await ctx.send("Your timezone has been cleared.", ephemeral=True) @timezone_.command("info") - async def timezone_info(self, ctx: core.Context, tz: str) -> None: + async def timezone_info(self, ctx: beira.Context, tz: str) -> None: try: zone = ZoneInfo(tz) except ZoneInfoNotFoundError: @@ -155,7 +155,7 @@ async def timezone_info(self, ctx: core.Context, tz: str) -> None: @timezone_set.autocomplete("tz") @timezone_info.autocomplete("tz") async def timezone_autocomplete( - self, itx: core.Interaction, current: str + self, itx: beira.Interaction, current: str ) -> list[discord.app_commands.Choice[str]]: if not current: return [ @@ -169,7 +169,7 @@ async def timezone_autocomplete( # TODO: Complete and enable later. -async def setup(bot: core.Beira) -> None: +async def setup(bot: beira.Beira) -> None: """Connects cog to bot.""" # await bot.add_cog(TimingCog(bot)) # noqa: ERA001 diff --git a/exts/todo.py b/src/beira/exts/todo.py similarity index 92% rename from exts/todo.py rename to src/beira/exts/todo.py index 1e77e8e..0772bf7 100644 --- a/exts/todo.py +++ b/src/beira/exts/todo.py @@ -13,9 +13,8 @@ import msgspec from discord.ext import commands -import core -from core.utils import OwnedView, PaginatedEmbedView -from core.utils.db import Connection_alias, Pool_alias +import beira +from beira.utils import Connection_alias, OwnedView, PaginatedEmbedView, Pool_alias LOGGER = logging.getLogger(__name__) @@ -171,7 +170,17 @@ async def on_submit(self, interaction: discord.Interaction, /) -> None: self.interaction = interaction -class TodoCompleteButton(discord.ui.Button["TodoViewABC"]): +class TodoViewABC(ABC, OwnedView): + """An ABC to define a common interface for views with to-do buttons.""" + + todo_item: TodoItem + + @abstractmethod + async def update_todo(self, interaction: discord.Interaction[Any], updated_item: TodoItem) -> None: + raise NotImplementedError + + +class TodoCompleteButton(discord.ui.Button[TodoViewABC]): """A Discord button that marks to-do items in the parent view as (in)complete, and changes visually as a result. Interacts with kwargs for default styling on initialization. @@ -194,7 +203,7 @@ def __init__(self, completed_at: datetime.datetime | None = None, **kwargs: Any) kwargs["label"] = kwargs.get("label", "Mark as incomplete") super().__init__(**kwargs) - async def callback(self, interaction: core.Interaction) -> None: # type: ignore # Necessary narrowing + async def callback(self, interaction: beira.Interaction) -> None: # type: ignore # Necessary narrowing """Changes the button's look, and updates the parent view and its to-do item's completion status.""" assert self.view is not None @@ -217,7 +226,7 @@ async def callback(self, interaction: core.Interaction) -> None: # type: ignore await interaction.followup.send(f"Todo task marked as {completion_status}!", ephemeral=True) -class TodoEditButton(discord.ui.Button["TodoViewABC"]): +class TodoEditButton(discord.ui.Button[TodoViewABC]): """A Discord button sends modals for editing the content of the parent view's to-do item. Interacts with kwargs for default styling on initialization. @@ -233,7 +242,7 @@ def __init__(self, **kwargs: Any) -> None: kwargs["label"] = kwargs.get("label", "Edit") super().__init__(**kwargs) - async def callback(self, interaction: core.Interaction) -> None: # type: ignore # Necessary narrowing + async def callback(self, interaction: beira.Interaction) -> None: # type: ignore # Necessary narrowing """Uses a modal to get the (edited) content for a to-do item, then updates the item and parent view.""" assert self.view is not None @@ -257,7 +266,7 @@ async def callback(self, interaction: core.Interaction) -> None: # type: ignore await modal.interaction.response.send_message("No changes made to the todo item.", ephemeral=True) -class TodoDeleteButton(discord.ui.Button["TodoViewABC"]): +class TodoDeleteButton(discord.ui.Button[TodoViewABC]): """A Discord button that allows users to delete a to-do item. Interacts with kwargs for default styling on initialization. @@ -273,7 +282,7 @@ def __init__(self, **kwargs: Any) -> None: kwargs["label"] = kwargs.get("label", "Delete") super().__init__(**kwargs) - async def callback(self, interaction: core.Interaction) -> None: # type: ignore # Necessary narrowing + async def callback(self, interaction: beira.Interaction) -> None: # type: ignore # Necessary narrowing """Deletes the to-do item, and updates the parent view accordingly.""" assert self.view is not None @@ -283,16 +292,6 @@ async def callback(self, interaction: core.Interaction) -> None: # type: ignore await interaction.followup.send("Todo task deleted!", ephemeral=True) -class TodoViewABC(ABC, OwnedView): - """An ABC to define a common interface for views with to-do buttons.""" - - todo_item: TodoItem - - @abstractmethod - async def update_todo(self, interaction: discord.Interaction[Any], updated_item: TodoItem) -> None: - raise NotImplementedError - - class TodoView(TodoViewABC): """A Discord view for interacting with a single to-do item. @@ -332,12 +331,12 @@ async def on_timeout(self) -> None: await self.message.edit(view=self) self.stop() - async def update_todo(self, interaction: core.Interaction, updated_item: TodoItem) -> None: + async def update_todo(self, interaction: beira.Interaction, updated_item: TodoItem) -> None: """Updates the state of the view, including the to-do item it holds, based on a passed in, new version of it. Parameters ---------- - interaction: `core.Interaction` + interaction: `beira.Interaction` The interaction that caused this state change. updated_record: `TodoItem` The new version of the to-do item for the view to display. @@ -410,12 +409,12 @@ def format_page(self) -> discord.Embed: self.todo_item = self.pages[self.page_index][0] return self.todo_item.display_embed() - async def update_todo(self, interaction: core.Interaction, updated_item: TodoItem) -> None: + async def update_todo(self, interaction: beira.Interaction, updated_item: TodoItem) -> None: """Updates the state of the view, including the to-do item currently in scope, based on a passed in item. Parameters ---------- - interaction: `core.Interaction` + interaction: `beira.Interaction` The interaction that caused this state change. updated_record: `TodoItem` The new version of the to-do item for the view to display. @@ -436,7 +435,7 @@ class TodoCog(commands.Cog, name="Todo"): Inspired by the to-do cogs of RoboDanny and Mipha. """ - def __init__(self, bot: core.Beira) -> None: + def __init__(self, bot: beira.Beira) -> None: self.bot = bot @property @@ -445,7 +444,7 @@ def cog_emoji(self) -> discord.PartialEmoji: return discord.PartialEmoji(name="\N{SPIRAL NOTE PAD}") - async def cog_command_error(self, ctx: core.Context, error: Exception) -> None: # type: ignore # Narrowing + async def cog_command_error(self, ctx: beira.Context, error: Exception) -> None: # type: ignore # Narrowing # Extract the original error. error = getattr(error, "original", error) if ctx.interaction: @@ -454,18 +453,18 @@ async def cog_command_error(self, ctx: core.Context, error: Exception) -> None: LOGGER.exception("", exc_info=error) @commands.hybrid_group() - async def todo(self, ctx: core.Context) -> None: + async def todo(self, ctx: beira.Context) -> None: """Commands to manage your to-do items.""" await ctx.send_help(ctx.command) @todo.command("add") - async def todo_add(self, ctx: core.Context, content: str) -> None: + async def todo_add(self, ctx: beira.Context, content: str) -> None: """Add an item to your to-do list. Parameters ---------- - ctx: `core.Context` + ctx: `beira.Context` The invocation context. content: `str` The content of the to-do item. @@ -480,12 +479,12 @@ async def todo_add(self, ctx: core.Context, content: str) -> None: await ctx.send("Todo added!", ephemeral=True) @todo.command("delete") - async def todo_delete(self, ctx: core.Context, todo_id: int) -> None: + async def todo_delete(self, ctx: beira.Context, todo_id: int) -> None: """Remove a to-do item based on its id. Parameters ---------- - ctx: `core.Context` + ctx: `beira.Context` The invocation context. todo_id: `int` The id of the task to do. @@ -496,19 +495,19 @@ async def todo_delete(self, ctx: core.Context, todo_id: int) -> None: await ctx.send(f"To-do item #{todo_id} has been removed.", ephemeral=True) @todo.command("clear") - async def todo_clear(self, ctx: core.Context) -> None: + async def todo_clear(self, ctx: beira.Context) -> None: """Clear all of your to-do items.""" await self.bot.db_pool.execute("DELETE FROM todos where user_id = $1;", ctx.author.id) await ctx.send("All of your todo items have been cleared.", ephemeral=True) @todo.command("show") - async def todo_show(self, ctx: core.Context, todo_id: int) -> None: + async def todo_show(self, ctx: beira.Context, todo_id: int) -> None: """Show information about a to-do item based on its id. Parameters ---------- - ctx: `core.Context` + ctx: `beira.Context` The invocation context. todo_id: `int` The id of the task to do. @@ -524,12 +523,12 @@ async def todo_show(self, ctx: core.Context, todo_id: int) -> None: await ctx.send("Either this record doesn't exist, or you can't see it.") @todo.command("list") - async def todo_list(self, ctx: core.Context, complete: bool = False, pending: bool = True) -> None: + async def todo_list(self, ctx: beira.Context, complete: bool = False, pending: bool = True) -> None: """Show information about your to-do items. Parameters ---------- - ctx: `core.Context` + ctx: `beira.Context` The invocation context. complete: `bool`, default=False Whether to pull completed to-do items. Defaults to False. @@ -553,7 +552,7 @@ async def todo_list(self, ctx: core.Context, complete: bool = False, pending: bo @todo_show.autocomplete("todo_id") async def todo_id_autocomplete( self, - interaction: core.Interaction, + interaction: beira.Interaction, current: str, ) -> list[discord.app_commands.Choice[int]]: """Autocomplete for to-do items owned by the invoking user.""" @@ -569,7 +568,7 @@ async def todo_id_autocomplete( ][:25] -async def setup(bot: core.Beira) -> None: +async def setup(bot: beira.Beira) -> None: """Connect cog to bot.""" await bot.add_cog(TodoCog(bot)) diff --git a/src/beira/exts/triggers/__init__.py b/src/beira/exts/triggers/__init__.py new file mode 100644 index 0000000..198baaf --- /dev/null +++ b/src/beira/exts/triggers/__init__.py @@ -0,0 +1,11 @@ +import beira + +from .misc_triggers import MiscTriggersCog +from .rss_notifications import RSSNotificationsCog + + +async def setup(bot: beira.Beira) -> None: + """Connects cogs to bot.""" + + await bot.add_cog(MiscTriggersCog(bot)) + await bot.add_cog(RSSNotificationsCog(bot)) diff --git a/src/beira/exts/triggers/misc_triggers.py b/src/beira/exts/triggers/misc_triggers.py new file mode 100644 index 0000000..b61fee0 --- /dev/null +++ b/src/beira/exts/triggers/misc_triggers.py @@ -0,0 +1,247 @@ +"""custom_notifications.py: One or more listeners for sending custom notifications based on events.""" + +import asyncio +import logging +import re + +import discord +import lxml.etree +import lxml.html +import msgspec +from discord.ext import commands + +import beira + + +LOGGER = logging.getLogger(__name__) + +type ValidGuildChannel = ( + discord.VoiceChannel | discord.StageChannel | discord.ForumChannel | discord.TextChannel | discord.CategoryChannel +) + + +HEADERS = { + "User-Agent": ( + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " + "AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/104.0.0.0 Safari/537.36" + ) +} + +# The channel where deleted messages are logged. Would be 799077440139034654 in "production". +ACI_DELETE_LOG_CHANNEL = 975459460560605204 + +# A list of ids for Tatsu leveled roles to keep track of. +ACI_LEVELED_ROLES = { + 694616299476877382, + 694615984438509636, + 694615108323639377, + 694615102237835324, + 747520979735019572, +} + +# The mod role(s) to ping when sending notifications. +ACI_MOD_ROLE = 780904973004570654 + +PRIVATE_GUILD_WITH_9GAG_LINKS = 1097976528832307271 + +LEAKY_INSTAGRAM_LINK_PATTERN = re.compile(r"(instagram\.com/.*?)&igsh.*==") +LOSSY_TWITTER_LINK_PATTERN = re.compile(r"(?:http(?:s)?://|(? None: + self.bot = bot + + self.aci_guild_id = self.bot.config.discord.important_guilds["prod"][0] + + # The webhook url that will be used to send ACI-related notifications. + aci_webhook_url = self.bot.config.discord.webhooks[0] + self.role_log_webhook = discord.Webhook.from_url(aci_webhook_url, session=bot.web_session) + + @commands.Cog.listener("on_member_update") + async def on_server_boost_role_member_update(self, before: discord.Member, after: discord.Member) -> None: + """Listener that sends a notification if members of the ACI100 server earn certain roles. + + Condition for activating: + - Boost the server and earn the premium subscriber, or "Server Booster", role. + """ + + # Check if the update is in the right server, a member got new roles, and they got a new "Server Booster" role. + if ( + before.guild.id == self.aci_guild_id + and len(new_roles := set(after.roles).difference(before.roles)) > 0 + and after.guild.premium_subscriber_role in new_roles + ): + # Send a message notifying holders of some other role(s) about this new role acquisition. + content = f"<@&{ACI_MOD_ROLE}>, {after.mention} just boosted the server!" + await self.role_log_webhook.send(content) + + @commands.Cog.listener("on_member_update") + async def on_leveled_role_member_update(self, before: discord.Member, after: discord.Member) -> None: + """Listener that sends a notification if members of the ACI100 server earn certain roles. + + Condition for activating: + - Earn a Tatsu leveled role above "The Ears". + """ + + # Check if the update is in the right server, a member got new roles, and they got a relevant leveled role. + if ( + before.guild.id == self.aci_guild_id + and len(new_roles := set(after.roles).difference(before.roles)) > 0 + and (new_leveled_roles := tuple(role for role in new_roles if (role.id in ACI_LEVELED_ROLES))) + ): + # Ensure the user didn't just rejoin. + if after.joined_at is not None: + # Technically, at 8 points every two minutes, it's possible to hit the lowest relevant leveled role in + # 20h 50m on ACI, so 21 hours will be the limit. + recently_rejoined = (discord.utils.utcnow() - after.joined_at).total_seconds() < 75600 + else: + recently_rejoined = False + + if new_leveled_roles and not recently_rejoined: + # Send a message notifying holders of some other role(s) about this new role acquisition. + role_names = tuple(role.name for role in new_leveled_roles) + content = f"<@&{ACI_MOD_ROLE}>, {after.mention} was given the `{role_names}` role(s)." + await self.role_log_webhook.send(content) + + # @commands.Cog.listener("on_message") + async def on_bad_twitter_link(self, message: discord.Message) -> None: + if message.author == self.bot.user or (not message.guild) or message.guild.id != self.aci_guild_id: + return + + if not LOSSY_TWITTER_LINK_PATTERN.search(message.content): + return + + cleaned_content = re.sub(r"twitter\.com/(.+)", r"fxtwitter.com/\1", message.content) + new_content = ( + f"*Corrected Twitter link(s)*\n" + f"Reposted from {message.author.mention} ({message.author.name} - {message.author.id}):\n" + "————————\n" + "\n" + f"{cleaned_content}" + ) + + await message.reply(new_content, allowed_mentions=discord.AllowedMentions(users=False)) + + @commands.Cog.listener("on_message") + async def on_leaky_instagram_link(self, message: discord.Message) -> None: + if message.author == self.bot.user or (not message.guild) or message.guild.id != self.aci_guild_id: + return + + if not LEAKY_INSTAGRAM_LINK_PATTERN.search(message.content): + return + + cleaned_content = re.sub(LEAKY_INSTAGRAM_LINK_PATTERN, "\1", message.content) + new_content = ( + f"*Cleaned Instagram link(s)*\n" + f"Reposted from {message.author.mention} ({message.author.name} - {message.author.id}):\n" + "————————\n" + "\n" + f"{cleaned_content}" + ) + + if message.attachments: + send_msg = message.channel.send( + new_content, + allowed_mentions=discord.AllowedMentions(users=False), + files=[await atmt.to_file() for atmt in message.attachments], + ) + else: + send_msg = message.channel.send(new_content, allowed_mentions=discord.AllowedMentions(users=False)) + + await message.delete() + await send_msg + + @commands.Cog.listener("on_message") + async def on_bad_9gag_link(self, message: discord.Message) -> None: + if message.author == self.bot.user or (not message.guild) or message.guild.id != PRIVATE_GUILD_WITH_9GAG_LINKS: + return + + async def _get_9gag_page(link: str) -> bytes: + async with self.bot.web_session.get(link, headers=HEADERS) as response: + response.raise_for_status() + return await response.read() + + if links := LOSSY_9GAG_LINK_PATTERN.findall(message.content): + tasks = [asyncio.create_task(_get_9gag_page(link)) for link in links] + results = await asyncio.gather(*tasks, return_exceptions=True) + page_data = [page for page in results if not isinstance(page, BaseException)] + + mp4_urls: list[str] = [] + for page in page_data: + element = lxml.html.fromstring(page).find(".//script[@type='application/ld+json']") + if element is not None and element.text: + mp4_urls.append(msgspec.json.decode(element.text)["video"]["contentUrl"]) + + if mp4_urls: + fixed_urls = "\n".join(mp4_urls) + content = ( + f"*Corrected 9gag link(s)*\n" + f"Reposted from {message.author.mention} ({message.author.name} - {message.author.id}):\n" + "————————\n" + "\n" + f"{fixed_urls}" + ) + await message.reply(content, allowed_mentions=discord.AllowedMentions(users=False, replied_user=False)) + + # @commands.Cog.listener("on_message_delete") + async def on_any_message_delete(self, payload: discord.RawMessageDeleteEvent) -> None: + # TODO: Improve. + + # Only check in ACI100 server. + if payload.guild_id == self.aci_guild_id: + # Attempt to get the channel the message was sent in. + try: + channel = self.bot.get_channel(payload.channel_id) or await self.bot.fetch_channel(payload.channel_id) + except (discord.InvalidData, discord.HTTPException): + LOGGER.info("Could not find the channel of the deleted message: %s", payload) + return + assert isinstance(channel, ValidGuildChannel | discord.Thread) # Known if we reach this point. + + # Attempt to get the message itself. + message = payload.cached_message + if not message and not isinstance(channel, (discord.ForumChannel, discord.CategoryChannel)): + try: + message = await channel.fetch_message(payload.message_id) + except discord.HTTPException: + LOGGER.info("Could not find the deleted message: %s", payload) + return + assert message is not None # Known if we reach this point. + + # Create a log embed to represent the deleted message. + extra_attachments: list[str] = [] + embed = ( + discord.Embed( + colour=discord.Colour.dark_blue(), + description=( + f"**Message sent by {message.author.mention} - Deleted in <#{payload.channel_id}>**" + f"\n{message.content}" + ), + timestamp=discord.utils.utcnow(), + ) + .set_author(name=str(message.author), icon_url=message.author.display_avatar.url) + .set_footer(text=f"Author: {message.author.id} | Message ID: {payload.message_id}") + .add_field(name="Sent at:", value=discord.utils.format_dt(message.created_at, style="F"), inline=False) + ) + + # Put attachments in the one log message or in another. + if len(message.attachments) == 1: + if message.attachments[0].content_type in {"gif", "jpg", "png", "webp", "webm", "mp4"}: + embed.set_image(url=message.attachments[0].url) + else: + embed.add_field(name="Attachment", value="See below.") + extra_attachments.append(message.attachments[0].url) + elif len(message.attachments) > 1: + embed.add_field(name="Attachments", value="See below.") + extra_attachments.extend(att.url for att in message.attachments) + + # Send the log message(s). + delete_log_channel = self.bot.get_channel(ACI_DELETE_LOG_CHANNEL) + assert isinstance(delete_log_channel, discord.TextChannel) # Known at runtime. + + await delete_log_channel.send(embed=embed) + if extra_attachments: + content = "\n".join(extra_attachments) + await delete_log_channel.send(content) diff --git a/exts/notifications/rss_notifications.py b/src/beira/exts/triggers/rss_notifications.py similarity index 76% rename from exts/notifications/rss_notifications.py rename to src/beira/exts/triggers/rss_notifications.py index 72a6e0f..1268c53 100644 --- a/exts/notifications/rss_notifications.py +++ b/src/beira/exts/triggers/rss_notifications.py @@ -7,7 +7,7 @@ import msgspec from discord.ext import commands, tasks -import core +import beira class NotificationRecord(msgspec.Struct): @@ -34,10 +34,9 @@ class RSSNotificationsCog(commands.Cog): ); """ - def __init__(self, bot: core.Beira) -> None: + def __init__(self, bot: beira.Beira) -> None: self.bot = bot - if False: # FIXME: Remove once this works. - self.notification_check_loop.start() + # self.notification_check_loop.start() async def cog_unload(self) -> None: self.notification_check_loop.cancel() @@ -53,12 +52,15 @@ def process_new_item(self, text: str) -> discord.Embed: async def notification_check_loop(self) -> None: """Continuously check urls for updates and send notifications to webhooks accordingly.""" - notif_tasks: list[asyncio.Task[str | None]] = [asyncio.create_task(self.check_url(rec)) for rec in self.records] - results: list[str | None] = await asyncio.gather(*notif_tasks) - to_update = ((result, rec) for result, rec in zip(results, self.records, strict=True) if result is not None) - for result, rec in to_update: - embed = self.process_new_item(result) - await rec.webhook.send(embed=embed) + notif_tasks = [asyncio.create_task(self.check_url(rec)) for rec in self.records] + notif_results = await asyncio.gather(*notif_tasks) + + send_tasks = [ + asyncio.create_task(record.webhook.send(embed=self.process_new_item(result))) + for result, record in zip(notif_results, self.records, strict=True) + if result is not None + ] + await asyncio.gather(*send_tasks) @notification_check_loop.before_loop async def notification_check_loop_before(self) -> None: diff --git a/exts/webhook_logging.py b/src/beira/exts/webhook_logging.py similarity index 79% rename from exts/webhook_logging.py rename to src/beira/exts/webhook_logging.py index 45d10c3..441d9e6 100644 --- a/exts/webhook_logging.py +++ b/src/beira/exts/webhook_logging.py @@ -1,15 +1,17 @@ import discord from discord.ext import commands, tasks -import core +import beira class LoggingCog(commands.Cog): - def __init__(self, bot: core.Beira) -> None: + def __init__(self, bot: beira.Beira) -> None: self.bot = bot - self.webhook = discord.Webhook.from_url(core.CONFIG.discord.logging_webhook, client=bot) + + self.webhook = discord.Webhook.from_url(self.bot.config.discord.logging_webhook, client=bot) self.username = "Beira Logging" self.avatar_url = "https://cdn.dribbble.com/users/1065420/screenshots/3751686/gwen-taking-notes.gif" + self.webhook_logging_loop.start() async def cog_unload(self) -> None: @@ -23,7 +25,7 @@ async def webhook_logging_loop(self) -> None: await self.webhook.send(username=self.username, avatar_url=self.avatar_url, embed=log_embed) -async def setup(bot: core.Beira) -> None: +async def setup(bot: beira.Beira) -> None: """Connects cog to bot.""" await bot.add_cog(LoggingCog(bot)) diff --git a/core/tree.py b/src/beira/tree.py similarity index 100% rename from core/tree.py rename to src/beira/tree.py diff --git a/core/utils/__init__.py b/src/beira/utils/__init__.py similarity index 100% rename from core/utils/__init__.py rename to src/beira/utils/__init__.py diff --git a/core/utils/db.py b/src/beira/utils/db.py similarity index 91% rename from core/utils/db.py rename to src/beira/utils/db.py index 975f5af..c46710f 100644 --- a/core/utils/db.py +++ b/src/beira/utils/db.py @@ -1,4 +1,4 @@ -"""db.py: Utility functions for interacting with the database.""" +"""Utilities for interacting with the database.""" from typing import TYPE_CHECKING diff --git a/core/utils/embeds.py b/src/beira/utils/embeds.py similarity index 93% rename from core/utils/embeds.py rename to src/beira/utils/embeds.py index 82abbaf..387663d 100644 --- a/core/utils/embeds.py +++ b/src/beira/utils/embeds.py @@ -1,6 +1,4 @@ -"""embeds.py: This class provides embeds for user-specific statistics separated into fields.""" - -from __future__ import annotations +"""Embed-related helpers, e.g. a class for displaying user-specific statistics separated into fields.""" import itertools import logging @@ -19,9 +17,8 @@ class StatsEmbed(discord.Embed): - """A subclass of `DTEmbed` that displays given statistics for a user. - - This has a default colour of 0x2f3136 and a default timestamp for right now in UTC. + """A subclass of `discord.Embed` that displays given statistics for a user, with a default colour of 0x2f3136 and a + default timestamp of now in UTC. Parameters ---------- diff --git a/core/utils/emojis.py b/src/beira/utils/emojis.py similarity index 97% rename from core/utils/emojis.py rename to src/beira/utils/emojis.py index 9f272a9..9372290 100644 --- a/core/utils/emojis.py +++ b/src/beira/utils/emojis.py @@ -1,8 +1,12 @@ +"""Emoji-related helpers and shortcuts.""" + __all__ = ( - "EMOJI_STOCK", "EMOJI_URL", + "EMOJI_STOCK", ) +EMOJI_URL = "https://cdn.discordapp.com/emojis/{0}.webp?size=128&quality=lossless" + # fmt: off EMOJI_STOCK: dict[str, str] = { "blue_star": "<:snow_StarB_2021:917859752057376779>", @@ -31,5 +35,3 @@ "ao3": "<:ao3:1229883149136433325>", } # fmt: on - -EMOJI_URL = "https://cdn.discordapp.com/emojis/{0}.webp?size=128&quality=lossless" diff --git a/src/beira/utils/extras/__init__.py b/src/beira/utils/extras/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/beira/utils/extras/formats.py b/src/beira/utils/extras/formats.py new file mode 100644 index 0000000..404d26a --- /dev/null +++ b/src/beira/utils/extras/formats.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from collections.abc import Sequence + + +class plural: + def __init__(self, value: int) -> None: + self.value: int = value + + def __format__(self, format_spec: str) -> str: + v = self.value + singular, _, plural = format_spec.partition("|") + plural = plural or f"{singular}s" + if abs(v) != 1: + return f"{v} {plural}" + return f"{v} {singular}" + + +def human_join(seq: Sequence[str], delim: str = ", ", final: str = "or") -> str: + size = len(seq) + if size == 0: + return "" + if size == 1: + return seq[0] + if size == 2: + return f"{seq[0]} {final} {seq[1]}" + + return delim.join(seq[:-1]) + f" {final} {seq[-1]}" diff --git a/src/beira/utils/extras/scheduler.py b/src/beira/utils/extras/scheduler.py new file mode 100644 index 0000000..eba12aa --- /dev/null +++ b/src/beira/utils/extras/scheduler.py @@ -0,0 +1,578 @@ +# region License +# Vendored from https://github.com/mikeshardmind/discord-scheduler with some modifications to accommodate a different +# backend. See the license below: +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +# +# Copyright (C) 2023 Michael Hall +# endregion + +import asyncio +from datetime import datetime, timedelta +from itertools import count +from types import TracebackType +from typing import Protocol, Self +from uuid import uuid4 +from warnings import warn +from zoneinfo import ZoneInfo + +import asyncpg +from msgspec import Struct, field +from msgspec.json import decode as json_decode, encode as json_encode + +from ..db import Connection_alias # noqa: TID252 + + +class BotLike(Protocol): + def dispatch(self, event_name: str, /, *args: object, **kwargs: object) -> None: ... + + async def wait_until_ready(self) -> None: ... + + +__all__ = ("DiscordBotScheduler", "ScheduledDispatch", "Scheduler") + +SQLROW_TYPE = tuple[str, str, str, str, int | None, int | None, bytes | None] +DATE_FMT = r"%Y-%m-%d %H:%M" + +_c = count() + +INITIALIZATION_STATEMENTS = """ +CREATE TABLE IF NOT EXISTS scheduled_dispatches ( + task_id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + dispatch_name TEXT NOT NULL, + dispatch_time TIMESTAMP WITH TIME ZONE NOT NULL, + dispatch_zone TEXT NOT NULL, + associated_guild BIGINT, + associated_user BIGINT, + dispatch_extra JSONB +); +""" + +ZONE_SELECTION_STATEMENT = """ +SELECT DISTINCT dispatch_zone FROM scheduled_dispatches; +""" + +UNSCHEDULE_BY_UUID_STATEMENT = """ +DELETE FROM scheduled_dispatches WHERE task_id = $1; +""" + +UNSCHEDULE_ALL_BY_GUILD_STATEMENT = """ +DELETE FROM scheduled_dispatches +WHERE associated_guild IS NOT NULL AND associated_guild = $1; +""" + +UNSCHEDULE_ALL_BY_USER_STATEMENT = """ +DELETE FROM scheduled_dispatches +WHERE associated_user IS NOT NULL AND associated_user = $1; +""" + +UNSCHEDULE_ALL_BY_MEMBER_STATEMENT = """ +DELETE FROM scheduled_dispatches +WHERE + associated_guild IS NOT NULL + AND associated_user IS NOT NULL + AND associated_guild = $1 + AND associated_user = $2 +; +""" + +UNSCHEDULE_ALL_BY_DISPATCH_NAME_STATEMENT = """ +DELETE FROM scheduled_dispatches WHERE dispatch_name = $1; +""" + +UNSCHEDULE_ALL_BY_NAME_AND_USER_STATEMENT = """ +DELETE FROM scheduled_dispatches +WHERE + dispatch_name = $1 + AND associated_user IS NOT NULL + AND associated_user = $2; +""" + +UNSCHEDULE_ALL_BY_NAME_AND_GUILD_STATEMENT = """ +DELETE FROM scheduled_dispatches +WHERE + dispatch_name = $1 + AND associated_guild IS NOT NULL + AND associated_guild = $2; +""" + +UNSCHEDULE_ALL_BY_NAME_AND_MEMBER_STATEMENT = """ +DELETE FROM scheduled_dispatches +WHERE + dispatch_name = $1 + AND associated_guild IS NOT NULL + AND associated_user IS NOT NULL + AND associated_guild = $2 + AND associated_user = $3 +; +""" + +SELECT_ALL_BY_NAME_STATEMENT = """ +SELECT * FROM scheduled_dispatches WHERE dispatch_name = $1; +""" + +SELECT_ALL_BY_NAME_AND_GUILD_STATEMET = """ +SELECT * FROM scheduled_dispatches +WHERE + dispatch_name = $1 + AND associated_guild IS NOT NULL + AND associated_guild = $2; +""" + +SELECT_ALL_BY_NAME_AND_USER_STATEMENT = """ +SELECT * FROM scheduled_dispatches +WHERE + dispatch_name = $1 + AND associated_user IS NOT NULL + AND associated_user = $2; +""" + +SELECT_ALL_BY_NAME_AND_MEMBER_STATEMENT = """ +SELECT * FROM scheduled_dispatches +WHERE + dispatch_name = $1 + AND associated_guild IS NOT NULL + AND associated_user IS NOT NULL + AND associated_guild = $2 + AND associated_user = $3 +; +""" + +INSERT_SCHEDULE_STATEMENT = """ +INSERT INTO scheduled_dispatches +(task_id, dispatch_name, dispatch_time, dispatch_zone, associated_guild, associated_user, dispatch_extra) +VALUES ($1, $2, $3, $4, $5, $6, $7); +""" + +DELETE_RETURNING_UPCOMING_IN_ZONE_STATEMENT = """ +DELETE FROM scheduled_dispatches +WHERE dispatch_time < $1 AND dispatch_zone = $2 +RETURNING *; +""" + + +class ScheduledDispatch(Struct, frozen=True, gc=False): + task_id: str + dispatch_name: str + dispatch_time: str + dispatch_zone: str + associated_guild: int | None + associated_user: int | None + dispatch_extra: bytes | None + _count: int = field(default_factory=lambda: next(_c)) + + def __eq__(self, other: object) -> bool: + return self is other + + def __lt__(self, other: object) -> bool: + if isinstance(other, type(self)): + return (self.get_arrow_time(), self._count) < (other.get_arrow_time(), other._count) + return False + + def __gt__(self, other: object) -> bool: + if isinstance(other, type(self)): + return (self.get_arrow_time(), self._count) > (other.get_arrow_time(), other._count) + return False + + @classmethod + def from_pg_row(cls: type[Self], row: asyncpg.Record) -> Self: + tid, name, time, zone, guild, user, extra_bytes = row + return cls(tid, name, time, zone, guild, user, extra_bytes) + + @classmethod + def from_exposed_api( + cls: type[Self], + *, + name: str, + time: str, + zone: str, + guild: int | None, + user: int | None, + extra: object | None, + ) -> Self: + packed: bytes | None = None + if extra is not None: + f = json_encode(extra) + packed = f + return cls(uuid4().hex, name, time, zone, guild, user, packed) + + def to_pg_row(self) -> SQLROW_TYPE: + return ( + self.task_id, + self.dispatch_name, + self.dispatch_time, + self.dispatch_zone, + self.associated_guild, + self.associated_user, + self.dispatch_extra, + ) + + def get_arrow_time(self) -> datetime: + return datetime.strptime(self.dispatch_time, DATE_FMT).replace(tzinfo=ZoneInfo(self.dispatch_zone)) + + def unpack_extra(self) -> object | None: + if self.dispatch_extra: + return json_decode(self.dispatch_extra, strict=True) + return None + + +async def _setup_db(conn: Connection_alias) -> set[str]: + async with conn.transaction(): + await conn.execute(INITIALIZATION_STATEMENTS) + return {row["dispatch_zone"] for row in await conn.fetch(ZONE_SELECTION_STATEMENT)} + + +async def _get_scheduled(conn: Connection_alias, granularity: int, zones: set[str]) -> list[ScheduledDispatch]: + ret: list[ScheduledDispatch] = [] + if not zones: + return ret + + cutoff = datetime.now(ZoneInfo("UTC")) + timedelta(minutes=granularity) + async with conn.transaction(): + for zone in zones: + local_time = cutoff.astimezone(ZoneInfo(zone)).strftime(DATE_FMT) + rows = await conn.fetch(DELETE_RETURNING_UPCOMING_IN_ZONE_STATEMENT, local_time, zone) + ret.extend(map(ScheduledDispatch.from_pg_row, rows)) + + return ret + + +async def _schedule( + conn: Connection_alias, + *, + dispatch_name: str, + dispatch_time: str, + dispatch_zone: str, + guild_id: int | None, + user_id: int | None, + dispatch_extra: object | None, +) -> str: + # do this here, so if it fails, it fails at scheduling + _time = datetime.strptime(dispatch_time, DATE_FMT).replace(tzinfo=ZoneInfo(dispatch_zone)) + obj = ScheduledDispatch.from_exposed_api( + name=dispatch_name, + time=dispatch_time, + zone=dispatch_zone, + guild=guild_id, + user=user_id, + extra=dispatch_extra, + ) + + async with conn.transaction(): + await conn.execute(INSERT_SCHEDULE_STATEMENT, *obj.to_pg_row()) + + return obj.task_id + + +async def _query(conn: Connection_alias, query_str: str, params: tuple[int | str, ...]) -> list[ScheduledDispatch]: + return [ScheduledDispatch.from_pg_row(row) for row in await conn.fetch(query_str, *params)] + + +async def _drop(conn: Connection_alias, query_str: str, params: tuple[int | str, ...]) -> None: + async with conn.transaction(): + await conn.execute(query_str, *params) + + +class Scheduler: + def __init__(self, db_pool: asyncpg.Pool[asyncpg.Record], granularity: int = 1): + if granularity < 1: + msg = "Granularity must be a positive iteger number of minutes" + raise ValueError(msg) + asyncio.get_running_loop() + self.granularity = granularity + self._pool = db_pool + self._zones: set[str] = set() # We don't re-narrow this anywhere currently, only expand it. + self._queue: asyncio.PriorityQueue[ScheduledDispatch] = asyncio.PriorityQueue() + self._ready = False + self._closing = False + self._lock = asyncio.Lock() + self._loop_task: asyncio.Task[None] | None = None + self._discord_task: asyncio.Task[None] | None = None + + def stop(self) -> None: + if self._loop_task is None: + msg = "Contextmanager, use it" + raise RuntimeError(msg) + self._loop_task.cancel() + if self._discord_task: + self._discord_task.cancel() + + async def _loop(self) -> None: + # not currently modifiable once running + # differing granularities here, + a delay on retrieving in .get_next() + # ensures closest + sleep_gran = self.granularity * 25 + while (not self._closing) and await asyncio.sleep(sleep_gran, self._ready): + # Lock needed to ensure that once the db is dropping rows + # that a graceful shutdown doesn't drain the queue until entries are in it. + async with self._lock: + # check on both ends of the await that we aren't closing + if self._closing: + return + async with self._pool.acquire() as conn: + scheduled = await _get_scheduled(conn, self.granularity, self._zones) + for s in scheduled: + await self._queue.put(s) + + async def __aexit__(self, exc_type: type[BaseException], exc_value: BaseException, traceback: TracebackType): + if not self._closing: + msg = "Exiting without use of stop_gracefully may cause loss of tasks" + warn(msg, stacklevel=2) + self.stop() + + async def get_next(self) -> ScheduledDispatch: + """ + gets the next scheduled event, waiting if neccessary. + """ + + try: + dispatch = await self._queue.get() + now = datetime.now(ZoneInfo("UTC")) + scheduled_for = dispatch.get_arrow_time() + if now < scheduled_for: + delay = (now - scheduled_for).total_seconds() + await asyncio.sleep(delay) + return dispatch + finally: + self._queue.task_done() + + async def stop_gracefully(self) -> None: + """Notify the internal scheduling loop to stop scheduling and wait for the internal queue to be empty""" + + self._closing = True + # don't remove lock, see note in _loop + async with self._lock: + await self._queue.join() + + async def __aenter__(self) -> Self: + async with self._pool.acquire() as conn: + self._zones = await _setup_db(conn) + self._ready = True + self._loop_task = asyncio.create_task(self._loop()) + self._loop_task.add_done_callback(lambda f: f.exception() if not f.cancelled() else None) + return self + + async def schedule_event( + self, + *, + dispatch_name: str, + dispatch_time: str, + dispatch_zone: str, + guild_id: int | None = None, + user_id: int | None = None, + dispatch_extra: object | None = None, + ) -> str: + """ + Schedule something to be emitted later. + + Parameters + ---------- + dispatch_name: str + The event name to dispatch under. + You may drop all events dispatching to the same name + (such as when removing a feature built ontop of this) + dispatch_time: str + A time string matching the format "%Y-%m-%d %H:%M" (eg. "2023-01-23 13:15") + dispatch_zone: str + The name of the zone for the event. + - Use `UTC` for absolute things scheduled by machines for machines + - Use the name of the zone (eg. US/Eastern) for things scheduled by + humans for machines to do for humans later + + guild_id: int | None + Optionally, an associated guild_id. + This can be used with dispatch_name as a means of querying events + or to drop all scheduled events for a guild. + user_id: int | None + Optionally, an associated user_id. + This can be used with dispatch_name as a means of querying events + or to drop all scheduled events for a user. + dispatch_extra: object | None + Optionally, Extra data to attach to dispatch. + This may be any object serializable by msgspec.msgpack.encode + where the result is round-trip decodable with + msgspec.msgpack.decode(..., strict=True) + + Returns + ------- + str + A uuid for the task, used for unique cancelation. + """ + + self._zones.add(dispatch_zone) + async with self._pool.acquire() as conn: + return await _schedule( + conn, + dispatch_name=dispatch_name, + dispatch_time=dispatch_time, + dispatch_zone=dispatch_zone, + guild_id=guild_id, + user_id=user_id, + dispatch_extra=dispatch_extra, + ) + + async def unschedule_uuid(self, uuid: str) -> None: + """ + Unschedule something by uuid. + This may miss things which should run within the next interval as defined by `granularity` + Non-existent uuids are silently handled. + """ + + async with self._pool.acquire() as conn: + await _drop(conn, UNSCHEDULE_BY_UUID_STATEMENT, (uuid,)) + + async def drop_user_schedule(self, user_id: int) -> None: + """ + Drop all scheduled events for a user (by user_id) + + Intended use case: + removing everything associated to a user who asks for data removal, doesn't exist anymore, or is blacklisted + """ + + async with self._pool.acquire() as conn: + await _drop(conn, UNSCHEDULE_ALL_BY_USER_STATEMENT, (user_id,)) + + async def drop_event_for_user(self, dispatch_name: str, user_id: int) -> None: + """ + Drop scheduled events dispatched to `dispatch_name` for user (by user_id) + + Intended use case example: + A reminder system allowing a user to unschedule all reminders + without effecting how other extensions might use this. + """ + + async with self._pool.acquire() as conn: + await _drop(conn, UNSCHEDULE_ALL_BY_NAME_AND_USER_STATEMENT, (dispatch_name, user_id)) + + async def drop_guild_schedule(self, guild_id: int) -> None: + """ + Drop all scheduled events for a guild (by guild_id) + + Intended use case: + clearing sccheduled events for a guild when leaving it. + """ + + async with self._pool.acquire() as conn: + await _drop(conn, UNSCHEDULE_ALL_BY_GUILD_STATEMENT, (guild_id,)) + + async def drop_event_for_guild(self, dispatch_name: str, guild_id: int) -> None: + """ + Drop scheduled events dispatched to `dispatch_name` for guild (by guild_id) + + Intended use case example: + An admin command allowing clearing all scheduled messages for a guild. + """ + + async with self._pool.acquire() as conn: + await _drop(conn, UNSCHEDULE_ALL_BY_NAME_AND_GUILD_STATEMENT, (dispatch_name, guild_id)) + + async def drop_member_schedule(self, guild_id: int, user_id: int) -> None: + """ + Drop all scheduled events for a guild (by guild_id, user_id) + + Intended use case: + clearing sccheduled events for a member that leaves a guild + """ + + async with self._pool.acquire() as conn: + await _drop(conn, UNSCHEDULE_ALL_BY_MEMBER_STATEMENT, (guild_id, user_id)) + + async def drop_event_for_member(self, dispatch_name: str, guild_id: int, user_id: int) -> None: + """ + Drop scheduled events dispatched to `dispatch_name` for member (by guild_id, user_id) + + Intended use case example: + see user example, but in a guild + """ + + async with self._pool.acquire() as conn: + await _drop(conn, UNSCHEDULE_ALL_BY_NAME_AND_MEMBER_STATEMENT, (dispatch_name, guild_id, user_id)) + + async def list_event_schedule_for_user(self, dispatch_name: str, user_id: int) -> list[ScheduledDispatch]: + """ + list the events of a specified name scheduled for a user (by user_id) + """ + + async with self._pool.acquire() as conn: + return await _query(conn, SELECT_ALL_BY_NAME_AND_USER_STATEMENT, (dispatch_name, user_id)) + + async def list_event_schedule_for_member( + self, + dispatch_name: str, + guild_id: int, + user_id: int, + ) -> list[ScheduledDispatch]: + """ + list the events of a specified name scheduled for a member (by guild_id, user_id) + """ + + async with self._pool.acquire() as conn: + return await _query(conn, SELECT_ALL_BY_NAME_AND_MEMBER_STATEMENT, (dispatch_name, guild_id, user_id)) + + async def list_event_schedule_for_guild(self, dispatch_name: str, guild_id: int) -> list[ScheduledDispatch]: + """ + list the events of a specified name scheduled for a guild (by guild_id) + """ + + async with self._pool.acquire() as conn: + return await _query(conn, SELECT_ALL_BY_NAME_AND_USER_STATEMENT, (dispatch_name, guild_id)) + + @staticmethod + def time_str_from_params(year: int, month: int, day: int, hour: int, minute: int) -> str: + """ + A quick helper for people working with other time representations + (if you have a datetime object, just use strftime with "%Y-%m-%d %H:%M") + """ + + return datetime(year, month, day, hour, minute, tzinfo=ZoneInfo("UTC")).strftime(DATE_FMT) + + +class DiscordBotScheduler(Scheduler): + """Scheduler with convienence dispatches compatible with discord.py's commands extenstion + + Note: long-term compatability not guaranteed, dispatch isn't covered by discord.py's version guarantees. + """ + + async def _bot_dispatch_loop(self, bot: BotLike, wait_until_ready: bool) -> None: + if not self._ready: + msg = "context manager, use it" + raise RuntimeError(msg) + + if wait_until_ready: + await bot.wait_until_ready() + + while scheduled := await self.get_next(): + bot.dispatch(f"sinbad_scheduler_{scheduled.dispatch_name}", scheduled) + + def start_dispatch_to_bot(self, bot: BotLike, *, wait_until_ready: bool = True) -> None: + """ + Starts dispatching events to the bot. + + Events will dispatch under a name with the following format: + + sinbad_scheduler_{dispatch_name} + + where dispatch_name is set when submitting events to schedule. + This is done to avoid potential conflicts with existing or future event names, + as well as anyone else building a scheduler on top of bot.dispatch + (hence author name inclusion) and someone deciding to use both. + + Listeners get a single object as their argument, `ScheduledDispatch` + + to listen for an event you submit with `reminder` as the name + + @commands.Cog.listener("on_sinbad_scheduler_reminder") + async def some_listener(self, scheduled_object: ScheduledDispatch): + ... + + Events will not start being sent until the bot is considered ready if `wait_until_ready` is True + """ + + if not self._ready: + msg = "context manager, use it" + raise RuntimeError(msg) + + self._discord_task = asyncio.create_task(self._bot_dispatch_loop(bot, wait_until_ready)) + self._discord_task.add_done_callback(lambda f: f.exception() if not f.cancelled() else None) diff --git a/src/beira/utils/extras/time.py b/src/beira/utils/extras/time.py new file mode 100644 index 0000000..a184e77 --- /dev/null +++ b/src/beira/utils/extras/time.py @@ -0,0 +1,421 @@ +from __future__ import annotations + +import datetime +import re +from typing import Any, Self +from zoneinfo import ZoneInfo + +import discord +import parsedatetime as pdt +from dateutil.relativedelta import relativedelta +from discord.ext import commands + +import beira + +from ... import Context # noqa: TID252 +from .formats import human_join, plural + + +# Monkey patch mins and secs into the units +units = pdt.pdtLocales["en_US"].units +units["minutes"].append("mins") +units["seconds"].append("secs") + + +UTC = ZoneInfo("UTC") + + +class ShortTime: + COMPILED = re.compile( + """ + (?:(?P[0-9])(?:years?|y))? # e.g. 2y + (?:(?P[0-9]{1,2})(?:months?|mon?))? # e.g. 2months + (?:(?P[0-9]{1,4})(?:weeks?|w))? # e.g. 10w + (?:(?P[0-9]{1,5})(?:days?|d))? # e.g. 14d + (?:(?P[0-9]{1,5})(?:hours?|hr?s?))? # e.g. 12h + (?:(?P[0-9]{1,5})(?:minutes?|m(?:ins?)?))? # e.g. 10m + (?:(?P[0-9]{1,5})(?:seconds?|s(?:ecs?)?))? # e.g. 15s + """, + re.VERBOSE, + ) + + DISCORD_FMT = re.compile(r"[0-9]+)(?:\:?[RFfDdTt])?>") + + def __init__( + self, + argument: str, + *, + now: datetime.datetime | None = None, + tzinfo: datetime.tzinfo = UTC, + ): + match = self.COMPILED.fullmatch(argument) + if match is None or not match.group(0): + match = self.DISCORD_FMT.fullmatch(argument) + if match is not None: + self.dt = datetime.datetime.fromtimestamp(int(match.group("ts")), tz=UTC) + + if tzinfo not in {datetime.UTC, UTC}: + self.dt = self.dt.astimezone(tzinfo) + return + + msg = "invalid time provided" + raise commands.BadArgument(msg) + + data = {k: int(v) for k, v in match.groupdict(default=0).items()} + now = now or datetime.datetime.now(UTC) + self.dt = now + relativedelta(**data) # type: ignore # None of the regex groups currently fill the date fields. + if tzinfo not in {datetime.UTC, UTC}: + self.dt = self.dt.astimezone(tzinfo) + + @classmethod + async def convert(cls, ctx: Context, argument: str) -> Self: + tzinfo = await ctx.bot.get_user_tzinfo(ctx.author.id) + return cls(argument, now=ctx.message.created_at, tzinfo=tzinfo) + + +class RelativeDelta(discord.app_commands.Transformer, commands.Converter[relativedelta]): + @classmethod + def __do_conversion(cls, argument: str) -> relativedelta: + match = ShortTime.COMPILED.fullmatch(argument) + if match is None or not match.group(0): + msg = "invalid time provided" + raise ValueError(msg) + + data = {k: int(v) for k, v in match.groupdict(default=0).items()} + return relativedelta(**data) # type: ignore # None of the regex groups currently fill the date fields. + + async def convert(self, ctx: Context, argument: str) -> relativedelta: # type: ignore # Custom context. + try: + return self.__do_conversion(argument) + except ValueError as e: + raise commands.BadArgument(str(e)) from None + + async def transform(self, interaction: discord.Interaction, value: str) -> relativedelta: + try: + return self.__do_conversion(value) + except ValueError as e: + raise discord.app_commands.AppCommandError(str(e)) from None + + +class HumanTime: + calendar = pdt.Calendar(version=pdt.VERSION_CONTEXT_STYLE) + + def __init__( + self, + argument: str, + *, + now: datetime.datetime | None = None, + tzinfo: datetime.tzinfo = UTC, + ): + now = now or datetime.datetime.now(tzinfo) + dt, status = self.calendar.parseDT(argument, sourceTime=now, tzinfo=None) + + assert isinstance(status, pdt.pdtContext) + + if not status.hasDateOrTime: + msg = 'invalid time provided, try e.g. "tomorrow" or "3 days"' + raise commands.BadArgument(msg) + + if not status.hasTime: + # replace it with the current time + dt = dt.replace(hour=now.hour, minute=now.minute, second=now.second, microsecond=now.microsecond) + + self.dt: datetime.datetime = dt.replace(tzinfo=tzinfo) + if now.tzinfo is None: + now = now.replace(tzinfo=UTC) + self._past: bool = self.dt < now + + @classmethod + async def convert(cls, ctx: Context, argument: str) -> Self: + tzinfo = await ctx.bot.get_user_tzinfo(ctx.author.id) + return cls(argument, now=ctx.message.created_at, tzinfo=tzinfo) + + +class Time(HumanTime): + def __init__( + self, + argument: str, + *, + now: datetime.datetime | None = None, + tzinfo: datetime.tzinfo = UTC, + ): + try: + o = ShortTime(argument, now=now, tzinfo=tzinfo) + except Exception: # noqa: BLE001 + super().__init__(argument, now=now, tzinfo=tzinfo) + else: + self.dt = o.dt + self._past = False + + +class FutureTime(Time): + def __init__( + self, + argument: str, + *, + now: datetime.datetime | None = None, + tzinfo: datetime.tzinfo = UTC, + ): + super().__init__(argument, now=now, tzinfo=tzinfo) + + if self._past: + msg = "this time is in the past" + raise commands.BadArgument(msg) + + +class BadTimeTransform(discord.app_commands.AppCommandError): + pass + + +class TimeTransformer(discord.app_commands.Transformer): + async def transform(self, interaction: discord.Interaction[beira.Beira], value: str) -> datetime.datetime: + tzinfo = await interaction.client.get_user_tzinfo(interaction.user.id) + + now = interaction.created_at.astimezone(tzinfo) + try: + short = ShortTime(value, now=now, tzinfo=tzinfo) + except commands.BadArgument: + try: + human = FutureTime(value, now=now, tzinfo=tzinfo) + except commands.BadArgument as e: + raise BadTimeTransform(str(e)) from None + else: + return human.dt + else: + return short.dt + + +class FriendlyTimeResult: + __slots__ = ("dt", "arg") + + def __init__(self, dt: datetime.datetime): + self.dt: datetime.datetime = dt + self.arg: str = "" + + async def ensure_constraints( + self, + ctx: Context, + uft: UserFriendlyTime, + now: datetime.datetime, + remaining: str, + ) -> None: + if self.dt < now: + msg = "This time is in the past." + raise commands.BadArgument(msg) + + if not remaining: + if uft.default is None: + msg = "Missing argument after the time." + raise commands.BadArgument(msg) + remaining = uft.default + + if uft.converter is not None: + self.arg = await uft.converter.convert(ctx, remaining) + else: + self.arg = remaining + + +class UserFriendlyTime(commands.Converter[FriendlyTimeResult]): + """That way quotes aren't absolutely necessary.""" + + def __init__( + self, + converter: type[commands.Converter[str]] | commands.Converter[str] | None = None, + *, + default: Any = None, + ): + if issubclass(converter, commands.Converter): # type: ignore [reportUnnecessaryIsInstance] + converter = converter() + + if converter is not None and not isinstance(converter, commands.Converter): # type: ignore [reportUnnecessaryIsInstance] + msg = "commands.Converter subclass necessary." + raise TypeError(msg) + + self.converter: commands.Converter[str] | None = converter + self.default: Any = default + + async def convert(self, ctx: Context, argument: str) -> FriendlyTimeResult: # type: ignore # Custom context. # noqa: PLR0915 + calendar = HumanTime.calendar + regex = ShortTime.COMPILED + now = ctx.message.created_at + + tzinfo = await ctx.bot.get_user_tzinfo(ctx.author.id) + + assert isinstance(tzinfo, datetime.tzinfo) + + match = regex.match(argument) + if match is not None and match.group(0): + data = {k: int(v) for k, v in match.groupdict(default=0).items()} + remaining = argument[match.end() :].strip() + dt = now + relativedelta(**data) # type: ignore # None of the regex groups currently fill the date fields. + result = FriendlyTimeResult(dt.astimezone(tzinfo)) + await result.ensure_constraints(ctx, self, now, remaining) + return result + + if match is None or not match.group(0): + match = ShortTime.DISCORD_FMT.match(argument) + if match is not None: + result = FriendlyTimeResult( + datetime.datetime.fromtimestamp(int(match.group("ts")), tz=UTC).astimezone(tzinfo) + ) + remaining = argument[match.end() :].strip() + await result.ensure_constraints(ctx, self, now, remaining) + return result + + # apparently nlp does not like "from now" + # it likes "from x" in other cases though so let me handle the 'now' case + if argument.endswith("from now"): + argument = argument[:-8].strip() + + if argument[0:2] == "me" and argument[0:6] in ("me to ", "me in ", "me at "): + argument = argument[6:] + + # Have to adjust the timezone so pdt knows how to handle things like "tomorrow at 6pm" in an aware way + now = now.astimezone(tzinfo) + elements = calendar.nlp(argument, sourceTime=now) + if elements is None or len(elements) == 0: + msg = 'Invalid time provided, try e.g. "tomorrow" or "3 days".' + raise commands.BadArgument(msg) + + # handle the following cases: + # "date time" foo + # date time foo + # foo date time + + # first the first two cases: + dt, status, begin, end, _dt_string = elements[0] + assert isinstance(status, pdt.pdtContext) + + if not status.hasDateOrTime: + msg = 'Invalid time provided, try e.g. "tomorrow" or "3 days".' + raise commands.BadArgument(msg) + + if begin not in (0, 1) and end != len(argument): + msg = ( + "Time is either in an inappropriate location, which " + "must be either at the end or beginning of your input, " + "or I just flat out did not understand what you meant. Sorry." + ) + raise commands.BadArgument(msg) + + dt = dt.replace(tzinfo=tzinfo) + if not status.hasTime: + # replace it with the current time + dt = dt.replace(hour=now.hour, minute=now.minute, second=now.second, microsecond=now.microsecond) + + if status.hasTime and not status.hasDate and dt < now: + # if it's in the past, and it has a time but no date, + # assume it's for the next occurrence of that time + dt = dt + datetime.timedelta(days=1) + + # if midnight is provided, just default to next day + if status.accuracy == pdt.pdtContext.ACU_HALFDAY: + dt = dt + datetime.timedelta(days=1) + + result = FriendlyTimeResult(dt) + remaining = "" + + if begin in (0, 1): + if begin == 1: + # check if it's quoted: + if argument[0] != '"': + msg = "Expected quote before time input..." + raise commands.BadArgument(msg) + + if not (end < len(argument) and argument[end] == '"'): + msg = "If the time is quoted, you must unquote it." + raise commands.BadArgument(msg) + + remaining = argument[end + 1 :].lstrip(" ,.!") + else: + remaining = argument[end:].lstrip(" ,.!") + elif len(argument) == end: + remaining = argument[:begin].strip() + + await result.ensure_constraints(ctx, self, now, remaining) + return result + + +def human_timedelta( + dt: datetime.datetime, + *, + source: datetime.datetime | None = None, + accuracy: int | None = 3, + brief: bool = False, + suffix: bool = True, +) -> str: + now = source or datetime.datetime.now(UTC) + if dt.tzinfo is None: + dt = dt.replace(tzinfo=UTC) + + if now.tzinfo is None: + now = now.replace(tzinfo=UTC) + + # Microsecond free zone + now = now.replace(microsecond=0) + dt = dt.replace(microsecond=0) + + # Make sure they're both in the timezone + now = now.astimezone(UTC) + dt = dt.astimezone(UTC) + + # This implementation uses relativedelta instead of the much more obvious + # divmod approach with seconds because the seconds approach is not entirely + # accurate once you go over 1 week in terms of accuracy since you have to + # hardcode a month as 30 or 31 days. + # A query like "11 months" can be interpreted as "!1 months and 6 days" + if dt > now: + delta = relativedelta(dt, now) + output_suffix = "" + else: + delta = relativedelta(now, dt) + output_suffix = " ago" if suffix else "" + + attrs = [ + ("year", "y"), + ("month", "mo"), + ("day", "d"), + ("hour", "h"), + ("minute", "m"), + ("second", "s"), + ] + + output: list[str] = [] + for attr, brief_attr in attrs: + elem = getattr(delta, attr + "s") + if not elem: + continue + + if attr == "day": + weeks = delta.weeks + if weeks: + elem -= weeks * 7 + if not brief: + output.append(format(plural(weeks), "week")) + else: + output.append(f"{weeks}w") + + if elem <= 0: + continue + + if brief: + output.append(f"{elem}{brief_attr}") + else: + output.append(format(plural(elem), attr)) + + if accuracy is not None: + output = output[:accuracy] + + if len(output) == 0: + return "now" + + if not brief: + return human_join(output, final="and") + output_suffix + + return " ".join(output) + output_suffix + + +def format_relative(dt: datetime.datetime) -> str: + if dt.tzinfo is None: + dt = dt.replace(tzinfo=UTC) + return discord.utils.format_dt(dt, "R") diff --git a/core/utils/log.py b/src/beira/utils/log.py similarity index 66% rename from core/utils/log.py rename to src/beira/utils/log.py index 146296a..17733ed 100644 --- a/core/utils/log.py +++ b/src/beira/utils/log.py @@ -1,14 +1,11 @@ -"""custom_logging.py: Based on the work of Umbra, this is Beira's logging system. +"""Beira's logging setup. -References ----------- -https://github.com/AbstractUmbra/Mipha/blob/main/bot.py#L91 +Based on Umbra's work: https://github.com/AbstractUmbra/Mipha/blob/main/bot.py#L91 """ import asyncio -import copy import logging -from logging.handlers import RotatingFileHandler +from logging.handlers import QueueHandler, RotatingFileHandler from pathlib import Path from typing import Self @@ -18,42 +15,8 @@ __all__ = ("LoggingManager",) -class AsyncQueueHandler(logging.Handler): - # Copied api and implementation of stdlib QueueHandler. - def __init__(self, queue: asyncio.Queue[logging.LogRecord]) -> None: - logging.Handler.__init__(self) - self.queue = queue - - def enqueue(self, record: logging.LogRecord) -> None: - self.queue.put_nowait(record) - - def prepare(self, record: logging.LogRecord) -> logging.LogRecord: - msg = self.format(record) - record = copy.copy(record) - record.message = msg - record.msg = msg - record.args = None - record.exc_info = None - record.exc_text = None - record.stack_info = None - return record - - def emit(self, record: logging.LogRecord) -> None: - try: - self.enqueue(self.prepare(record)) - except Exception: # noqa: BLE001 - self.handleError(record) - - class RemoveNoise(logging.Filter): - """Filter for custom logging system. - - Copied from Umbra. - - References - ---------- - https://github.com/AbstractUmbra/Mipha/blob/main/bot.py#L91 - """ + """Filter for discord.state warnings about "referencing an unknown".""" def __init__(self) -> None: super().__init__(name="discord.state") @@ -62,12 +25,9 @@ def filter(self, record: logging.LogRecord) -> bool: return not (record.levelname == "WARNING" and "referencing an unknown" in record.msg) -# TODO: Personalize logging beyond Umbra's work. class LoggingManager: """Custom logging system. - Copied from Umbra with minimal customization so far: https://github.com/AbstractUmbra/Mipha/blob/main/bot.py#L109 - Parameters ---------- stream: `bool`, default=True @@ -83,6 +43,8 @@ class LoggingManager: A path to the directory for all log files. stream: `bool` A boolean indicating whether the logs should be output to a stream. + log_queue: `asyncio.Queue[logging.LogRecord]` + An asyncio queue with logs to send to a logging webhook. """ def __init__(self, *, stream: bool = True) -> None: @@ -126,7 +88,7 @@ def __enter__(self) -> Self: self.log.addHandler(stream_handler) # Add a queue handler. - queue_handler = AsyncQueueHandler(self.log_queue) + queue_handler = QueueHandler(self.log_queue) self.log.addHandler(queue_handler) return self diff --git a/core/utils/misc.py b/src/beira/utils/misc.py similarity index 91% rename from core/utils/misc.py rename to src/beira/utils/misc.py index abe0ac2..fe9d393 100644 --- a/core/utils/misc.py +++ b/src/beira/utils/misc.py @@ -1,4 +1,4 @@ -"""misc.py: Miscellaneous utility functions that might come in handy.""" +"""Miscellaneous utility functions that might come in handy.""" import logging import re @@ -25,13 +25,13 @@ def __init__(self, logger: logging.Logger | None = None): self.logger = logger def __enter__(self): - self.total_time = time.perf_counter() + self.elapsed = time.perf_counter() return self def __exit__(self, *exc_info: object) -> None: - self.total_time = time.perf_counter() - self.total_time + self.elapsed = time.perf_counter() - self.elapsed if self.logger: - self.logger.info("Time: %.3f seconds", self.total_time) + self.logger.info("Time: %.3f seconds", self.elapsed) _BEFORE_WS = re.compile(r"^([\s]+)") diff --git a/core/utils/pagination.py b/src/beira/utils/pagination.py similarity index 96% rename from core/utils/pagination.py rename to src/beira/utils/pagination.py index c45c6a2..1bc005d 100644 --- a/core/utils/pagination.py +++ b/src/beira/utils/pagination.py @@ -1,12 +1,7 @@ -""" -pagination.py: A collection of views that together create a view that uses embeds, is paginated, and allows -easy navigation. -""" - -from __future__ import annotations +"""A collection of views that together create a view that uses embeds, is paginated, and allows easy navigation.""" +import abc import asyncio -from abc import ABC, abstractmethod from collections.abc import Sequence from typing import Any, Self @@ -96,7 +91,7 @@ class PageSeekModal(discord.ui.Modal, title="Page Jump"): required=True, ) - def __init__(self, *, parent: PaginatedEmbedView[Any], **kwargs: Any) -> None: + def __init__(self, *, parent: "PaginatedEmbedView[Any]", **kwargs: Any) -> None: super().__init__(**kwargs) self.parent = parent self.interaction: discord.Interaction | None = None @@ -115,7 +110,7 @@ async def on_submit(self, interaction: discord.Interaction, /) -> None: self.stop() -class PaginatedEmbedView[_LT](ABC, OwnedView): +class PaginatedEmbedView[_LT](abc.ABC, OwnedView): """A view that handles paginated embeds and page buttons. Parameters @@ -171,7 +166,7 @@ async def on_timeout(self) -> None: await self.message.edit(view=self) self.stop() - @abstractmethod + @abc.abstractmethod def format_page(self) -> discord.Embed: """|maybecoro| @@ -291,7 +286,7 @@ async def turn_to_last(self, interaction: discord.Interaction, _: discord.ui.But await self.update_page(interaction) -class PaginatedSelectView[_LT](ABC, OwnedView): +class PaginatedSelectView[_LT](abc.ABC, OwnedView): """A view that handles paginated embeds and page buttons. Parameters @@ -340,7 +335,7 @@ async def on_timeout(self) -> None: await self.message.edit(view=self) self.stop() - @abstractmethod + @abc.abstractmethod def format_page(self) -> discord.Embed: """|maybecoro| @@ -350,7 +345,7 @@ def format_page(self) -> discord.Embed: msg = "Page formatting must be set up in a subclass." raise NotImplementedError(msg) - @abstractmethod + @abc.abstractmethod def populate_select(self) -> None: """Populates the select with relevant options."""