Skip to content

Commit

Permalink
Massive refactor with focus on pagiation.
Browse files Browse the repository at this point in the history
  • Loading branch information
Sachaa-Thanasius committed Sep 19, 2023
1 parent 31f67e5 commit c6262cf
Show file tree
Hide file tree
Showing 40 changed files with 1,290 additions and 1,442 deletions.
47 changes: 47 additions & 0 deletions .github/workflows/coverage_and_lint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
name: Type Coverage and Linting

on:
push:
pull_request:
types: [opened, reopened, synchronize]

jobs:
check:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.11"]

name: "Type Coverage and Linting @ ${{ matrix.python-version }}"
steps:
- name: "Checkout Repository"
uses: actions/checkout@v3
with:
fetch-depth: 0

- name: "Setup CPython @ ${{ matrix.python-version }}"
id: setup-python
uses: actions/setup-python@v4
with:
python-version: "${{ matrix.python-version }}"

- name: "Install Python deps @ ${{ matrix.python-version }}"
id: install-deps
run: |
python -m pip install --upgrade pip setuptools wheel
python -m pip install -U -r requirements.txt
- name: Setup node.js
uses: actions/setup-node@v3
with:
node-version: '16'

- name: "Run Pyright @ ${{ matrix.python-version }}"
uses: jakebailey/pyright-action@v1

- name: Run Ruff
uses: chartboost/ruff-action@v1

- name: Run Black
uses: psf/black@stable
8 changes: 2 additions & 6 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,6 @@ fabric.properties
# Android studio 3.1+ serialized cache file
.idea/caches/build_file_checksums.ser


# VisualStudioCode.gitignore
.vscode/*
!.vscode/settings.json
Expand All @@ -255,14 +254,11 @@ fabric.properties
# Ruff
.ruff_cache

# Bot data
# Project specific
data/dunk/
logs/

# Bot config
config.json
database/table_insertions.sql

# Bot misc
database/migration_script.py
misc/
test_bot.py
4 changes: 2 additions & 2 deletions core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
__version__ = "0.0.1"

from . import utils as utils, wave as wave
from .bot import *
from . import tree as tree, utils as utils, wave as wave
from .bot import Beira as Beira
from .checks import *
from .config import *
from .context import *
Expand Down
43 changes: 10 additions & 33 deletions core/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
import aiohttp
import asyncpg
import discord
import wavelink
from discord.ext import commands, tasks
from wavelink.ext import spotify

from exts import EXTENSIONS

Expand All @@ -21,8 +23,6 @@
from .context import Context


__all__ = ("Beira",)

LOGGER = logging.getLogger(__name__)


Expand Down Expand Up @@ -83,11 +83,10 @@ async def on_ready(self) -> None:
LOGGER.info(f"Logged in as {self.user} (ID: {self.user.id})")

async def setup_hook(self) -> None:
"""Loads variables from the database and local files before the bot connects to the Discord Gateway."""

await self._load_guild_prefixes()
await self._load_blocked_entities()
await self._load_extensions()
await self._connect_lavalink_nodes()
self.app_info = await self.application_info()
self.owner_id = self.app_info.owner.id
self.loop.create_task(self._load_special_friends())
Expand Down Expand Up @@ -149,6 +148,11 @@ async def _load_guild_prefixes(self, guild_id: int | None = None) -> None:
except OSError:
LOGGER.error("Couldn't load guild prefixes from the database. Ignoring for sake of defaults.")

async def _connect_lavalink_nodes(self) -> None:
sc = spotify.SpotifyClient(**self.config["spotify"])
node = wavelink.Node(**self.config["lavalink"])
await wavelink.NodePool.connect(client=self, nodes=[node], spotify=sc)

async def _load_extensions(self) -> None:
"""Loads extensions/cogs.
Expand Down Expand Up @@ -190,42 +194,15 @@ async def set_custom_presence_before(self) -> None:
await self.wait_until_ready()

def is_special_friend(self, user: discord.abc.User, /) -> bool:
"""Checks if a :class:`discord.User` or :class:`discord.Member` is a "special friend" of
this bot's owner.
If a :attr:`special_friends` dict is not set, it is assumed to be on purpose.
Parameters
-----------
user : :class:`discord.abc.User`
The user to check for.
Returns
--------
:class:`bool`
Whether the user is a special friend of the owner.
"""
"""Checks if a :class:`discord.User` or :class:`discord.Member` is a "special friend" of this bot's owner."""

if len(self.special_friends) > 0:
return user.id in self.special_friends.values()

return False

def is_ali(self, user: discord.abc.User, /) -> bool:
"""Checks if a :class:`discord.User` or :class:`discord.Member` is Ali.
If a :attr:`special_friends` dict is not set, it is assumed to be on purpose.
Parameters
-----------
user : :class:`discord.abc.User`
The user to check for.
Returns
--------
:class:`bool`
Whether the user is Ali.
"""
"""Checks if a :class:`discord.User` or :class:`discord.Member` is Ali."""

if len(self.special_friends) > 0:
return user.id == self.special_friends["aeroali"]
Expand Down
77 changes: 46 additions & 31 deletions core/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,29 @@
from __future__ import annotations

from collections.abc import Callable
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Protocol, TypeVar

import discord
from discord import app_commands
from discord.ext import commands
from discord.utils import maybe_coroutine

from .errors import GuildIsBlocked, NotAdmin, NotInBotVoiceChannel, NotOwnerOrFriend, UserIsBlocked
from .errors import CheckAnyFailure, GuildIsBlocked, NotAdmin, NotInBotVoiceChannel, NotOwnerOrFriend, UserIsBlocked


if TYPE_CHECKING:
from discord.app_commands.commands import Check as app_Check
from discord.ext.commands._types import Check # type: ignore # For the sake of type-checking?
from discord.app_commands.commands import Check as AppCheckFunc
from discord.ext.commands._types import Check

from core import Context, GuildContext
from .context import Context, GuildContext

T = TypeVar("T")


class AppCheck(Protocol):
predicate: AppCheckFunc

def __call__(self, coro_or_commands: T) -> T:
...


__all__ = ("is_owner_or_friend", "is_admin", "in_bot_vc", "in_aci100_guild", "is_blocked", "check_any")
Expand All @@ -37,8 +45,7 @@ def is_owner_or_friend() -> Check[Any]:

async def predicate(ctx: Context) -> bool:
if not (await ctx.bot.is_owner(ctx.author) or ctx.bot.is_special_friend(ctx.author)):
msg = "You do not own this bot, nor are you a friend of the owner."
raise NotOwnerOrFriend(msg)
raise NotOwnerOrFriend
return True

return commands.check(predicate)
Expand All @@ -53,11 +60,8 @@ def is_admin() -> Check[Any]:
"""

async def predicate(ctx: GuildContext) -> bool:
assert ctx.guild is not None

if not ctx.author.guild_permissions.administrator:
msg = "Only someone with administrator permissions can do this."
raise NotAdmin(msg)
raise NotAdmin
return True

return commands.check(predicate)
Expand All @@ -72,14 +76,13 @@ def in_bot_vc() -> Check[Any]:
"""

async def predicate(ctx: GuildContext) -> bool:
vc: discord.VoiceProtocol | None = ctx.voice_client
vc = ctx.voice_client

if not (
ctx.author.guild_permissions.administrator
or (vc and ctx.author.voice and (ctx.author.voice.channel == vc.channel))
):
msg = "You are not connected to the same voice channel as the bot."
raise NotInBotVoiceChannel(msg)
raise NotInBotVoiceChannel
return True

return commands.check(predicate)
Expand Down Expand Up @@ -110,42 +113,54 @@ def is_blocked() -> Check[Any]:
async def predicate(ctx: Context) -> bool:
if not (await ctx.bot.is_owner(ctx.author)):
if ctx.author.id in ctx.bot.blocked_entities_cache["users"]:
msg = "This user is prohibited from using bot commands."
raise UserIsBlocked(msg)
raise UserIsBlocked
if ctx.guild and (ctx.guild.id in ctx.bot.blocked_entities_cache["guilds"]):
msg = "This guild is prohibited from using bot commands."
raise GuildIsBlocked(msg)
raise GuildIsBlocked
return True

return commands.check(predicate)


def check_any(*checks: app_Check) -> Callable[..., Any]:
"""An attempt at making a :func:`check_any` decorator for application commands.
# TODO: Actually check if this works.
def check_any(*checks: AppCheck) -> Callable[[T], T]:
"""An attempt at making a `check_any` decorator for application commands that checks if any of the checks passed
will pass, i.e. using logical OR.
If all checks fail then :exc:`CheckAnyFailure` is raised to signal the failure.
It inherits from :exc:`app_commands.CheckFailure`.
Parameters
----------
checks: :class:`app_Check`
checks : :class:`AppCheckProtocol`
An argument list of checks that have been decorated with :func:`app_commands.check` decorator.
Returns
-------
:class:`app_Check`
A predicate that condenses all given checks with logical OR.
Raises
------
TypeError
A check passed has not been decorated with the :func:`app_commands.check` decorator.
"""

# TODO: Actually check if this works.
unwrapped: list[AppCheckFunc] = []
for wrapped in checks:
try:
pred = wrapped.predicate
except AttributeError:
msg = f"{wrapped!r} must be wrapped by app_commands.check decorator"
raise TypeError(msg) from None
else:
unwrapped.append(pred)

async def predicate(interaction: discord.Interaction) -> bool:
errors: list[Exception] = []
for check in checks:
errors: list[app_commands.CheckFailure] = []
for func in unwrapped:
try:
value = await maybe_coroutine(check, interaction)
value = await discord.utils.maybe_coroutine(func, interaction)
except app_commands.CheckFailure as err:
errors.append(err)
else:
if value:
return True
# If we're here, all checks failed.
raise app_commands.CheckFailure(checks, errors)
raise CheckAnyFailure(unwrapped, errors)

return app_commands.check(predicate)
7 changes: 3 additions & 4 deletions core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,15 @@

from typing import TYPE_CHECKING, Any, TypeAlias

import aiohttp
import asyncpg
import discord
from aiohttp import ClientSession
from discord.ext import commands

from .wave import SkippablePlayer


if TYPE_CHECKING:
from .bot import Beira
from .wave import SkippablePlayer


__all__ = ("Context", "GuildContext", "Interaction")
Expand All @@ -39,7 +38,7 @@ def __init__(self, **kwargs: Any) -> None:
self.error_handled = False

@property
def session(self) -> ClientSession:
def session(self) -> aiohttp.ClientSession:
""":class:`ClientSession`: Returns the asynchronous HTTP session used by the bot for HTTP requests."""

return self.bot.web_session
Expand Down
Loading

0 comments on commit c6262cf

Please sign in to comment.