diff --git a/core/tree.py b/core/tree.py index cc1d039..cc18400 100644 --- a/core/tree.py +++ b/core/tree.py @@ -7,7 +7,7 @@ import discord from discord import Client, Interaction from discord.app_commands import AppCommandError, Command, CommandTree, Group, Namespace -from discord.ext.commands import Cog +from discord.ext import commands if TYPE_CHECKING: @@ -25,7 +25,7 @@ T = TypeVar("T") Coro: TypeAlias = Coroutine[Any, Any, T] CoroFunc: TypeAlias = Callable[..., Coro[Any]] -GroupT = TypeVar("GroupT", bound=Group | Cog) +GroupT = TypeVar("GroupT", bound=Group | commands.Cog) AppHook: TypeAlias = Callable[[GroupT, Interaction[Any]], Coro[Any]] | Callable[[Interaction[Any]], Coro[Any]] __all__ = ("before_app_invoke", "after_app_invoke", "HookableTree") diff --git a/core/utils/app_help_test.py b/core/utils/app_help_test.py new file mode 100644 index 0000000..64ea778 --- /dev/null +++ b/core/utils/app_help_test.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +from typing import Any + +import discord +from discord import app_commands +from discord._types import ClientT + + +def get_nested_command( + tree: app_commands.CommandTree[ClientT], + name: str, + *, + guild: discord.Guild | None, +) -> app_commands.Command[Any, ..., Any] | app_commands.Group | None: + key, *keys = name.split(" ") + cmd = tree.get_command(key, guild=guild) or tree.get_command(key) + + for key in keys: + if cmd is None: + return None + if isinstance(cmd, app_commands.Command): + break + + cmd = cmd.get_command(key) + + return cmd + + +@app_commands.command(name="help2") +async def _help(itx: discord.Interaction[ClientT], command: str) -> None: + tree: app_commands.CommandTree | None = getattr(itx.client, "tree", None) + if tree is None: + await itx.response.send_message("Could not find a command tree", ephemeral=True) + return + + cmd = get_nested_command(tree, command, guild=itx.guild) + if cmd is None: + await itx.response.send_message(f"Could not find a command named {command}", ephemeral=True) + return + + if isinstance(cmd, app_commands.Command): + description = cmd.callback.__doc__ or cmd.description + else: + description = cmd.__doc__ or cmd.description + + embed = discord.Embed(title=cmd.qualified_name, description=description) + + # whatever other fancy thing you want + await itx.response.send_message(embed=embed, ephemeral=True) + + +@_help.autocomplete("command") +async def help_autocomplete(itx: discord.Interaction[ClientT], current: str) -> list[app_commands.Choice[str]]: + # Known to exist at runtime, else autocomplete would not trigger. + tree: app_commands.CommandTree = getattr(itx.client, "tree") # noqa: B009 + + commands = list(tree.walk_commands(guild=None, type=discord.AppCommandType.chat_input)) + + if itx.guild is not None: + commands.extend(tree.walk_commands(guild=itx.guild, type=discord.AppCommandType.chat_input)) + + choices: list[app_commands.Choice[str]] = [] + for command in commands: + name = command.qualified_name + if current in name: + choices.append(app_commands.Choice(name=name, value=name)) + + # Only show unique commands + choices = sorted(set(choices), key=lambda c: c.name) + return choices[:25] diff --git a/core/wave.py b/core/wave.py index 6dce7a9..eb739a9 100644 --- a/core/wave.py +++ b/core/wave.py @@ -4,8 +4,7 @@ from __future__ import annotations -from collections.abc import AsyncIterator, Iterable -from typing import cast +from collections.abc import AsyncIterable, Iterable import discord import wavelink @@ -15,7 +14,8 @@ __all__ = ("SkippableQueue", "SkippablePlayer") AnyTrack = wavelink.Playable | spotify.SpotifyTrack -AnyTrackIterable = list[wavelink.Playable] | list[spotify.SpotifyTrack] | spotify.SpotifyAsyncIterator +AnyTrackIterator = list[wavelink.Playable] | list[spotify.SpotifyTrack] | spotify.SpotifyAsyncIterator +AnyTrackIterable = Iterable[wavelink.Playable] | Iterable[spotify.SpotifyTrack] | AsyncIterable[spotify.SpotifyTrack] class SkippableQueue(wavelink.Queue): @@ -50,9 +50,8 @@ async def put_all_wait(self, item: AnyTrack | AnyTrackIterable, requester: str | for sub_item in item: sub_item.requester = requester # type: ignore # Runtime attribute assignment. await self.put_wait(sub_item) - elif isinstance(item, AsyncIterator): - # Awkward casting to satisfy pyright since wavelink isn't fully typed. - async for sub_item in cast(AsyncIterator[spotify.SpotifyTrack], item): + elif isinstance(item, AsyncIterable): + async for sub_item in item: sub_item.requester = requester # type: ignore # Runtime attribute assignment. await self.put_wait(sub_item) else: @@ -66,7 +65,7 @@ class SkippablePlayer(wavelink.Player): Attributes ---------- queue: :class:`SkippableQueue` - A version of :class:`wavelink.Queue` that can be skipped into. + A subclass of :class:`wavelink.Queue` that can be skipped into. """ def __init__( diff --git a/exts/help.py b/exts/help.py index 92b110b..c821ec3 100644 --- a/exts/help.py +++ b/exts/help.py @@ -267,7 +267,7 @@ async def command_autocomplete(self, interaction: core.Interaction, current: str return [ app_commands.Choice(name=command.qualified_name, value=command.qualified_name) for command in await help_command.filter_commands(self.bot.walk_commands(), sort=True) - if current.casefold() in command.qualified_name + if current.casefold() in command.qualified_name.casefold() ][:25] diff --git a/exts/music/music.py b/exts/music/music.py index 03bb0b2..720e737 100644 --- a/exts/music/music.py +++ b/exts/music/music.py @@ -14,7 +14,7 @@ from discord.ext import commands import core -from core.wave import AnyTrack, AnyTrackIterable, SkippablePlayer +from core.wave import AnyTrack, AnyTrackIterator, SkippablePlayer from .utils import MusicQueueView, WavelinkSearchConverter, format_track_embed, generate_tracks_add_notification @@ -139,7 +139,7 @@ async def play( self, ctx: core.GuildContext, *, - search: app_commands.Transform[AnyTrack | AnyTrackIterable, WavelinkSearchConverter], + search: app_commands.Transform[AnyTrack | AnyTrackIterator, WavelinkSearchConverter], ) -> None: """Play audio from a YouTube url or search term. @@ -196,7 +196,7 @@ async def stop(self, ctx: core.GuildContext) -> None: """Stop playback and disconnect the bot from voice.""" if vc := ctx.voice_client: - await vc.disconnect() # type: ignore # Incomplete wavelink typing + await vc.disconnect() await ctx.send("Disconnected from voice channel.") else: await ctx.send("No player to perform this on.") diff --git a/exts/music/utils.py b/exts/music/utils.py index af5008c..c2d42cb 100644 --- a/exts/music/utils.py +++ b/exts/music/utils.py @@ -15,7 +15,7 @@ from wavelink.ext import spotify from core.utils import EMOJI_STOCK, PaginatedEmbedView -from core.wave import AnyTrack, AnyTrackIterable +from core.wave import AnyTrack, AnyTrackIterator escape_markdown = functools.partial(discord.utils.escape_markdown, as_needed=True) @@ -42,7 +42,7 @@ def format_page(self) -> discord.Embed: return embed_page -class WavelinkSearchConverter(commands.Converter[AnyTrack | AnyTrackIterable], discord.app_commands.Transformer): +class WavelinkSearchConverter(commands.Converter[AnyTrack | AnyTrackIterator], discord.app_commands.Transformer): """Converts to what Wavelink considers a playable track (:class:`AnyPlayable` or :class:`AnyTrackIterable`). The lookup strategy is as follows (in order): @@ -81,7 +81,7 @@ def _get_search_type(argument: str) -> type[AnyTrack]: return search_type - async def _convert(self, argument: str) -> AnyTrack | AnyTrackIterable: + async def _convert(self, argument: str) -> AnyTrack | AnyTrackIterator: """Attempt to convert a string into a Wavelink track or list of tracks.""" search_type = self._get_search_type(argument) @@ -104,10 +104,10 @@ async def _convert(self, argument: str) -> AnyTrack | AnyTrackIterable: return tracks # Who needs narrowing anyway? - async def convert(self, ctx: commands.Context[Any], argument: str) -> AnyTrack | AnyTrackIterable: + async def convert(self, ctx: commands.Context[Any], argument: str) -> AnyTrack | AnyTrackIterator: return await self._convert(argument) - async def transform(self, _: discord.Interaction, value: str, /) -> AnyTrack | AnyTrackIterable: + async def transform(self, _: discord.Interaction, value: str, /) -> AnyTrack | AnyTrackIterator: return await self._convert(value) async def autocomplete( # type: ignore # Narrowing the types of the input value and return value, I guess. @@ -117,7 +117,13 @@ async def autocomplete( # type: ignore # Narrowing the types of the input value ) -> list[discord.app_commands.Choice[str]]: search_type = self._get_search_type(value) tracks = await search_type.search(value) - return [discord.app_commands.Choice(name=track.title, value=track.uri or track.title) for track in tracks][:25] + if isinstance(tracks, list): + choices = [ + discord.app_commands.Choice(name=track.title, value=track.uri or track.title) for track in tracks + ][:25] + else: + choices = [discord.app_commands.Choice(name=tracks.name, value=tracks.uri or tracks.name)] + return choices async def format_track_embed(title: str, track: AnyTrack) -> discord.Embed: @@ -154,7 +160,7 @@ async def format_track_embed(title: str, track: AnyTrack) -> discord.Embed: return embed -async def generate_tracks_add_notification(tracks: AnyTrack | AnyTrackIterable) -> str: +async def generate_tracks_add_notification(tracks: AnyTrack | AnyTrackIterator) -> str: """Returns the appropriate notification string for tracks or a collection of tracks being added to a queue.""" if isinstance(tracks, wavelink.YouTubePlaylist | wavelink.SoundCloudPlaylist): @@ -164,6 +170,6 @@ async def generate_tracks_add_notification(tracks: AnyTrack | AnyTrackIterable) if isinstance(tracks, list): return f"Added `{tracks[0].title}` to the queue." if isinstance(tracks, spotify.SpotifyAsyncIterator): - return f"Added `{tracks._count}` tracks to the queue." # type: ignore # This avoids iterating through it again. + return f"Added `{tracks._count}` tracks to the queue." # This avoids iterating through it again. return f"Added `{tracks.title}` to the queue." diff --git a/pyproject.toml b/pyproject.toml index 2841d77..bb99516 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ target-version = ["py310"] [tool.ruff] # Credit to @mikeshardmind for most of this setup. -include = ["main.py", "core/*.py", "exts/*.pyi", "**/pyproject.toml"] +include = ["main.py", "core/*.py", "exts/*.py", "**/pyproject.toml"] line-length = 120 target-version = "py310" select = [ @@ -85,9 +85,9 @@ combine-as-imports = true include = ["main.py", "core", "exts"] pythonVersion = "3.10" typeCheckingMode = "strict" -useLibraryCodeForTypes = true +# useLibraryCodeForTypes = true reportMissingTypeStubs = "none" -reportImportCycles = "warning" +# reportImportCycles = "warning" reportPropertyTypeMismatch = "warning" reportUnnecessaryTypeIgnoreComment = "warning" diff --git a/requirements.txt b/requirements.txt index 3dd39d0..ee54bfe 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,4 +17,5 @@ Pillow>=10.0.0 tatsu @ git+https://github.com/Sachaa-Thanasius/Tatsu.git typing_extensions>=4.5.0,<5 wavelink>=2.6.1,<3 +wavelink-stubs yarl>=1.8.2,<2