Skip to content

Commit

Permalink
Update to wavelink v3.
Browse files Browse the repository at this point in the history
  • Loading branch information
Sachaa-Thanasius committed Nov 28, 2023
1 parent 36c6a3b commit 224d24b
Show file tree
Hide file tree
Showing 9 changed files with 365 additions and 271 deletions.
8 changes: 3 additions & 5 deletions core/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import openai
import wavelink
from discord.ext import commands
from wavelink.ext import spotify # type: ignore [reportMissingTypeStubs]

from exts import EXTENSIONS

Expand Down Expand Up @@ -98,9 +97,8 @@ async def setup_hook(self) -> None:
await self._load_extensions()

# Connection lavalink nodes.
sc = spotify.SpotifyClient(**CONFIG.spotify.to_dict())
node = wavelink.Node(**CONFIG.lavalink.to_dict())
await wavelink.NodePool.connect(client=self, nodes=[node], spotify=sc)
node = wavelink.Node(uri=CONFIG.lavalink.uri, password=CONFIG.lavalink.password)
await wavelink.Pool.connect(client=self, nodes=[node])

Check failure on line 101 in core/bot.py

View workflow job for this annotation

GitHub Actions / Type Coverage and Linting @ 3.10

"Pool" is not a known member of module "wavelink" (reportGeneralTypeIssues)

Check failure on line 101 in core/bot.py

View workflow job for this annotation

GitHub Actions / Type Coverage and Linting @ 3.10

Type of "Pool" is unknown (reportUnknownMemberType)

Check failure on line 101 in core/bot.py

View workflow job for this annotation

GitHub Actions / Type Coverage and Linting @ 3.10

Type of "connect" is unknown (reportUnknownMemberType)

# Get information about owner.
self.app_info = await self.application_info()
Expand Down Expand Up @@ -148,7 +146,7 @@ async def on_error(self, event_method: str, /, *args: object, **kwargs: object)
)
LOGGER.error("Exception in event %s", event_method, exc_info=exception, extra={"embed": embed})

async def on_command_error(self, context: Context, exception: commands.CommandError) -> None: # type: ignore
async def on_command_error(self, context: Context, exception: commands.CommandError) -> None: # type: ignore # Narrowing
assert context.command # Pre-condition for being here.

if isinstance(exception, commands.CommandNotFound):
Expand Down
4 changes: 2 additions & 2 deletions core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

if TYPE_CHECKING:
from .bot import Beira
from .wave import SkippablePlayer
from .wave import ExtraPlayer


__all__ = ("Context", "GuildContext", "Interaction")
Expand All @@ -32,7 +32,7 @@ class Context(commands.Context["Beira"]):
db
"""

voice_client: SkippablePlayer | None # type: ignore # Type lie for narrowing
voice_client: ExtraPlayer | None # type: ignore # Type lie for narrowing

def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
Expand Down
35 changes: 34 additions & 1 deletion core/tree.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import asyncio
import logging
import traceback
from collections.abc import Callable, Coroutine
from typing import TYPE_CHECKING, Any, ParamSpec, TypeAlias

Expand All @@ -15,7 +17,6 @@
from typing_extensions import TypeVar

ClientT_co = TypeVar("ClientT_co", bound=Client, covariant=True, default=Client)

else:
from typing import TypeVar

Expand All @@ -28,6 +29,8 @@
GroupT = TypeVar("GroupT", bound=Group | commands.Cog)
AppHook: TypeAlias = Callable[[GroupT, Interaction[Any]], Coro[Any]] | Callable[[Interaction[Any]], Coro[Any]]

LOGGER = logging.getLogger(__name__)

__all__ = ("before_app_invoke", "after_app_invoke", "HookableTree")


Expand Down Expand Up @@ -98,6 +101,36 @@ def decorator(inner: Command[GroupT, P, T]) -> Command[GroupT, P, T]:


class HookableTree(CommandTree):
async def on_error(self, interaction: Interaction[Client], error: AppCommandError, /) -> None:
command = interaction.command

error = getattr(error, "original", error)

tb_text = "".join(traceback.format_exception(type(error), error, error.__traceback__, chain=False))
embed = discord.Embed(
title="App Command Error",
description=f"```py\n{tb_text}\n```",
colour=discord.Colour.dark_magenta(),
timestamp=discord.utils.utcnow(),
).set_author(name=str(interaction.user.global_name), icon_url=interaction.user.display_avatar.url)

if command is not None:
embed.add_field(name="Name", value=command.qualified_name, inline=False)

if interaction.namespace:
embed.add_field(
name="Args",
value="```py\n" + "\n".join(f"{name}: {arg!r}" for name, arg in iter(interaction.namespace)) + "\n```",
inline=False,
)
embed.add_field(name="Guild", value=f"{interaction.guild.name if interaction.guild else '-----'}", inline=False)
embed.add_field(name="Channel", value=f"{interaction.channel}", inline=False)

if command is not None:
LOGGER.error("Exception in command %r", command.name, exc_info=error, extra={"embed": embed})
else:
LOGGER.error("Exception in command tree", exc_info=error, extra={"embed": embed})

async def _call(self, interaction: Interaction[ClientT_co]) -> None:
###### Copy the original logic but add hook checks/calls near the end.

Expand Down
94 changes: 91 additions & 3 deletions core/utils/app_help_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,25 @@
from discord._types import ClientT


class SampleTree(app_commands.CommandTree):
def get_nested_command(
self,
name: str,
*,
guild: discord.abc.Snowflake | None = None,
) -> app_commands.Command[Any, ..., Any] | app_commands.Group | None:
...


class SampleClient(discord.Client):
tree: SampleTree


def get_nested_command(
tree: app_commands.CommandTree[ClientT],
name: str,
*,
guild: discord.Guild | None,
guild: discord.Guild | None = None,
) -> app_commands.Command[Any, ..., Any] | app_commands.Group | None:
key, *keys = name.split(" ")
cmd = tree.get_command(key, guild=guild) or tree.get_command(key)
Expand Down Expand Up @@ -51,9 +65,9 @@ async def _help(itx: discord.Interaction[ClientT], command: str) -> None:


@_help.autocomplete("command")
async def help_autocomplete(itx: discord.Interaction[ClientT], current: str) -> list[app_commands.Choice[str]]:
async def help_autocomplete(itx: discord.Interaction[SampleClient], current: str) -> list[app_commands.Choice[str]]:
# Known to exist at runtime, else autocomplete would not trigger.
tree: app_commands.CommandTree = getattr(itx.client, "tree") # noqa: B009
tree = itx.client.tree

commands = list(tree.walk_commands(guild=None, type=discord.AppCommandType.chat_input))

Expand All @@ -69,3 +83,77 @@ async def help_autocomplete(itx: discord.Interaction[ClientT], current: str) ->
# Only show unique commands
choices = sorted(set(choices), key=lambda c: c.name)
return choices[:25]


class CommandTransformer(app_commands.Transformer):
async def autocomplete( # type: ignore # Narrowing interaction and choice
self,
itx: discord.Interaction[SampleClient],
current: str,
/,
) -> list[app_commands.Choice[str]]:
# Known to exist at runtime, else autocomplete would not trigger.
tree = itx.client.tree

return [
app_commands.Choice(name=command.qualified_name, value=command.qualified_name)
for command in tree.walk_commands()
if command.qualified_name.casefold() in current.casefold()
][:25]

async def transform( # type: ignore # Narrowing interaction
self,
itx: discord.Interaction[SampleClient],
value: str,
/,
) -> app_commands.Command[Any, ..., Any] | app_commands.Group:
# Known to exist at runtime, else transform would never be invoked.
tree = itx.client.tree
command = tree.get_command(value)
if command is None:
msg = f"Command {value} not found."
raise ValueError(msg)

return command


class CommandTransformer2(app_commands.Transformer):
async def autocomplete( # type: ignore # Narrowing interaction and choice.
self,
itx: discord.Interaction[SampleClient],
current: str,
/,
) -> list[app_commands.Choice[str]]:
commands = list(itx.client.tree.walk_commands(guild=None, type=discord.AppCommandType.chat_input))

if itx.guild is not None:
commands.extend(itx.client.tree.walk_commands(guild=itx.guild, type=discord.AppCommandType.chat_input))

choices = [
app_commands.Choice(name=name, value=name)
for cmd in commands
if current.casefold() in (name := cmd.qualified_name.casefold())
]

# Only show unique commands
choices = sorted(set(choices), key=lambda c: c.name)
return choices[:25]

async def transform( # type: ignore # Narrowing interaction.
self,
itx: discord.Interaction[SampleClient],
value: str,
/,
) -> app_commands.Command[Any, ..., Any] | app_commands.Group:
command = itx.client.tree.get_nested_command(value)
if command is None:
msg = f"Command {value} not found."
raise ValueError(msg)

return command


CommandTransform2 = app_commands.Transform[
app_commands.Command[Any, ..., Any] | app_commands.Group | None,
CommandTransformer2,
]
8 changes: 4 additions & 4 deletions core/utils/emojis.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
"fof": "<:FoF:856969711241396254>",
"pop": "<:PoP:856969710486814730>",
"mr_jare": "<:Mr_Jare:1061029880059400262>",
"YouTubeTrack": "<:youtube:1108460195270631537>",
"YouTubeMusicTrack": "<:youtubemusic:954046930713985074>",
"SoundCloudTrack": "<:soundcloud:1147265178505846804>",
"SpotifyTrack": "<:spotify:1108458132826501140>",
"youtube": "<:youtube:1108460195270631537>",
"youtubemusic": "<:youtubemusic:954046930713985074>",
"soundcloud": "<:soundcloud:1147265178505846804>",
"spotify": "<:spotify:1108458132826501140>",
"d04": "<a:d04:1109234548727885884>",
"d06": "<a:d06:1109234547389907017>",
"d08": "<a:d08:1109234533041197196>",
Expand Down
120 changes: 59 additions & 61 deletions core/wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,80 +4,78 @@

from __future__ import annotations

from collections.abc import AsyncIterable, Iterable
from typing import TypeAlias

import discord
import wavelink
from wavelink.ext import spotify # type: ignore [reportMissingTypeStubs]


__all__ = ("AnyTrack", "AnyTrackIterator", "AnyTrackIterable", "SkippableQueue", "SkippablePlayer")

AnyTrack: TypeAlias = wavelink.Playable | spotify.SpotifyTrack
AnyTrackIterator: TypeAlias = list[wavelink.Playable] | list[spotify.SpotifyTrack] | spotify.SpotifyAsyncIterator
AnyTrackIterable: TypeAlias = (
Iterable[wavelink.Playable] | Iterable[spotify.SpotifyTrack] | AsyncIterable[spotify.SpotifyTrack]
)


class SkippableQueue(wavelink.Queue):
"""A version of :class:`wavelink.Queue` that can skip to a specific index."""

def remove_before_index(self, index: int) -> None:
"""Remove all members from the queue before a certain index.
Credit to Chillymosh for the implementation.
"""

for _ in range(index):
try:
del self[0]
except IndexError:
break

async def put_all_wait(self, item: AnyTrack | AnyTrackIterable, requester: str | None = None) -> None:
"""Put items individually or from an iterable into the queue asynchronously using await.
This can include some playlist subclasses.
Parameters
----------
item: :class:`AnyPlayable` | :class:`AnyTrackIterable`
The track or collection of tracks to add to the queue.
requester: :class:`str`, optional
A string representing the user who queued this up. Optional.
"""

if isinstance(item, Iterable):
for sub_item in item:
sub_item.requester = requester # type: ignore # Runtime attribute assignment.
await self.put_wait(sub_item)
elif isinstance(item, AsyncIterable):
async for sub_item in item:
sub_item.requester = requester # type: ignore # Runtime attribute assignment.
await self.put_wait(sub_item)
else:
item.requester = requester # type: ignore # Runtime attribute assignment.
await self.put_wait(item)


class SkippablePlayer(wavelink.Player):
__all__ = ("ExtraQueue", "ExtraPlayer")


class ExtraQueue(wavelink.Queue):
"""A version of :class:`wavelink.Queue` with extra operations."""

def put_at(self, index: int, item: wavelink.Playable, /) -> None:
if index >= len(self._queue) or index < 0:
msg = "The index is out of range."
raise IndexError(msg)
self._queue.rotate(-index)
self._queue.appendleft(item)
self._queue.rotate(index)

def skip_to(self, index: int, /) -> None:
if index >= len(self._queue) or index < 0:
msg = "The index is out of range."
raise IndexError(msg)
for _ in range(index - 1):
self.get()

def swap(self, first: int, second: int, /) -> None:
if first >= len(self._queue) or second >= len(self._queue):
msg = "One of the given indices is out of range."
raise IndexError(msg)
if first == second:
msg = "These are the same index; swapping will have no effect."
raise IndexError(msg)
self._queue.rotate(-first)
first_item = self._queue[0]
self._queue.rotate(first - second)
second_item = self._queue.popleft()
self._queue.appendleft(first_item)
self._queue.rotate(second - first)
self._queue.popleft()
self._queue.appendleft(second_item)
self._queue.rotate(first)

def move(self, before: int, after: int, /) -> None:
if before >= len(self._queue) or after >= len(self._queue):
msg = "One of the given indices is out of range."
raise IndexError(msg)
if before == after:
msg = "These are the same index; swapping will have no effect."
raise IndexError(msg)
self._queue.rotate(-before)
item = self._queue.popleft()
self._queue.rotate(before - after)
self._queue.appendleft(item)
self._queue.rotate(after)


class ExtraPlayer(wavelink.Player):
"""A version of :class:`wavelink.Player` with a different queue.
Attributes
----------
queue: :class:`SkippableQueue`
A subclass of :class:`wavelink.Queue` that can be skipped into.
queue: :class:`ExtraQueue`
A version of :class:`wavelink.Queue` with extra operations.
"""

def __init__(
self,
client: discord.Client = discord.utils.MISSING,
channel: discord.VoiceChannel | discord.StageChannel = discord.utils.MISSING,
channel: discord.abc.Connectable = discord.utils.MISSING,
*,
nodes: list[wavelink.Node] | None = None,
swap_node_on_disconnect: bool = True,
) -> None:
super().__init__(client, channel, nodes=nodes, swap_node_on_disconnect=swap_node_on_disconnect)
self.queue: SkippableQueue = SkippableQueue() # type: ignore [reportIncompatibleVariableOverride]
super().__init__(client, channel, nodes=nodes)

Check failure on line 79 in core/wave.py

View workflow job for this annotation

GitHub Actions / Type Coverage and Linting @ 3.10

Argument of type "Connectable" cannot be assigned to parameter "channel" of type "_VoiceChannel" in function "__init__"   Type "Connectable" cannot be assigned to type "_VoiceChannel"     "Connectable" is incompatible with "VoiceChannel"     "Connectable" is incompatible with "StageChannel" (reportGeneralTypeIssues)
self.autoplay = wavelink.AutoPlayMode.partial

Check failure on line 80 in core/wave.py

View workflow job for this annotation

GitHub Actions / Type Coverage and Linting @ 3.10

"AutoPlayMode" is not a known member of module "wavelink" (reportGeneralTypeIssues)

Check failure on line 80 in core/wave.py

View workflow job for this annotation

GitHub Actions / Type Coverage and Linting @ 3.10

Type of "AutoPlayMode" is unknown (reportUnknownMemberType)

Check failure on line 80 in core/wave.py

View workflow job for this annotation

GitHub Actions / Type Coverage and Linting @ 3.10

Type of "partial" is unknown (reportUnknownMemberType)
self.queue: ExtraQueue = ExtraQueue() # type: ignore [reportIncompatibleVariableOverride]
Loading

0 comments on commit 224d24b

Please sign in to comment.