Skip to content

Commit

Permalink
A lot, but mainly replacing Any with object where possible.
Browse files Browse the repository at this point in the history
- Eliminate DTEmbed
- Combine notification cogs and listeners into a module
- Optimize ff_metadata slightly.
- Copy implementation of QueueHandler into AsyncQueueHandler.
- Lock TODO pin listening action to testing guild.
  • Loading branch information
Sachaa-Thanasius committed Nov 4, 2023
1 parent d513c26 commit bb24f71
Show file tree
Hide file tree
Showing 24 changed files with 274 additions and 172 deletions.
7 changes: 5 additions & 2 deletions core/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import sys
import time
import traceback
from typing import TYPE_CHECKING, Any, TypeAlias
from typing import TYPE_CHECKING, Any

import aiohttp
import ao3
Expand All @@ -30,7 +30,7 @@
if TYPE_CHECKING:
from core.utils import LoggingManager
else:
LoggingManager: TypeAlias = Any
LoggingManager = object

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -222,6 +222,7 @@ async def _load_extensions(self) -> None:
await self.load_extension("jishaku")

exts_to_load = self.initial_extensions or EXTENSIONS
all_exts_start_time = time.perf_counter()
for extension in exts_to_load:
try:
start_time = time.perf_counter()
Expand All @@ -230,6 +231,8 @@ async def _load_extensions(self) -> None:
LOGGER.info("Loaded extension: %s -- Time: %.5f", extension, end_time - start_time)
except commands.ExtensionError as err:
LOGGER.exception("Failed to load extension: %s", extension, exc_info=err)
all_exts_end_time = time.perf_counter()
LOGGER.info("Total extension loading time: Time: %.5f", all_exts_start_time - all_exts_end_time)

async def _load_special_friends(self) -> None:
await self.wait_until_ready()
Expand Down
36 changes: 32 additions & 4 deletions core/utils/custom_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
from __future__ import annotations

import asyncio
import copy
import logging
from logging.handlers import QueueHandler, RotatingFileHandler
from logging.handlers import RotatingFileHandler
from pathlib import Path
from typing import TYPE_CHECKING, Any, TypeAlias, TypeVar
from typing import TYPE_CHECKING, Any, TypeVar

from discord.utils import _ColourFormatter as ColourFormatter, stream_supports_colour # type: ignore # Because color.

Expand All @@ -22,13 +23,40 @@

from typing_extensions import Self
else:
Self: TypeAlias = Any
TracebackType = Self = object

BE = TypeVar("BE", bound=BaseException)

__all__ = ("LoggingManager",)


class AsyncQueueHandler(logging.Handler):
# Copy implementation of QueueHandler.
def __init__(self, queue: asyncio.Queue[Any]) -> None:
logging.Handler.__init__(self)
self.queue = queue

def enqueue(self, record: logging.LogRecord) -> None:
self.queue.put_nowait(record)

def prepare(self, record: logging.LogRecord) -> logging.LogRecord:
msg = self.format(record)
record = copy.copy(record)
record.message = msg
record.msg = msg
record.args = None
record.exc_info = None
record.exc_text = None
record.stack_info = None
return record

def emit(self, record: logging.LogRecord) -> None:
try:
self.enqueue(self.prepare(record))
except Exception: # noqa: BLE001
self.handleError(record)


class RemoveNoise(logging.Filter):
"""Filter for custom logging system.
Expand Down Expand Up @@ -112,7 +140,7 @@ def __enter__(self) -> Self:
self.log.addHandler(stream_handler)

# Add a queue handler.
queue_handler = QueueHandler(self.log_queue)
queue_handler = AsyncQueueHandler(self.log_queue)
self.log.addHandler(queue_handler)

return self
Expand Down
40 changes: 15 additions & 25 deletions core/utils/embeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,49 +15,39 @@
if TYPE_CHECKING:
from typing_extensions import Self
else:
Self: TypeAlias = Any
Self = object

AnyEmoji: TypeAlias = discord.Emoji | discord.PartialEmoji | str

__all__ = ("EMOJI_URL", "DTEmbed", "StatsEmbed")
__all__ = ("StatsEmbed",)

LOGGER = logging.getLogger(__name__)

EMOJI_URL = "https://cdn.discordapp.com/emojis/{0}.webp?size=128&quality=lossless"


class DTEmbed(discord.Embed):
"""Represents a Discord embed, with a preset timestamp attribute.
Inherits from :class:`discord.Embed`.
"""

def __init__(self, **kwargs: Any) -> None:
kwargs["timestamp"] = kwargs.get("timestamp", discord.utils.utcnow())
super().__init__(**kwargs)


class StatsEmbed(DTEmbed):
class StatsEmbed(discord.Embed):
"""A subclass of :class:`DTEmbed` that displays given statistics for a user.
This has a default colour of 0x2f3136 and, due to inheritance, a default timestamp for right now in UTC.
This has a default colour of 0x2f3136 and a default timestamp for right now in UTC.
Parameters
----------
*args
Positional arguments
**kwargs
Keyword arguments for the normal initialization of a discord :class:`Embed`.
"""

def __init__(self, **kwargs: Any) -> None:
def __init__(self, *args: Any, **kwargs: Any) -> None:
kwargs["colour"] = kwargs.get("colour") or kwargs.get("color") or 0x2F3136
super().__init__(**kwargs)
kwargs["timestamp"] = kwargs.get("timestamp", discord.utils.utcnow())
super().__init__(*args, **kwargs)

def add_stat_fields(
self,
*,
names: Iterable[Any],
names: Iterable[object],
emojis: Iterable[AnyEmoji] = ("",),
values: Iterable[Any],
values: Iterable[object],
inline: bool = False,
emoji_as_header: bool = False,
) -> Self:
Expand All @@ -67,11 +57,11 @@ def add_stat_fields(
Parameters
----------
names: Iterable[Any]
names: Iterable[object]
The names for each field.
emojis: Iterable[AnyEmoji]
The emojis adorning each field. Defaults to a tuple with an empty string so there is at least one "emoji".
values: Iterable[Any], default=("",)
values: Iterable[object], default=("",)
The values for each field.
inline: :class:`bool`, default=False
Whether the fields should be displayed inline. Defaults to False.
Expand All @@ -94,7 +84,7 @@ def add_stat_fields(
def add_leaderboard_fields(
self,
*,
ldbd_content: Iterable[Sequence[Any]],
ldbd_content: Iterable[Sequence[object]],
ldbd_emojis: Iterable[AnyEmoji] = ("",),
name_format: str = "| {}",
value_format: str = "{}",
Expand All @@ -107,7 +97,7 @@ def add_leaderboard_fields(
Parameters
----------
ldbd_content: Iterable[Sequence[Any]]
ldbd_content: Iterable[Sequence[object]]
The content for each leaderboard, including names and values. Assumes they're given in descending order.
ldbd_emojis: Iterable[AnyEmoji], default=("",)
The emojis adorning the names of the leaderboard fields. Defaults to a tuple with an empty string so there
Expand Down
10 changes: 9 additions & 1 deletion core/utils/emojis.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
__all__ = ("EMOJI_STOCK",)
from __future__ import annotations


__all__ = (
"EMOJI_STOCK",
"EMOJI_URL",
)

# fmt: off
EMOJI_STOCK: dict[str, str] = {
Expand Down Expand Up @@ -27,3 +33,5 @@
"d100": "<a:d100:1109960365967687841>",
}
# fmt: on

EMOJI_URL = "https://cdn.discordapp.com/emojis/{0}.webp?size=128&quality=lossless"
11 changes: 9 additions & 2 deletions core/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@
from collections.abc import Awaitable, Callable, Coroutine
from functools import wraps
from time import perf_counter
from typing import TYPE_CHECKING, Any, ParamSpec, TypeAlias, TypeGuard, TypeVar, overload
from typing import TYPE_CHECKING, Any, ParamSpec, TypeGuard, TypeVar, overload


if TYPE_CHECKING:
from types import TracebackType

from typing_extensions import Self
else:
Self: TypeAlias = Any
TracebackType = Self = object

T = TypeVar("T")
P = ParamSpec("P")
Expand Down Expand Up @@ -205,3 +205,10 @@ def inner(*args: P.args, **kwargs: P.kwargs) -> T:
return func(*args, **kwargs)

return inner


def take_annotation_from(original: Callable[P, T]) -> Callable[[Callable[P, T]], Callable[P, T]]:
def wrapped(new: Callable[P, T]) -> Callable[P, T]:
return new

return wrapped
12 changes: 6 additions & 6 deletions core/utils/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Generic, TypeAlias, TypeVar
from typing import TYPE_CHECKING, Any, Generic, TypeVar

import discord
from discord.utils import maybe_coroutine
Expand All @@ -18,7 +18,7 @@
if TYPE_CHECKING:
from typing_extensions import Self
else:
Self: TypeAlias = Any
Self = object

_LT = TypeVar("_LT")

Expand Down Expand Up @@ -184,7 +184,7 @@ async def on_timeout(self) -> None:
self.stop()

@abstractmethod
def format_page(self) -> Any:
def format_page(self) -> discord.Embed:
"""|maybecoro|
Makes, or retrieves from the cache, the embed 'page' that the user will see.
Expand Down Expand Up @@ -352,7 +352,7 @@ async def on_timeout(self) -> None:
self.stop()

@abstractmethod
def format_page(self) -> Any:
def format_page(self) -> discord.Embed:
"""|maybecoro|
Makes and returns the embed 'page' that the user will see.
Expand Down Expand Up @@ -390,8 +390,8 @@ async def update_page(self, interaction: discord.Interaction) -> None:
self.disable_page_buttons()
await interaction.response.edit_message(embed=embed_page, view=self)

@discord.ui.select()
async def select_page(self, interaction: discord.Interaction, select: discord.ui.Select[Any]) -> None:
@discord.ui.select(cls=discord.ui.Select[Self])
async def select_page(self, interaction: discord.Interaction, select: discord.ui.Select[Self]) -> None:
"""Dropdown that displays all the Patreon tiers and provides them as choices to navigate to."""

self.page_index = int(select.values[0])
Expand Down
6 changes: 4 additions & 2 deletions exts/bot_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import logging
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Literal

import discord
from discord.app_commands import Choice
Expand All @@ -18,6 +18,8 @@

if TYPE_CHECKING:
from asyncpg import Record
else:
Record = object


LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -184,7 +186,7 @@ async def get_usage(
) -> list[Record]:
"""Queries the database for command usage."""

query_args: list[Any] = [] # Holds the query args as objects.
query_args: list[object] = [] # Holds the query args as objects.
where_params: list[str] = [] # Holds the query param placeholders as formatted strings.

# Create the base queries.
Expand Down
4 changes: 2 additions & 2 deletions exts/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import textwrap
from ast import literal_eval
from io import StringIO
from typing import TYPE_CHECKING, Any, TypeAlias, cast
from typing import TYPE_CHECKING, Any, cast

import discord
import msgspec
Expand All @@ -27,7 +27,7 @@
if TYPE_CHECKING:
from typing_extensions import Self
else:
Self: TypeAlias = Any
Self = object


LOGGER = logging.getLogger(__name__)
Expand Down
2 changes: 2 additions & 0 deletions exts/emoji_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

if TYPE_CHECKING:
from typing_extensions import Self
else:
Self = object

LOGGER = logging.getLogger(__name__)

Expand Down
9 changes: 5 additions & 4 deletions exts/fandom_wiki.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from lxml import etree, html

import core
from core.utils import EMOJI_URL, DTEmbed
from core.utils import EMOJI_URL


if TYPE_CHECKING:
Expand All @@ -34,8 +34,8 @@
AOC_EMOJI_URL, JARE_EMOJI_URL = EMOJI_URL.format(770620658501025812), EMOJI_URL.format(1061029880059400262)


class AoCWikiEmbed(DTEmbed):
"""A subclass of :class:`DTEmbed` that is set up for representing Ashes of Chaos wiki pages.
class AoCWikiEmbed(discord.Embed):
"""A subclass of :class:`discord.Embed` that is set up for representing Ashes of Chaos wiki pages.
Parameters
----------
Expand Down Expand Up @@ -263,7 +263,7 @@ async def search_wiki(self, wiki_name: str, wiki_query: str) -> discord.Embed:

# --------------------------------
# Check if the wiki has the requested query as a page.
final_embed = AoCWikiEmbed() if wiki_name == "Harry Potter and the Ashes of Chaos" else DTEmbed()
final_embed = AoCWikiEmbed() if wiki_name == "Harry Potter and the Ashes of Chaos" else discord.Embed()

specific_wiki_page = wiki_pages.get(wiki_query)

Expand All @@ -286,6 +286,7 @@ async def search_wiki(self, wiki_name: str, wiki_query: str) -> discord.Embed:
# Add the primary embed parameters.
final_embed.title = wiki_query
final_embed.url = specific_wiki_page
final_embed.timestamp = discord.utils.utcnow()

# Fetch information from the character webpage to populate the rest of the embed.
summary, thumbnail = await process_fandom_page(self.bot.web_session, specific_wiki_page)
Expand Down
Loading

0 comments on commit bb24f71

Please sign in to comment.