From 65a9176d46d1c131abab8052e1462ab6975287e7 Mon Sep 17 00:00:00 2001 From: Sachaa-Thanasius <111999343+Sachaa-Thanasius@users.noreply.github.com> Date: Wed, 24 Jul 2024 13:57:51 -0400 Subject: [PATCH] Push remaining changes. --- src/beira/bot.py | 89 +++++++++++++++++------------------------ src/beira/checks.py | 2 +- src/beira/exts/_dev.py | 29 ++++++-------- src/beira/exts/admin.py | 6 +-- src/beira/exts/lol.py | 2 +- src/beira/exts/other.py | 30 +++++++------- src/beira/utils/misc.py | 48 +++++++++++----------- 7 files changed, 92 insertions(+), 114 deletions(-) diff --git a/src/beira/bot.py b/src/beira/bot.py index bbb55a8..198c811 100644 --- a/src/beira/bot.py +++ b/src/beira/bot.py @@ -24,7 +24,7 @@ from .exts import EXTENSIONS from .scheduler import Scheduler from .tree import HookableTree -from .utils import LoggingManager, Pool_alias, conn_init, copy_annotations +from .utils import LoggingManager, Pool_alias, catchtime, conn_init, copy_annotations LOGGER = logging.getLogger(__name__) @@ -117,7 +117,7 @@ def __init__( self.ao3_client = ao3.Client(session=self.web_session) # Things to load before connecting to the Gateway. - self.prefix_cache: dict[int, list[str]] = {} + self.prefixes: dict[int, list[str]] = {} self.blocked_guilds: set[int] = set() self.blocked_users: set[int] = set() @@ -140,10 +140,39 @@ async def on_ready(self) -> None: LOGGER.info("Logged in as %s (ID: %s)", self.user, self.user.id) async def setup_hook(self) -> None: + # Start up the scheduler. self.scheduler.start_discord_dispatch(self) - await self._load_guild_prefixes() - await self._load_blocked_entities() - await self._load_extensions() + + # Load guild prefixes. + prefix_records = await self.db_pool.fetch("SELECT guild_id, prefix FROM guild_prefixes;") + for entry in prefix_records: + self.prefixes.setdefault(entry["guild_id"], []).append(entry["prefix"]) + + LOGGER.info("Loaded all guild prefixes.") + + # Load all blocked users and guilds. + blocked_user_records = await self.db_pool.fetch("SELECT user_id FROM users WHERE is_blocked;") + self.blocked_users |= {record["user_id"] for record in blocked_user_records} + + blocked_guild_records = await self.db_pool.fetch("SELECT guild_id FROM guilds WHERE is_blocked;") + self.blocked_guilds |= {record["guild_id"] for record in blocked_guild_records} + + # Load extensions/cogs. If a list of initial ones isn't provided, all extensions are loaded by default. + await self.load_extension("jishaku") + + exts_to_load = self.initial_extensions or EXTENSIONS + with catchtime() as all_exts_time: + for extension in exts_to_load: + start_time = time.perf_counter() + try: + await self.load_extension(extension) + except commands.ExtensionError as err: + LOGGER.exception("Failed to load extension: %s", extension, exc_info=err) + else: + end_time = time.perf_counter() + LOGGER.info("Loaded extension: %s -- Time: %.5f", extension, end_time - start_time) + + LOGGER.info("Total extension loading time: Time: %.5f", all_exts_time.time) # Connect to lavalink node(s). node = wavelink.Node( @@ -165,7 +194,7 @@ async def close(self) -> None: await super().close() async def get_prefix(self, message: discord.Message, /) -> list[str] | str: - return self.prefix_cache.get(message.guild.id, "$") if message.guild else "$" + return self.prefixes.get(message.guild.id, "$") if message.guild else "$" @overload async def get_context(self, origin: discord.Message | discord.Interaction, /) -> Context: ... @@ -251,57 +280,11 @@ async def on_command_error(self, context: Context, exception: commands.CommandEr embed.add_field(name="Channel", value=f"{context.channel}", inline=False) LOGGER.error("Exception in command %s", context.command, exc_info=exception, extra={"embed": embed}) - async def _load_blocked_entities(self) -> None: - """Load all blocked users and guilds from the bot database.""" - - user_records = await self.db_pool.fetch("SELECT user_id FROM users WHERE is_blocked;") - self.blocked_users.update([record["user_id"] for record in user_records]) - - guild_records = await self.db_pool.fetch("SELECT guild_id FROM guilds WHERE is_blocked;") - self.blocked_guilds.update([record["guild_id"] for record in guild_records]) - - async def _load_guild_prefixes(self) -> None: - """Load all prefixes from the bot database.""" - - try: - db_prefixes = await self.db_pool.fetch("SELECT guild_id, prefix FROM guild_prefixes;") - except OSError: - LOGGER.exception("Couldn't load guild prefixes from the database. Using default(s) instead.") - else: - for entry in db_prefixes: - self.prefix_cache.setdefault(entry["guild_id"], []).append(entry["prefix"]) - - LOGGER.info("(Re)loaded all guild prefixes.") - - async def _load_extensions(self) -> None: - """Loads extensions/cogs. - - If a list of initial ones isn't provided, all extensions are loaded by default. - """ - - await self.load_extension("jishaku") - - exts_to_load = self.initial_extensions or EXTENSIONS - all_exts_start_time = time.perf_counter() - for extension in exts_to_load: - start_time = time.perf_counter() - try: - await self.load_extension(extension) - except commands.ExtensionError as err: - LOGGER.exception("Failed to load extension: %s", extension, exc_info=err) - else: - end_time = time.perf_counter() - LOGGER.info("Loaded extension: %s -- Time: %.5f", extension, end_time - start_time) - all_exts_end_time = time.perf_counter() - LOGGER.info("Total extension loading time: Time: %.5f", all_exts_end_time - all_exts_start_time) - async def _load_special_friends(self) -> None: await self.wait_until_ready() friends_ids: list[int] = self.config.discord.friend_ids - for user_id in friends_ids: - if user_obj := self.get_user(user_id): - self.special_friends[user_obj.name] = user_id + self.special_friends |= {user.name: user_id for user_id in friends_ids if (user := self.get_user(user_id))} @async_lru.alru_cache() async def get_user_timezone(self, user_id: int) -> str | None: diff --git a/src/beira/checks.py b/src/beira/checks.py index bd0240f..23b28aa 100644 --- a/src/beira/checks.py +++ b/src/beira/checks.py @@ -117,7 +117,7 @@ async def predicate(ctx: Context) -> bool: if not (await ctx.bot.is_owner(ctx.author)): if ctx.author.id in ctx.bot.blocked_users: raise UserIsBlocked - if ctx.guild and ctx.guild.id in ctx.bot.blocked_guilds: + if ctx.guild and (ctx.guild.id in ctx.bot.blocked_guilds): raise GuildIsBlocked return True diff --git a/src/beira/exts/_dev.py b/src/beira/exts/_dev.py index f745695..112bce9 100644 --- a/src/beira/exts/_dev.py +++ b/src/beira/exts/_dev.py @@ -2,7 +2,6 @@ import logging from collections.abc import Generator -from time import perf_counter from typing import Any, Literal import discord @@ -11,6 +10,7 @@ from discord.ext import commands import beira +from beira.utils import catchtime from . import EXTENSIONS @@ -363,23 +363,22 @@ async def reload(self, ctx: beira.Context, extension: str) -> None: reloaded: list[str] = [] failed: list[str] = [] - start_time = perf_counter() - for ext in sorted(self.bot.extensions): - try: - await self.bot.reload_extension(ext) - except commands.ExtensionError as err: - failed.append(ext) - LOGGER.exception("Couldn't reload extension: %s", ext, exc_info=err) - else: - reloaded.append(ext) - end_time = perf_counter() + with catchtime() as reload_time: + for ext in sorted(self.bot.extensions): + try: + await self.bot.reload_extension(ext) + except commands.ExtensionError as err: + failed.append(ext) + LOGGER.exception("Couldn't reload extension: %s", ext, exc_info=err) + else: + reloaded.append(ext) ratio_succeeded = f"{len(reloaded)}/{len(self.bot.extensions)}" LOGGER.info("Attempted to reload all extensions. Successful: %s.", ratio_succeeded) embed.add_field(name="Reloaded", value="\n".join(reloaded)) embed.add_field(name="Failed to reload", value="\n".join(failed)) - embed.set_footer(text=f"Time taken: {end_time - start_time:.3f}s") + embed.set_footer(text=f"Time taken: {reload_time:.3f}s") await ctx.send(embed=embed, ephemeral=True) @@ -538,10 +537,8 @@ async def cmd_tree(self, ctx: beira.Context) -> None: """Display all bot commands in a pretty tree-like format.""" def walk_commands_with_indent(group: commands.GroupMixin[Any], indent_level: int = 0) -> Generator[str]: - indent_level += 4 - for cmd in group.commands: - if indent_level > 4: # noqa: SIM108 + if indent_level != 0: # noqa: SIM108 indent = (indent_level - 1) * "─" else: indent = "" @@ -549,7 +546,7 @@ def walk_commands_with_indent(group: commands.GroupMixin[Any], indent_level: int yield f"└{indent}{cmd.qualified_name}" if isinstance(cmd, commands.GroupMixin): - yield from walk_commands_with_indent(cmd, indent_level) + yield from walk_commands_with_indent(cmd, indent_level + 4) await ctx.send("\n".join(("```", "Beira", *walk_commands_with_indent(ctx.bot), "```"))) diff --git a/src/beira/exts/admin.py b/src/beira/exts/admin.py index 9e5bf1b..cab6fd3 100644 --- a/src/beira/exts/admin.py +++ b/src/beira/exts/admin.py @@ -76,7 +76,7 @@ async def prefixes_add(self, ctx: beira.GuildContext, *, new_prefix: str) -> Non await conn.execute(guild_stmt, ctx.guild.id) await conn.execute(prefix_stmt, ctx.guild.id, new_prefix) # Update it in the cache. - self.bot.prefix_cache.setdefault(ctx.guild.id, []).append(new_prefix) + self.bot.prefixes.setdefault(ctx.guild.id, []).append(new_prefix) await ctx.send(f"'{new_prefix}' has been registered as a prefix in this guild.") @@ -103,7 +103,7 @@ async def prefixes_remove(self, ctx: beira.GuildContext, *, old_prefix: str) -> # Update it in the database and the cache. prefix_stmt = "DELETE FROM guild_prefixes WHERE guild_id = $1 AND prefix = $2;" await self.bot.db_pool.execute(prefix_stmt, ctx.guild.id, old_prefix) - self.bot.prefix_cache.setdefault(ctx.guild.id, [old_prefix]).remove(old_prefix) + self.bot.prefixes.setdefault(ctx.guild.id, [old_prefix]).remove(old_prefix) await ctx.send(f"'{old_prefix}' has been unregistered as a prefix in this guild.") @@ -117,7 +117,7 @@ async def prefixes_reset(self, ctx: beira.GuildContext) -> None: # Update it in the database and the cache. prefix_stmt = """DELETE FROM guild_prefixes WHERE guild_id = $1;""" await self.bot.db_pool.execute(prefix_stmt, ctx.guild.id) - self.bot.prefix_cache.setdefault(ctx.guild.id, []).clear() + self.bot.prefixes.setdefault(ctx.guild.id, []).clear() content = "The prefix(es) for this guild have been reset. Now only accepting the default prefix: `$`." await ctx.send(content) diff --git a/src/beira/exts/lol.py b/src/beira/exts/lol.py index 09f3484..182b2fc 100644 --- a/src/beira/exts/lol.py +++ b/src/beira/exts/lol.py @@ -36,7 +36,7 @@ async def update_op_gg_profiles(urls: list[str]) -> None: # Create the webdriver. with GECKODRIVER_LOGS.open(mode="a", encoding="utf-8") as log_file: - service = services.Geckodriver(binary=str(GECKODRIVER), log_file=log_file) # type: ignore # attrs class + service = services.Geckodriver(binary=str(GECKODRIVER), log_file=log_file) # type: ignore # Untyped class browser = browsers.Firefox(**{"moz:firefoxOptions": {"args": ["-headless"]}}) async with get_session(service, browser) as session: diff --git a/src/beira/exts/other.py b/src/beira/exts/other.py index a072580..58ad369 100644 --- a/src/beira/exts/other.py +++ b/src/beira/exts/other.py @@ -7,7 +7,6 @@ import random import re import tempfile -import time from io import BytesIO, StringIO import discord @@ -16,6 +15,7 @@ from discord.ext import commands import beira +from beira.utils import catchtime INSPIROBOT_API_URL = "https://inspirobot.me/api" @@ -205,29 +205,27 @@ async def quote(self, ctx: beira.Context, *, message: discord.Message) -> None: async def ping_(self, ctx: beira.Context) -> None: """Display the time necessary for the bot to communicate with Discord.""" - ws_ping = self.bot.latency * 1000 + ws_ping = self.bot.latency - start_time = time.perf_counter() - await ctx.typing() - typing_ping = (time.perf_counter() - start_time) * 1000 + with catchtime() as typing_ping: + await ctx.typing() - start_time = time.perf_counter() - await self.bot.db_pool.fetch("SELECT * FROM guilds;") - db_ping = (time.perf_counter() - start_time) * 1000 + with catchtime() as db_ping: + await self.bot.db_pool.fetch("SELECT * FROM guilds;") - start_time = time.perf_counter() - message = await ctx.send(embed=discord.Embed(title="Ping...")) - msg_ping = (time.perf_counter() - start_time) * 1000 + with catchtime() as msg_ping: + message = await ctx.send(embed=discord.Embed(title="Ping...")) + total_time = sum((ws_ping, *(catch.time for catch in (typing_ping, db_ping, msg_ping)))) pong_embed = ( discord.Embed(title="Pong! \N{TABLE TENNIS PADDLE AND BALL}") - .add_field(name="Websocket", value=f"```json\n{ws_ping:.2f} ms\n```") - .add_field(name="Typing", value=f"```json\n{typing_ping:.2f} ms\n```") + .add_field(name="Websocket", value=f"```json\n{ws_ping * 1000:.2f} ms\n```") + .add_field(name="Typing", value=f"```json\n{typing_ping.time * 1000:.2f} ms\n```") .add_field(name="\u200b", value="\u200b") - .add_field(name="Database", value=f"```json\n{db_ping:.2f} ms\n```") - .add_field(name="Message", value=f"```json\n{msg_ping:.2f} ms\n```") + .add_field(name="Database", value=f"```json\n{db_ping.time * 1000:.2f} ms\n```") + .add_field(name="Message", value=f"```json\n{msg_ping.time * 1000:.2f} ms\n```") .add_field(name="\u200b", value="\u200b") - .add_field(name="Average", value=f"```json\n{(ws_ping + typing_ping + db_ping + msg_ping) / 4:.2f} ms\n```") + .add_field(name="Average", value=f"```json\n{total_time * 1000 / 4:.2f} ms\n```") ) await message.edit(embed=pong_embed) diff --git a/src/beira/utils/misc.py b/src/beira/utils/misc.py index 824bc2e..45e1c78 100644 --- a/src/beira/utils/misc.py +++ b/src/beira/utils/misc.py @@ -24,16 +24,16 @@ class catchtime: def __init__(self, logger: logging.Logger | None = None): self.logger = logger - self.elapsed = 0.0 + self.time = 0.0 def __enter__(self): - self.elapsed = time.perf_counter() + self.time = time.perf_counter() return self def __exit__(self, *exc_info: object) -> None: - self.elapsed = time.perf_counter() - self.elapsed + self.time = time.perf_counter() - self.time if self.logger: - self.logger.info("Time: %.3f seconds", self.elapsed) + self.logger.info("Time: %.3f seconds", self.time) _BEFORE_WS = re.compile(r"^([\s]+)") @@ -48,7 +48,7 @@ def html_to_markdown(node: lxml.html.HtmlElement, *, include_spans: bool = False italics_marker: str = "_" if base_url is not None: - node.make_links_absolute("".join(base_url.partition(".com/wiki/")[0:-1]), resolve_base_href=True) + node.make_links_absolute("".join(base_url.partition(".com/wiki/")[0:2]), resolve_base_href=True) for child in node.iter(): if child.text: @@ -59,27 +59,27 @@ def html_to_markdown(node: lxml.html.HtmlElement, *, include_spans: bool = False else: before_ws = after_ws = child_text = "" - if child.tag in {"i", "em"}: - text.append(f"{before_ws}{italics_marker}{child_text}{italics_marker}{after_ws}") - if italics_marker == "*": # type: ignore # Pyright bug? - italics_marker = "_" - elif child.tag in {"b", "strong"}: - if text and text[-1].endswith("*"): - text.append("\u200b") - text.append(f"{before_ws}**{child_text}**{after_ws}") - elif child.tag == "a": - # No markup for links - if base_url is None: + match child.tag: + case "i" | "em": + text.append(f"{before_ws}{italics_marker}{child_text}{italics_marker}{after_ws}") + italics_marker = "_" if italics_marker == "*" else "*" + case "b" | "strong": + if text and text[-1].endswith("*"): + text.append("\u200b") + text.append(f"{before_ws}**{child_text}**{after_ws}") + case "a" if base_url is None: # No markup for incomplete links text.append(f"{before_ws}{child_text}{after_ws}") - else: + case "a": text.append(f"{before_ws}[{child.text}]({child.attrib['href']}){after_ws}") - elif child.tag == "p": - text.append(f"\n{child_text}\n") - elif include_spans and child.tag == "span": - if len(child) > 1: - text.append(f"{html_to_markdown(child, include_spans=True)}") - else: - text.append(f"{before_ws}{child_text}{after_ws}") + case "p": + text.append(f"\n{child_text}\n") + case "span" if include_spans: + if len(child) > 1: + text.append(html_to_markdown(child, include_spans=True)) + else: + text.append(f"{before_ws}{child_text}{after_ws}") + case _: + pass if child.tail: text.append(child.tail)