Skip to content

Commit

Permalink
Push remaining changes.
Browse files Browse the repository at this point in the history
  • Loading branch information
Sachaa-Thanasius committed Jul 24, 2024
1 parent 606b41b commit 65a9176
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 114 deletions.
89 changes: 36 additions & 53 deletions src/beira/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from .exts import EXTENSIONS
from .scheduler import Scheduler
from .tree import HookableTree
from .utils import LoggingManager, Pool_alias, conn_init, copy_annotations
from .utils import LoggingManager, Pool_alias, catchtime, conn_init, copy_annotations


LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -117,7 +117,7 @@ def __init__(
self.ao3_client = ao3.Client(session=self.web_session)

# Things to load before connecting to the Gateway.
self.prefix_cache: dict[int, list[str]] = {}
self.prefixes: dict[int, list[str]] = {}
self.blocked_guilds: set[int] = set()
self.blocked_users: set[int] = set()

Expand All @@ -140,10 +140,39 @@ async def on_ready(self) -> None:
LOGGER.info("Logged in as %s (ID: %s)", self.user, self.user.id)

async def setup_hook(self) -> None:
# Start up the scheduler.
self.scheduler.start_discord_dispatch(self)
await self._load_guild_prefixes()
await self._load_blocked_entities()
await self._load_extensions()

# Load guild prefixes.
prefix_records = await self.db_pool.fetch("SELECT guild_id, prefix FROM guild_prefixes;")
for entry in prefix_records:
self.prefixes.setdefault(entry["guild_id"], []).append(entry["prefix"])

LOGGER.info("Loaded all guild prefixes.")

# Load all blocked users and guilds.
blocked_user_records = await self.db_pool.fetch("SELECT user_id FROM users WHERE is_blocked;")
self.blocked_users |= {record["user_id"] for record in blocked_user_records}

blocked_guild_records = await self.db_pool.fetch("SELECT guild_id FROM guilds WHERE is_blocked;")
self.blocked_guilds |= {record["guild_id"] for record in blocked_guild_records}

# Load extensions/cogs. If a list of initial ones isn't provided, all extensions are loaded by default.
await self.load_extension("jishaku")

exts_to_load = self.initial_extensions or EXTENSIONS
with catchtime() as all_exts_time:
for extension in exts_to_load:
start_time = time.perf_counter()
try:
await self.load_extension(extension)
except commands.ExtensionError as err:
LOGGER.exception("Failed to load extension: %s", extension, exc_info=err)
else:
end_time = time.perf_counter()
LOGGER.info("Loaded extension: %s -- Time: %.5f", extension, end_time - start_time)

LOGGER.info("Total extension loading time: Time: %.5f", all_exts_time.time)

# Connect to lavalink node(s).
node = wavelink.Node(
Expand All @@ -165,7 +194,7 @@ async def close(self) -> None:
await super().close()

async def get_prefix(self, message: discord.Message, /) -> list[str] | str:
return self.prefix_cache.get(message.guild.id, "$") if message.guild else "$"
return self.prefixes.get(message.guild.id, "$") if message.guild else "$"

@overload
async def get_context(self, origin: discord.Message | discord.Interaction, /) -> Context: ...
Expand Down Expand Up @@ -251,57 +280,11 @@ async def on_command_error(self, context: Context, exception: commands.CommandEr
embed.add_field(name="Channel", value=f"{context.channel}", inline=False)
LOGGER.error("Exception in command %s", context.command, exc_info=exception, extra={"embed": embed})

async def _load_blocked_entities(self) -> None:
"""Load all blocked users and guilds from the bot database."""

user_records = await self.db_pool.fetch("SELECT user_id FROM users WHERE is_blocked;")
self.blocked_users.update([record["user_id"] for record in user_records])

guild_records = await self.db_pool.fetch("SELECT guild_id FROM guilds WHERE is_blocked;")
self.blocked_guilds.update([record["guild_id"] for record in guild_records])

async def _load_guild_prefixes(self) -> None:
"""Load all prefixes from the bot database."""

try:
db_prefixes = await self.db_pool.fetch("SELECT guild_id, prefix FROM guild_prefixes;")
except OSError:
LOGGER.exception("Couldn't load guild prefixes from the database. Using default(s) instead.")
else:
for entry in db_prefixes:
self.prefix_cache.setdefault(entry["guild_id"], []).append(entry["prefix"])

LOGGER.info("(Re)loaded all guild prefixes.")

async def _load_extensions(self) -> None:
"""Loads extensions/cogs.
If a list of initial ones isn't provided, all extensions are loaded by default.
"""

await self.load_extension("jishaku")

exts_to_load = self.initial_extensions or EXTENSIONS
all_exts_start_time = time.perf_counter()
for extension in exts_to_load:
start_time = time.perf_counter()
try:
await self.load_extension(extension)
except commands.ExtensionError as err:
LOGGER.exception("Failed to load extension: %s", extension, exc_info=err)
else:
end_time = time.perf_counter()
LOGGER.info("Loaded extension: %s -- Time: %.5f", extension, end_time - start_time)
all_exts_end_time = time.perf_counter()
LOGGER.info("Total extension loading time: Time: %.5f", all_exts_end_time - all_exts_start_time)

async def _load_special_friends(self) -> None:
await self.wait_until_ready()

friends_ids: list[int] = self.config.discord.friend_ids
for user_id in friends_ids:
if user_obj := self.get_user(user_id):
self.special_friends[user_obj.name] = user_id
self.special_friends |= {user.name: user_id for user_id in friends_ids if (user := self.get_user(user_id))}

@async_lru.alru_cache()
async def get_user_timezone(self, user_id: int) -> str | None:
Expand Down
2 changes: 1 addition & 1 deletion src/beira/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ async def predicate(ctx: Context) -> bool:
if not (await ctx.bot.is_owner(ctx.author)):
if ctx.author.id in ctx.bot.blocked_users:
raise UserIsBlocked
if ctx.guild and ctx.guild.id in ctx.bot.blocked_guilds:
if ctx.guild and (ctx.guild.id in ctx.bot.blocked_guilds):
raise GuildIsBlocked
return True

Expand Down
29 changes: 13 additions & 16 deletions src/beira/exts/_dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import logging
from collections.abc import Generator
from time import perf_counter
from typing import Any, Literal

import discord
Expand All @@ -11,6 +10,7 @@
from discord.ext import commands

import beira
from beira.utils import catchtime

from . import EXTENSIONS

Expand Down Expand Up @@ -363,23 +363,22 @@ async def reload(self, ctx: beira.Context, extension: str) -> None:
reloaded: list[str] = []
failed: list[str] = []

start_time = perf_counter()
for ext in sorted(self.bot.extensions):
try:
await self.bot.reload_extension(ext)
except commands.ExtensionError as err:
failed.append(ext)
LOGGER.exception("Couldn't reload extension: %s", ext, exc_info=err)
else:
reloaded.append(ext)
end_time = perf_counter()
with catchtime() as reload_time:
for ext in sorted(self.bot.extensions):
try:
await self.bot.reload_extension(ext)
except commands.ExtensionError as err:
failed.append(ext)
LOGGER.exception("Couldn't reload extension: %s", ext, exc_info=err)
else:
reloaded.append(ext)

ratio_succeeded = f"{len(reloaded)}/{len(self.bot.extensions)}"
LOGGER.info("Attempted to reload all extensions. Successful: %s.", ratio_succeeded)

embed.add_field(name="Reloaded", value="\n".join(reloaded))
embed.add_field(name="Failed to reload", value="\n".join(failed))
embed.set_footer(text=f"Time taken: {end_time - start_time:.3f}s")
embed.set_footer(text=f"Time taken: {reload_time:.3f}s")

await ctx.send(embed=embed, ephemeral=True)

Expand Down Expand Up @@ -538,18 +537,16 @@ async def cmd_tree(self, ctx: beira.Context) -> None:
"""Display all bot commands in a pretty tree-like format."""

def walk_commands_with_indent(group: commands.GroupMixin[Any], indent_level: int = 0) -> Generator[str]:
indent_level += 4

for cmd in group.commands:
if indent_level > 4: # noqa: SIM108
if indent_level != 0: # noqa: SIM108
indent = (indent_level - 1) * "─"
else:
indent = ""

yield f"└{indent}{cmd.qualified_name}"

if isinstance(cmd, commands.GroupMixin):
yield from walk_commands_with_indent(cmd, indent_level)
yield from walk_commands_with_indent(cmd, indent_level + 4)

await ctx.send("\n".join(("```", "Beira", *walk_commands_with_indent(ctx.bot), "```")))

Expand Down
6 changes: 3 additions & 3 deletions src/beira/exts/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ async def prefixes_add(self, ctx: beira.GuildContext, *, new_prefix: str) -> Non
await conn.execute(guild_stmt, ctx.guild.id)
await conn.execute(prefix_stmt, ctx.guild.id, new_prefix)
# Update it in the cache.
self.bot.prefix_cache.setdefault(ctx.guild.id, []).append(new_prefix)
self.bot.prefixes.setdefault(ctx.guild.id, []).append(new_prefix)

await ctx.send(f"'{new_prefix}' has been registered as a prefix in this guild.")

Expand All @@ -103,7 +103,7 @@ async def prefixes_remove(self, ctx: beira.GuildContext, *, old_prefix: str) ->
# Update it in the database and the cache.
prefix_stmt = "DELETE FROM guild_prefixes WHERE guild_id = $1 AND prefix = $2;"
await self.bot.db_pool.execute(prefix_stmt, ctx.guild.id, old_prefix)
self.bot.prefix_cache.setdefault(ctx.guild.id, [old_prefix]).remove(old_prefix)
self.bot.prefixes.setdefault(ctx.guild.id, [old_prefix]).remove(old_prefix)

await ctx.send(f"'{old_prefix}' has been unregistered as a prefix in this guild.")

Expand All @@ -117,7 +117,7 @@ async def prefixes_reset(self, ctx: beira.GuildContext) -> None:
# Update it in the database and the cache.
prefix_stmt = """DELETE FROM guild_prefixes WHERE guild_id = $1;"""
await self.bot.db_pool.execute(prefix_stmt, ctx.guild.id)
self.bot.prefix_cache.setdefault(ctx.guild.id, []).clear()
self.bot.prefixes.setdefault(ctx.guild.id, []).clear()

content = "The prefix(es) for this guild have been reset. Now only accepting the default prefix: `$`."
await ctx.send(content)
Expand Down
2 changes: 1 addition & 1 deletion src/beira/exts/lol.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async def update_op_gg_profiles(urls: list[str]) -> None:

# Create the webdriver.
with GECKODRIVER_LOGS.open(mode="a", encoding="utf-8") as log_file:
service = services.Geckodriver(binary=str(GECKODRIVER), log_file=log_file) # type: ignore # attrs class
service = services.Geckodriver(binary=str(GECKODRIVER), log_file=log_file) # type: ignore # Untyped class
browser = browsers.Firefox(**{"moz:firefoxOptions": {"args": ["-headless"]}})

async with get_session(service, browser) as session:
Expand Down
30 changes: 14 additions & 16 deletions src/beira/exts/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import random
import re
import tempfile
import time
from io import BytesIO, StringIO

import discord
Expand All @@ -16,6 +15,7 @@
from discord.ext import commands

import beira
from beira.utils import catchtime


INSPIROBOT_API_URL = "https://inspirobot.me/api"
Expand Down Expand Up @@ -205,29 +205,27 @@ async def quote(self, ctx: beira.Context, *, message: discord.Message) -> None:
async def ping_(self, ctx: beira.Context) -> None:
"""Display the time necessary for the bot to communicate with Discord."""

ws_ping = self.bot.latency * 1000
ws_ping = self.bot.latency

start_time = time.perf_counter()
await ctx.typing()
typing_ping = (time.perf_counter() - start_time) * 1000
with catchtime() as typing_ping:
await ctx.typing()

start_time = time.perf_counter()
await self.bot.db_pool.fetch("SELECT * FROM guilds;")
db_ping = (time.perf_counter() - start_time) * 1000
with catchtime() as db_ping:
await self.bot.db_pool.fetch("SELECT * FROM guilds;")

start_time = time.perf_counter()
message = await ctx.send(embed=discord.Embed(title="Ping..."))
msg_ping = (time.perf_counter() - start_time) * 1000
with catchtime() as msg_ping:
message = await ctx.send(embed=discord.Embed(title="Ping..."))

total_time = sum((ws_ping, *(catch.time for catch in (typing_ping, db_ping, msg_ping))))
pong_embed = (
discord.Embed(title="Pong! \N{TABLE TENNIS PADDLE AND BALL}")
.add_field(name="Websocket", value=f"```json\n{ws_ping:.2f} ms\n```")
.add_field(name="Typing", value=f"```json\n{typing_ping:.2f} ms\n```")
.add_field(name="Websocket", value=f"```json\n{ws_ping * 1000:.2f} ms\n```")
.add_field(name="Typing", value=f"```json\n{typing_ping.time * 1000:.2f} ms\n```")
.add_field(name="\u200b", value="\u200b")
.add_field(name="Database", value=f"```json\n{db_ping:.2f} ms\n```")
.add_field(name="Message", value=f"```json\n{msg_ping:.2f} ms\n```")
.add_field(name="Database", value=f"```json\n{db_ping.time * 1000:.2f} ms\n```")
.add_field(name="Message", value=f"```json\n{msg_ping.time * 1000:.2f} ms\n```")
.add_field(name="\u200b", value="\u200b")
.add_field(name="Average", value=f"```json\n{(ws_ping + typing_ping + db_ping + msg_ping) / 4:.2f} ms\n```")
.add_field(name="Average", value=f"```json\n{total_time * 1000 / 4:.2f} ms\n```")
)

await message.edit(embed=pong_embed)
Expand Down
48 changes: 24 additions & 24 deletions src/beira/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,16 @@ class catchtime:

def __init__(self, logger: logging.Logger | None = None):
self.logger = logger
self.elapsed = 0.0
self.time = 0.0

def __enter__(self):
self.elapsed = time.perf_counter()
self.time = time.perf_counter()
return self

def __exit__(self, *exc_info: object) -> None:
self.elapsed = time.perf_counter() - self.elapsed
self.time = time.perf_counter() - self.time
if self.logger:
self.logger.info("Time: %.3f seconds", self.elapsed)
self.logger.info("Time: %.3f seconds", self.time)


_BEFORE_WS = re.compile(r"^([\s]+)")
Expand All @@ -48,7 +48,7 @@ def html_to_markdown(node: lxml.html.HtmlElement, *, include_spans: bool = False
italics_marker: str = "_"

if base_url is not None:
node.make_links_absolute("".join(base_url.partition(".com/wiki/")[0:-1]), resolve_base_href=True)
node.make_links_absolute("".join(base_url.partition(".com/wiki/")[0:2]), resolve_base_href=True)

for child in node.iter():
if child.text:
Expand All @@ -59,27 +59,27 @@ def html_to_markdown(node: lxml.html.HtmlElement, *, include_spans: bool = False
else:
before_ws = after_ws = child_text = ""

if child.tag in {"i", "em"}:
text.append(f"{before_ws}{italics_marker}{child_text}{italics_marker}{after_ws}")
if italics_marker == "*": # type: ignore # Pyright bug?
italics_marker = "_"
elif child.tag in {"b", "strong"}:
if text and text[-1].endswith("*"):
text.append("\u200b")
text.append(f"{before_ws}**{child_text}**{after_ws}")
elif child.tag == "a":
# No markup for links
if base_url is None:
match child.tag:
case "i" | "em":
text.append(f"{before_ws}{italics_marker}{child_text}{italics_marker}{after_ws}")
italics_marker = "_" if italics_marker == "*" else "*"
case "b" | "strong":
if text and text[-1].endswith("*"):
text.append("\u200b")
text.append(f"{before_ws}**{child_text}**{after_ws}")
case "a" if base_url is None: # No markup for incomplete links
text.append(f"{before_ws}{child_text}{after_ws}")
else:
case "a":
text.append(f"{before_ws}[{child.text}]({child.attrib['href']}){after_ws}")
elif child.tag == "p":
text.append(f"\n{child_text}\n")
elif include_spans and child.tag == "span":
if len(child) > 1:
text.append(f"{html_to_markdown(child, include_spans=True)}")
else:
text.append(f"{before_ws}{child_text}{after_ws}")
case "p":
text.append(f"\n{child_text}\n")
case "span" if include_spans:
if len(child) > 1:
text.append(html_to_markdown(child, include_spans=True))
else:
text.append(f"{before_ws}{child_text}{after_ws}")
case _:
pass

if child.tail:
text.append(child.tail)
Expand Down

0 comments on commit 65a9176

Please sign in to comment.