Skip to content

Commit

Permalink
Cut a bunch of dead code, combine ff_metadata again into one file, wo…
Browse files Browse the repository at this point in the history
…rk more on the scheduler
  • Loading branch information
Sachaa-Thanasius committed Jul 24, 2024
1 parent 21cdcfe commit 606b41b
Show file tree
Hide file tree
Showing 27 changed files with 1,470 additions and 1,816 deletions.
9 changes: 5 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,10 @@ ignore = [
"ISC001",
"ISC002",
# == Project-specific ignores.
"S311", # No need for cryptographically secure number generation in this use case; it's just dice rolls.
# "PLR", # Allow complexity.
"PLR0912", # Allow more branches
"PLR0913", # Allow more parameters
"S311", # No need for cryptographically secure number generation in this use case; it's just dice rolls.

]
unfixable = [
"ERA", # Don't want erroneous deletion of comments.
Expand All @@ -100,6 +102,5 @@ reportCallInDefaultInitializer = "warning"
reportImportCycles = "warning"
reportPropertyTypeMismatch = "warning"
reportShadowedImports = "error"
reportUninitializedInstanceVariable = "warning"
# reportUninitializedInstanceVariable = "warning"
reportUnnecessaryTypeIgnoreComment = "warning"

16 changes: 6 additions & 10 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,22 +1,18 @@
aiohttp
ao3.py @ git+https://github.com/Sachaa-Thanasius/ao3.py@main
ao3.py @ git+https://github.com/Sachaa-Thanasius/ao3.py
arsenic
async-lru
asyncpg==0.29.0
asyncpg-stubs==0.29.1
atlas-api @ https://github.com/Sachaa-Thanasius/atlas-api-wrapper
atlas-api @ git+https://github.com/Sachaa-Thanasius/atlas-api-wrapper
discord.py[speed,voice]>=2.4.0
fichub-api @ https://github.com/Sachaa-Thanasius/fichub-api
jishaku @ git+https://github.com/Gorialis/jishaku@a6661e2813124fbfe53326913e54f7c91e5d0dec
fichub-api @ git+https://github.com/Sachaa-Thanasius/fichub-api
jishaku @ git+https://github.com/Gorialis/jishaku
lxml>=4.9.3
markdownify
msgspec[toml]
msgspec
openpyxl
Pillow>=10.0.0
types-lxml
tzdata; platform_system == "Windows"
wavelink>=3.4.0

# To be used later:
# parsedatetime
# parsedatetime-stubs @ git+https://github.com/Sachaa-Thanasius/parsedatetime-stubs
# python-dateutil
67 changes: 30 additions & 37 deletions src/beira/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import sys
import time
import traceback
from typing import Any, Self, overload
from typing import Any, overload
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError

import aiohttp
Expand All @@ -18,10 +18,11 @@
import wavelink
from discord.ext import commands
from discord.utils import MISSING
from exts import EXTENSIONS

from .checks import is_blocked
from .config import Config, load_config
from .exts import EXTENSIONS
from .scheduler import Scheduler
from .tree import HookableTree
from .utils import LoggingManager, Pool_alias, conn_init, copy_annotations

Expand Down Expand Up @@ -97,6 +98,7 @@ def __init__(
db_pool: Pool_alias,
web_session: aiohttp.ClientSession,
logging_manager: LoggingManager,
scheduler: Scheduler,
initial_extensions: list[str] | None = None,
**kwargs: Any,
) -> None:
Expand All @@ -105,6 +107,7 @@ def __init__(
self.db_pool = db_pool
self.web_session = web_session
self.logging_manager = logging_manager
self.scheduler = scheduler
self.initial_extensions: list[str] = initial_extensions or []

# Various webfiction-related clients.
Expand All @@ -115,21 +118,29 @@ def __init__(

# Things to load before connecting to the Gateway.
self.prefix_cache: dict[int, list[str]] = {}
self.blocked_entities_cache: dict[str, set[int]] = {}
self.blocked_guilds: set[int] = set()
self.blocked_users: set[int] = set()

# Things that are more convenient to retrieve when established here or filled after connecting to the Gateway.
self.special_friends: dict[str, int] = {}

# Add a global check for blocked members.
self.add_check(is_blocked().predicate)

@property
def owner(self) -> discord.User:
"""`discord.User`: The user that owns the bot."""

return self.app_info.owner

async def on_ready(self) -> None:
"""Display that the bot is ready."""

assert self.user
LOGGER.info("Logged in as %s (ID: %s)", self.user, self.user.id)

async def setup_hook(self) -> None:
self.scheduler.start_discord_dispatch(self)
await self._load_guild_prefixes()
await self._load_blocked_entities()
await self._load_extensions()
Expand All @@ -149,14 +160,15 @@ async def setup_hook(self) -> None:
# Cache "friends".
self.loop.create_task(self._load_special_friends())

async def get_prefix(self, message: discord.Message, /) -> list[str] | str:
if not self.prefix_cache:
await self._load_guild_prefixes()
async def close(self) -> None:
await self.scheduler.stop_discord_dispatch()
await super().close()

async def get_prefix(self, message: discord.Message, /) -> list[str] | str:
return self.prefix_cache.get(message.guild.id, "$") if message.guild else "$"

@overload
async def get_context(self, origin: discord.Message | discord.Interaction, /) -> commands.Context[Self]: ...
async def get_context(self, origin: discord.Message | discord.Interaction, /) -> Context: ...

@overload
async def get_context[ContextT: commands.Context[Any]](
Expand Down Expand Up @@ -239,39 +251,27 @@ async def on_command_error(self, context: Context, exception: commands.CommandEr
embed.add_field(name="Channel", value=f"{context.channel}", inline=False)
LOGGER.error("Exception in command %s", context.command, exc_info=exception, extra={"embed": embed})

@property
def owner(self) -> discord.User:
"""`discord.User`: The user that owns the bot."""

return self.app_info.owner

async def _load_blocked_entities(self) -> None:
"""Load all blocked users and guilds from the bot database."""

async with self.db_pool.acquire() as conn, conn.transaction():
user_records = await conn.fetch("SELECT user_id FROM users WHERE is_blocked;")
guild_records = await conn.fetch("SELECT guild_id FROM guilds WHERE is_blocked;")
user_records = await self.db_pool.fetch("SELECT user_id FROM users WHERE is_blocked;")
self.blocked_users.update([record["user_id"] for record in user_records])

self.blocked_entities_cache["users"] = {record["user_id"] for record in user_records}
self.blocked_entities_cache["guilds"] = {record["guild_id"] for record in guild_records}
guild_records = await self.db_pool.fetch("SELECT guild_id FROM guilds WHERE is_blocked;")
self.blocked_guilds.update([record["guild_id"] for record in guild_records])

async def _load_guild_prefixes(self, guild_id: int | None = None) -> None:
async def _load_guild_prefixes(self) -> None:
"""Load all prefixes from the bot database."""

query = "SELECT guild_id, prefix FROM guild_prefixes"
if guild_id:
query += " WHERE guild_id = $1"

try:
db_prefixes = await self.db_pool.fetch(query)
db_prefixes = await self.db_pool.fetch("SELECT guild_id, prefix FROM guild_prefixes;")
except OSError:
LOGGER.exception("Couldn't load guild prefixes from the database. Ignoring for sake of defaults.")
LOGGER.exception("Couldn't load guild prefixes from the database. Using default(s) instead.")
else:
for entry in db_prefixes:
self.prefix_cache.setdefault(entry["guild_id"], []).append(entry["prefix"])

msg = f"(Re)loaded guild prefixes for {guild_id}." if guild_id else "(Re)loaded all guild prefixes."
LOGGER.info(msg)
LOGGER.info("(Re)loaded all guild prefixes.")

async def _load_extensions(self) -> None:
"""Loads extensions/cogs.
Expand Down Expand Up @@ -326,25 +326,17 @@ def is_special_friend(self, user: discord.abc.User, /) -> bool:

return False

def is_ali(self, user: discord.abc.User, /) -> bool:
"""Checks if a `discord.User` or `discord.Member` is Ali."""

if len(self.special_friends) > 0:
return user.id == self.special_friends["aeroali"]

return False


async def main() -> None:
"""Starts an instance of the bot."""

config = load_config()

# Initialize a connection to a PostgreSQL database, an asynchronous web session, and a custom logger setup.
async with (
aiohttp.ClientSession() as web_session,
asyncpg.create_pool(dsn=config.database.pg_url, command_timeout=30, init=conn_init) as pool,
LoggingManager() as logging_manager,
Scheduler(pool) as scheduler,
):
# Set the bot's basic starting parameters.
intents = discord.Intents.all()
Expand All @@ -353,11 +345,12 @@ async def main() -> None:

# Initialize and start the bot.
async with Beira(
command_prefix=default_prefix,
default_prefix,
config=config,
db_pool=pool,
web_session=web_session,
logging_manager=logging_manager,
scheduler=scheduler,
intents=intents,
tree_cls=HookableTree,
) as bot:
Expand Down
4 changes: 2 additions & 2 deletions src/beira/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,9 @@ def is_blocked() -> "Check[Any]":

async def predicate(ctx: Context) -> bool:
if not (await ctx.bot.is_owner(ctx.author)):
if ctx.author.id in ctx.bot.blocked_entities_cache["users"]:
if ctx.author.id in ctx.bot.blocked_users:
raise UserIsBlocked
if ctx.guild and (ctx.guild.id in ctx.bot.blocked_entities_cache["guilds"]):
if ctx.guild and ctx.guild.id in ctx.bot.blocked_guilds:
raise GuildIsBlocked
return True

Expand Down
Loading

0 comments on commit 606b41b

Please sign in to comment.