Skip to content

Commit

Permalink
Remove unnecessary backticks. Also, add copy_annotations to utils/mis…
Browse files Browse the repository at this point in the history
…c.py
  • Loading branch information
Sachaa-Thanasius committed Jul 22, 2024
1 parent a4d086a commit 4820311
Show file tree
Hide file tree
Showing 36 changed files with 461 additions and 697 deletions.
5 changes: 1 addition & 4 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,4 @@ config.toml
database/table_insertions.sql
database/migration_script.py
misc/
test_bot.py

# Might be temporary.
core/utils/extras
src/beira/utils/extras
12 changes: 0 additions & 12 deletions Dockerfile

This file was deleted.

61 changes: 0 additions & 61 deletions docker-compose.yml

This file was deleted.

1 change: 0 additions & 1 deletion src/beira/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .bot import *
from .checks import *
from .config import *
from .errors import *
14 changes: 12 additions & 2 deletions src/beira/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .checks import is_blocked
from .config import Config, load_config
from .tree import HookableTree
from .utils import LoggingManager, Pool_alias, conn_init
from .utils import LoggingManager, Pool_alias, conn_init, copy_annotations


LOGGER = logging.getLogger(__name__)
Expand All @@ -40,12 +40,19 @@ class Context(commands.Context["Beira"]):
Attributes
----------
error_handled: bool, default=False
Whether an error handler has already taken care of an error.
session
db
"""

voice_client: wavelink.Player | None # type: ignore # Type lie for narrowing

@copy_annotations(commands.Context["Beira"].__init__)
def __init__(self, *args: object, **kwargs: object):
super().__init__(*args, **kwargs)
self.error_handled: bool = False

@property
def session(self) -> aiohttp.ClientSession:
"""`ClientSession`: Returns the asynchronous HTTP session used by the bot for HTTP requests."""
Expand Down Expand Up @@ -193,12 +200,15 @@ async def on_error(self, event_method: str, /, *args: object, **kwargs: object)
async def on_command_error(self, context: Context, exception: commands.CommandError) -> None: # type: ignore # Narrowing
assert context.command # Pre-condition for being here.

if context.error_handled:
return

if isinstance(exception, commands.CommandNotFound):
return

exception = getattr(exception, "original", exception)

tb_text = "".join(traceback.format_exception(type(exception), exception, exception.__traceback__, chain=False))
tb_text = "".join(traceback.format_exception(exception, chain=False))
embed = (
discord.Embed(
title="Command Error",
Expand Down
5 changes: 1 addition & 4 deletions src/beira/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@
import msgspec


__all__ = (
"Config",
"load_config",
)
__all__ = ("Config", "load_config")


class Base(msgspec.Struct):
Expand Down
113 changes: 48 additions & 65 deletions src/beira/exts/_dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(self, bot: beira.Beira, dev_guilds: list[discord.Object]) -> None:

@property
def cog_emoji(self) -> discord.PartialEmoji:
"""`discord.PartialEmoji`: A partial emoji representing this cog."""
"""discord.PartialEmoji: A partial emoji representing this cog."""

return discord.PartialEmoji(name="discord_dev", animated=True, id=1084608963896672256)

Expand All @@ -77,24 +77,6 @@ async def cog_check(self, ctx: beira.Context) -> bool: # type: ignore # Narrowi

return await self.bot.is_owner(ctx.author)

async def cog_command_error(self, ctx: beira.Context, error: Exception) -> None: # type: ignore # Narrowing
assert ctx.command

# Extract the original error.
error = getattr(error, "original", error)
if ctx.interaction:
error = getattr(error, "original", error)

embed = discord.Embed(color=0x5E9A40)

if isinstance(error, commands.ExtensionError):
embed.description = f"Couldn't {ctx.command.name} extension: {error.name}\n{error}"
LOGGER.error("Couldn't %s extension: %s", ctx.command.name, error.name, exc_info=error)
else:
LOGGER.exception("", exc_info=error)

await ctx.send(embed=embed, ephemeral=True)

@commands.hybrid_group(fallback="get")
async def block(self, ctx: beira.Context) -> None:
"""A group of commands for blocking and unblocking users or guilds from using the bot.
Expand Down Expand Up @@ -129,9 +111,9 @@ async def block_add(
----------
ctx: `beira.Context`
The invocation context.
block_type: Literal["user", "guild"], default="user"
block_type: `Literal["user", "guild"]`, default="user"
What type of entity or entities are being blocked. Defaults to "user".
entities: `commands.Greedy`[`discord.Object`]
entities: `commands.Greedy[discord.Object`]
The entities to block.
"""

Expand Down Expand Up @@ -176,9 +158,9 @@ async def block_remove(
----------
ctx: `beira.Context`
The invocation context
block_type: Literal["user", "guild"], default="user"
block_type: `Literal["user", "guild"]`, default="user"
What type of entity or entities are being unblocked. Defaults to "user".
entities: `commands.Greedy`[`discord.Object`]
entities: `commands.Greedy[discord.Object`]
The entities to unblock.
"""

Expand Down Expand Up @@ -212,24 +194,19 @@ async def block_remove(
@block_add.error
@block_remove.error
async def block_change_error(self, ctx: beira.Context, error: commands.CommandError) -> None:
assert ctx.command

# Extract the original error.
error = getattr(error, "original", error)
if ctx.interaction:
error = getattr(error, "original", error)

assert ctx.command

if isinstance(error, PostgresError | PostgresConnectionError):
action = "block" if ctx.command.qualified_name == "block add" else "unblock"
await ctx.send(f"Unable to {action} these users/guilds at this time.", ephemeral=True)
LOGGER.exception("", exc_info=error)

@app_commands.check(lambda interaction: interaction.user.id == interaction.client.owner_id)
async def context_menu_block_add(
self,
interaction: beira.Interaction,
user: discord.User | discord.Member,
) -> None:
async def context_menu_block_add(self, interaction: beira.Interaction, user: discord.User | discord.Member) -> None:
stmt = """
INSERT INTO users (user_id, is_blocked)
VALUES ($1, $2)
Expand Down Expand Up @@ -274,13 +251,7 @@ async def shutdown(self, ctx: beira.Context) -> None:

@commands.hybrid_command()
async def walk(self, ctx: beira.Context) -> None:
"""Walk through all app commands globally and in every guild to see what is synced and where.
Parameters
----------
ctx: `beira.Context`
The invocation context where the command was called.
"""
"""Walk through all app commands globally and in every guild to see what is synced and where."""

all_embeds: list[discord.Embed] = []

Expand Down Expand Up @@ -432,6 +403,24 @@ async def ext_autocomplete(self, _: beira.Interaction, current: str) -> list[app
if current.lower() in ext.lower()
][:25]

@load.error
@unload.error
@reload.error
async def load_error(self, ctx: beira.Context, error: commands.CommandError) -> None:
assert ctx.command

# Extract the original error.
error = getattr(error, "original", error)
if ctx.interaction:
error = getattr(error, "original", error)

if isinstance(error, commands.ExtensionError):
embed = discord.Embed(
color=0x5E9A40,
description=f"Couldn't {ctx.command.name} extension: {error.name}\n{error}",
)
await ctx.send(embed=embed, ephemeral=True)

@commands.hybrid_command("sync")
@app_commands.choices(spec=[app_commands.Choice(name=name, value=value) for name, value in SPEC_CHOICES])
async def sync_(
Expand All @@ -442,16 +431,16 @@ async def sync_(
) -> None:
"""Syncs the command tree in some way based on input.
The `spec` and `guilds` parameters are mutually exclusive.
``spec`` and ``guilds`` are mutually exclusive.
Parameters
----------
ctx: `beira.Context`
The invocation context.
guilds: Greedy[`discord.Object`], optional
guilds: `Greedy[discord.Object`], optional
The guilds to sync the app commands if no specification is entered. Converts guild ids to
`discord.Object`s. Please provide as IDs separated by spaces.
spec: Choice[`str`], optional
``discord.Object``s. Please provide as IDs separated by spaces.
spec: `Choice[str]`, optional
The type of sync to perform if no guilds are entered. No input means global sync.
Notes
Expand All @@ -461,19 +450,13 @@ async def sync_(
Here is some elaboration on what the command would do with different arguments. Irrelevant with slash
activation, but replace '$' with whatever your prefix is for prefix command activation:
`$sync`: Sync globally.
`$sync ~`: Sync with current guild.
`$sync *`: Copy all global app commands to current guild and sync.
`$sync ^`: Clear all commands from the current guild target and sync, thereby removing guild commands.
`$sync -`: (D-N-T!) Clear all global commands and sync, thereby removing all global commands.
`$sync +`: (D-N-T!) Clear all commands from all guilds and sync, thereby removing all guild commands.
`$sync <id_1> <id_2> ...`: Sync with those guilds of id_1, id_2, etc.
- `$sync`: Sync globally.
- `$sync ~`: Sync with current guild.
- `$sync *`: Copy all global app commands to current guild and sync.
- `$sync ^`: Clear all commands from the current guild target and sync, thereby removing guild commands.
- `$sync -`: (D-N-T!) Clear all global commands and sync, thereby removing all global commands.
- `$sync +`: (D-N-T!) Clear all commands from all guilds and sync, thereby removing all guild commands.
- `$sync <id_1> <id_2> ...`: Sync with those guilds of id_1, id_2, etc.
References
----------
Expand Down Expand Up @@ -523,13 +506,13 @@ async def sync_(

@sync_.error
async def sync_error(self, ctx: beira.Context, error: commands.CommandError) -> None:
"""A local error handler for the :meth:`sync_` command.
"""A local error handler for the sync_ command.
Parameters
----------
ctx: `beira.Context`
ctx: beira.Context
The invocation context.
error: `commands.CommandError`
error: commands.CommandError
The error thrown by the command.
"""

Expand Down Expand Up @@ -592,16 +575,16 @@ def walk_commands_with_indent(group: commands.GroupMixin[Any]) -> Generator[str,


async def setup(bot: beira.Beira) -> None:
"""Connects cog to bot."""
dev_guild_ids = list(bot.config.discord.important_guilds["dev"])
dev_guilds = [discord.Object(id=guild_id) for guild_id in bot.config.discord.important_guilds["dev"]]
cog = DevCog(bot, dev_guilds)

# Can't use the guilds kwarg in add_cog, as it doesn't currently work for hybrids.
# Ref: https://github.com/Rapptz/discord.py/pull/9428
dev_guilds_objects = [discord.Object(id=guild_id) for guild_id in bot.config.discord.important_guilds["dev"]]
cog = DevCog(bot, dev_guilds_objects)
for cmd in cog.walk_app_commands():
if cmd._guild_ids is None:
cmd._guild_ids = [g.id for g in dev_guilds_objects]
for cmd in cog.get_app_commands():
if cmd._guild_ids is None: # pyright: ignore [reportPrivateUsage]
cmd._guild_ids = dev_guild_ids # pyright: ignore [reportPrivateUsage]
else:
cmd._guild_ids.extend(g.id for g in dev_guilds_objects)
cmd._guild_ids.extend(dev_guild_ids) # pyright: ignore [reportPrivateUsage]

await bot.add_cog(cog)
Loading

0 comments on commit 4820311

Please sign in to comment.