Skip to content

Commit

Permalink
Minor cleanup and music adjustments after installing wavelink-stubs
Browse files Browse the repository at this point in the history
  • Loading branch information
Sachaa-Thanasius committed Nov 7, 2023
1 parent 1e52de9 commit 2558d98
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 24 deletions.
4 changes: 2 additions & 2 deletions core/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import discord
from discord import Client, Interaction
from discord.app_commands import AppCommandError, Command, CommandTree, Group, Namespace
from discord.ext.commands import Cog
from discord.ext import commands


if TYPE_CHECKING:
Expand All @@ -25,7 +25,7 @@
T = TypeVar("T")
Coro: TypeAlias = Coroutine[Any, Any, T]
CoroFunc: TypeAlias = Callable[..., Coro[Any]]
GroupT = TypeVar("GroupT", bound=Group | Cog)
GroupT = TypeVar("GroupT", bound=Group | commands.Cog)
AppHook: TypeAlias = Callable[[GroupT, Interaction[Any]], Coro[Any]] | Callable[[Interaction[Any]], Coro[Any]]

__all__ = ("before_app_invoke", "after_app_invoke", "HookableTree")
Expand Down
71 changes: 71 additions & 0 deletions core/utils/app_help_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from __future__ import annotations

from typing import Any

import discord
from discord import app_commands
from discord._types import ClientT


def get_nested_command(
tree: app_commands.CommandTree[ClientT],
name: str,
*,
guild: discord.Guild | None,
) -> app_commands.Command[Any, ..., Any] | app_commands.Group | None:
key, *keys = name.split(" ")
cmd = tree.get_command(key, guild=guild) or tree.get_command(key)

for key in keys:
if cmd is None:
return None
if isinstance(cmd, app_commands.Command):
break

cmd = cmd.get_command(key)

return cmd


@app_commands.command(name="help2")
async def _help(itx: discord.Interaction[ClientT], command: str) -> None:
tree: app_commands.CommandTree | None = getattr(itx.client, "tree", None)
if tree is None:
await itx.response.send_message("Could not find a command tree", ephemeral=True)
return

cmd = get_nested_command(tree, command, guild=itx.guild)
if cmd is None:
await itx.response.send_message(f"Could not find a command named {command}", ephemeral=True)
return

if isinstance(cmd, app_commands.Command):
description = cmd.callback.__doc__ or cmd.description
else:
description = cmd.__doc__ or cmd.description

embed = discord.Embed(title=cmd.qualified_name, description=description)

# whatever other fancy thing you want
await itx.response.send_message(embed=embed, ephemeral=True)


@_help.autocomplete("command")
async def help_autocomplete(itx: discord.Interaction[ClientT], current: str) -> list[app_commands.Choice[str]]:
# Known to exist at runtime, else autocomplete would not trigger.
tree: app_commands.CommandTree = getattr(itx.client, "tree") # noqa: B009

commands = list(tree.walk_commands(guild=None, type=discord.AppCommandType.chat_input))

if itx.guild is not None:
commands.extend(tree.walk_commands(guild=itx.guild, type=discord.AppCommandType.chat_input))

choices: list[app_commands.Choice[str]] = []
for command in commands:
name = command.qualified_name
if current in name:
choices.append(app_commands.Choice(name=name, value=name))

# Only show unique commands
choices = sorted(set(choices), key=lambda c: c.name)
return choices[:25]
13 changes: 6 additions & 7 deletions core/wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@

from __future__ import annotations

from collections.abc import AsyncIterator, Iterable
from typing import cast
from collections.abc import AsyncIterable, Iterable

import discord
import wavelink
Expand All @@ -15,7 +14,8 @@
__all__ = ("SkippableQueue", "SkippablePlayer")

AnyTrack = wavelink.Playable | spotify.SpotifyTrack
AnyTrackIterable = list[wavelink.Playable] | list[spotify.SpotifyTrack] | spotify.SpotifyAsyncIterator
AnyTrackIterator = list[wavelink.Playable] | list[spotify.SpotifyTrack] | spotify.SpotifyAsyncIterator
AnyTrackIterable = Iterable[wavelink.Playable] | Iterable[spotify.SpotifyTrack] | AsyncIterable[spotify.SpotifyTrack]


class SkippableQueue(wavelink.Queue):
Expand Down Expand Up @@ -50,9 +50,8 @@ async def put_all_wait(self, item: AnyTrack | AnyTrackIterable, requester: str |
for sub_item in item:
sub_item.requester = requester # type: ignore # Runtime attribute assignment.
await self.put_wait(sub_item)
elif isinstance(item, AsyncIterator):
# Awkward casting to satisfy pyright since wavelink isn't fully typed.
async for sub_item in cast(AsyncIterator[spotify.SpotifyTrack], item):
elif isinstance(item, AsyncIterable):
async for sub_item in item:
sub_item.requester = requester # type: ignore # Runtime attribute assignment.
await self.put_wait(sub_item)
else:
Expand All @@ -66,7 +65,7 @@ class SkippablePlayer(wavelink.Player):
Attributes
----------
queue: :class:`SkippableQueue`
A version of :class:`wavelink.Queue` that can be skipped into.
A subclass of :class:`wavelink.Queue` that can be skipped into.
"""

def __init__(
Expand Down
2 changes: 1 addition & 1 deletion exts/help.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ async def command_autocomplete(self, interaction: core.Interaction, current: str
return [
app_commands.Choice(name=command.qualified_name, value=command.qualified_name)
for command in await help_command.filter_commands(self.bot.walk_commands(), sort=True)
if current.casefold() in command.qualified_name
if current.casefold() in command.qualified_name.casefold()
][:25]


Expand Down
6 changes: 3 additions & 3 deletions exts/music/music.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from discord.ext import commands

import core
from core.wave import AnyTrack, AnyTrackIterable, SkippablePlayer
from core.wave import AnyTrack, AnyTrackIterator, SkippablePlayer

from .utils import MusicQueueView, WavelinkSearchConverter, format_track_embed, generate_tracks_add_notification

Expand Down Expand Up @@ -139,7 +139,7 @@ async def play(
self,
ctx: core.GuildContext,
*,
search: app_commands.Transform[AnyTrack | AnyTrackIterable, WavelinkSearchConverter],
search: app_commands.Transform[AnyTrack | AnyTrackIterator, WavelinkSearchConverter],
) -> None:
"""Play audio from a YouTube url or search term.
Expand Down Expand Up @@ -196,7 +196,7 @@ async def stop(self, ctx: core.GuildContext) -> None:
"""Stop playback and disconnect the bot from voice."""

if vc := ctx.voice_client:
await vc.disconnect() # type: ignore # Incomplete wavelink typing
await vc.disconnect()
await ctx.send("Disconnected from voice channel.")
else:
await ctx.send("No player to perform this on.")
Expand Down
22 changes: 14 additions & 8 deletions exts/music/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from wavelink.ext import spotify

from core.utils import EMOJI_STOCK, PaginatedEmbedView
from core.wave import AnyTrack, AnyTrackIterable
from core.wave import AnyTrack, AnyTrackIterator


escape_markdown = functools.partial(discord.utils.escape_markdown, as_needed=True)
Expand All @@ -42,7 +42,7 @@ def format_page(self) -> discord.Embed:
return embed_page


class WavelinkSearchConverter(commands.Converter[AnyTrack | AnyTrackIterable], discord.app_commands.Transformer):
class WavelinkSearchConverter(commands.Converter[AnyTrack | AnyTrackIterator], discord.app_commands.Transformer):
"""Converts to what Wavelink considers a playable track (:class:`AnyPlayable` or :class:`AnyTrackIterable`).
The lookup strategy is as follows (in order):
Expand Down Expand Up @@ -81,7 +81,7 @@ def _get_search_type(argument: str) -> type[AnyTrack]:

return search_type

async def _convert(self, argument: str) -> AnyTrack | AnyTrackIterable:
async def _convert(self, argument: str) -> AnyTrack | AnyTrackIterator:
"""Attempt to convert a string into a Wavelink track or list of tracks."""

search_type = self._get_search_type(argument)
Expand All @@ -104,10 +104,10 @@ async def _convert(self, argument: str) -> AnyTrack | AnyTrackIterable:
return tracks

# Who needs narrowing anyway?
async def convert(self, ctx: commands.Context[Any], argument: str) -> AnyTrack | AnyTrackIterable:
async def convert(self, ctx: commands.Context[Any], argument: str) -> AnyTrack | AnyTrackIterator:
return await self._convert(argument)

async def transform(self, _: discord.Interaction, value: str, /) -> AnyTrack | AnyTrackIterable:
async def transform(self, _: discord.Interaction, value: str, /) -> AnyTrack | AnyTrackIterator:
return await self._convert(value)

async def autocomplete( # type: ignore # Narrowing the types of the input value and return value, I guess.
Expand All @@ -117,7 +117,13 @@ async def autocomplete( # type: ignore # Narrowing the types of the input value
) -> list[discord.app_commands.Choice[str]]:
search_type = self._get_search_type(value)
tracks = await search_type.search(value)
return [discord.app_commands.Choice(name=track.title, value=track.uri or track.title) for track in tracks][:25]
if isinstance(tracks, list):
choices = [
discord.app_commands.Choice(name=track.title, value=track.uri or track.title) for track in tracks
][:25]
else:
choices = [discord.app_commands.Choice(name=tracks.name, value=tracks.uri or tracks.name)]
return choices


async def format_track_embed(title: str, track: AnyTrack) -> discord.Embed:
Expand Down Expand Up @@ -154,7 +160,7 @@ async def format_track_embed(title: str, track: AnyTrack) -> discord.Embed:
return embed


async def generate_tracks_add_notification(tracks: AnyTrack | AnyTrackIterable) -> str:
async def generate_tracks_add_notification(tracks: AnyTrack | AnyTrackIterator) -> str:
"""Returns the appropriate notification string for tracks or a collection of tracks being added to a queue."""

if isinstance(tracks, wavelink.YouTubePlaylist | wavelink.SoundCloudPlaylist):
Expand All @@ -164,6 +170,6 @@ async def generate_tracks_add_notification(tracks: AnyTrack | AnyTrackIterable)
if isinstance(tracks, list):
return f"Added `{tracks[0].title}` to the queue."
if isinstance(tracks, spotify.SpotifyAsyncIterator):
return f"Added `{tracks._count}` tracks to the queue." # type: ignore # This avoids iterating through it again.
return f"Added `{tracks._count}` tracks to the queue." # This avoids iterating through it again.

return f"Added `{tracks.title}` to the queue."
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ target-version = ["py310"]

[tool.ruff]
# Credit to @mikeshardmind for most of this setup.
include = ["main.py", "core/*.py", "exts/*.pyi", "**/pyproject.toml"]
include = ["main.py", "core/*.py", "exts/*.py", "**/pyproject.toml"]
line-length = 120
target-version = "py310"
select = [
Expand Down Expand Up @@ -85,9 +85,9 @@ combine-as-imports = true
include = ["main.py", "core", "exts"]
pythonVersion = "3.10"
typeCheckingMode = "strict"
useLibraryCodeForTypes = true
# useLibraryCodeForTypes = true

reportMissingTypeStubs = "none"
reportImportCycles = "warning"
# reportImportCycles = "warning"
reportPropertyTypeMismatch = "warning"
reportUnnecessaryTypeIgnoreComment = "warning"
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ Pillow>=10.0.0
tatsu @ git+https://github.com/Sachaa-Thanasius/Tatsu.git
typing_extensions>=4.5.0,<5
wavelink>=2.6.1,<3
wavelink-stubs
yarl>=1.8.2,<2

0 comments on commit 2558d98

Please sign in to comment.