diff --git a/extensions/automod.py b/extensions/automod.py index 6b3cacf..6731835 100644 --- a/extensions/automod.py +++ b/extensions/automod.py @@ -16,7 +16,7 @@ from models.events import AutoModMessageFlagEvent from models.plugin import SnedPlugin from utils import helpers -from utils.ratelimiter import BucketType +from utils.ratelimiter import MemberBucket INVITE_REGEX = re.compile(r"(?:https?://)?discord(?:app)?\.(?:com/invite|gg)/[a-zA-Z0-9]+/?") """Used to detect and handle Discord invites.""" @@ -25,12 +25,12 @@ DISCORD_FORMATTING_REGEX = re.compile(r"<\S+>") """Remove Discord-specific formatting. Performance is key so some false-positives are acceptable here.""" -SPAM_RATELIMITER = utils.RateLimiter(10, 8, bucket=BucketType.MEMBER, wait=False) -PUNISH_RATELIMITER = utils.RateLimiter(30, 1, bucket=BucketType.MEMBER, wait=False) -ATTACH_SPAM_RATELIMITER = utils.RateLimiter(30, 2, bucket=BucketType.MEMBER, wait=False) -LINK_SPAM_RATELIMITER = utils.RateLimiter(30, 2, bucket=BucketType.MEMBER, wait=False) -ESCALATE_PREWARN_RATELIMITER = utils.RateLimiter(30, 1, bucket=BucketType.MEMBER, wait=False) -ESCALATE_RATELIMITER = utils.RateLimiter(30, 1, bucket=BucketType.MEMBER, wait=False) +SPAM_RATELIMITER = utils.RateLimiter(10, 8, bucket=MemberBucket, wait=False) +PUNISH_RATELIMITER = utils.RateLimiter(30, 1, bucket=MemberBucket, wait=False) +ATTACH_SPAM_RATELIMITER = utils.RateLimiter(30, 2, bucket=MemberBucket, wait=False) +LINK_SPAM_RATELIMITER = utils.RateLimiter(30, 2, bucket=MemberBucket, wait=False) +ESCALATE_PREWARN_RATELIMITER = utils.RateLimiter(30, 1, bucket=MemberBucket, wait=False) +ESCALATE_RATELIMITER = utils.RateLimiter(30, 1, bucket=MemberBucket, wait=False) logger = logging.getLogger(__name__) diff --git a/extensions/fun.py b/extensions/fun.py index db18466..da4337f 100644 --- a/extensions/fun.py +++ b/extensions/fun.py @@ -23,8 +23,9 @@ from models.context import SnedContext, SnedUserContext from models.plugin import SnedPlugin from models.views import AuthorOnlyNavigator, AuthorOnlyView -from utils import BucketType, RateLimiter, helpers +from utils import GlobalBucket, RateLimiter, helpers from utils.dictionaryapi import DictionaryClient, DictionaryEntry, DictionaryException, UrbanEntry +from utils.ratelimiter import UserBucket from utils.rpn import InvalidExpressionError, Solver ANIMAL_EMOJI_MAPPING: dict[str, str] = { @@ -37,7 +38,10 @@ "racoon": "🦝", } -animal_ratelimiter = RateLimiter(60, 45, BucketType.GLOBAL, wait=False) +ANIMAL_RATELIMITER = RateLimiter(60, 45, GlobalBucket, wait=False) +COMF_LIMITER = RateLimiter(60, 5, UserBucket, wait=False) +VESZTETTEM_LIMITER = RateLimiter(1800, 1, GlobalBucket, wait=False) +COMF_PROGRESS_BAR_WIDTH = 20 logger = logging.getLogger(__name__) @@ -903,8 +907,8 @@ async def on_dice_reroll(event: miru.ComponentInteractionCreateEvent) -> None: @lightbulb.command("animal", "Shows a random picture of the selected animal.", pass_options=True) @lightbulb.implements(lightbulb.SlashCommand) async def animal(ctx: SnedSlashContext, animal: str) -> None: - await animal_ratelimiter.acquire(ctx) - if animal_ratelimiter.is_rate_limited(ctx): + await ANIMAL_RATELIMITER.acquire(ctx) + if ANIMAL_RATELIMITER.is_rate_limited(ctx): await ctx.respond( embed=hikari.Embed( title="❌ Ratelimited", @@ -983,27 +987,20 @@ async def wiki(ctx: SnedSlashContext, query: str) -> None: await ctx.respond(embed=embed) -vesztettem_limiter = RateLimiter(1800, 1, BucketType.GLOBAL, wait=False) - - @fun.listener(hikari.GuildMessageCreateEvent) async def lose_autoresponse(event: hikari.GuildMessageCreateEvent) -> None: if event.guild_id not in (Config().DEBUG_GUILDS or (1012448659029381190,)) or not event.is_human: return if event.content and "vesztettem" in event.content.lower(): - await vesztettem_limiter.acquire(event.message) + await VESZTETTEM_LIMITER.acquire(event.message) - if vesztettem_limiter.is_rate_limited(event.message): + if VESZTETTEM_LIMITER.is_rate_limited(event.message): return await event.message.respond("Vesztettem") -comf_ratelimiter = RateLimiter(60, 5, BucketType.USER, wait=False) -COMF_PROGRESS_BAR_WIDTH = 20 - - @fun.command @lightbulb.app_command_permissions(None, dm_enabled=False) @lightbulb.command("comf", "Shows your current and upcoming comfiness.") @@ -1011,8 +1008,8 @@ async def lose_autoresponse(event: hikari.GuildMessageCreateEvent) -> None: async def comf(ctx: SnedSlashContext) -> None: assert ctx.member is not None - await comf_ratelimiter.acquire(ctx) - if comf_ratelimiter.is_rate_limited(ctx): + await COMF_LIMITER.acquire(ctx) + if COMF_LIMITER.is_rate_limited(ctx): await ctx.respond( embed=hikari.Embed( title="❌ Ratelimited", diff --git a/extensions/role_buttons.py b/extensions/role_buttons.py index 0a3f36c..ece9375 100644 --- a/extensions/role_buttons.py +++ b/extensions/role_buttons.py @@ -11,7 +11,7 @@ from models.plugin import SnedPlugin from models.rolebutton import RoleButton, RoleButtonMode from utils import helpers -from utils.ratelimiter import BucketType, RateLimiter +from utils.ratelimiter import MemberBucket, RateLimiter logger = logging.getLogger(__name__) @@ -29,7 +29,7 @@ "Remove": RoleButtonMode.REMOVE_ONLY, } -role_button_ratelimiter = RateLimiter(2, 1, BucketType.MEMBER, wait=False) +role_button_ratelimiter = RateLimiter(2, 1, MemberBucket, wait=False) class RoleButtonConfirmType(enum.Enum): diff --git a/utils/ratelimiter.py b/utils/ratelimiter.py index 77ba662..c45e9ea 100644 --- a/utils/ratelimiter.py +++ b/utils/ratelimiter.py @@ -1,31 +1,97 @@ from __future__ import annotations +import abc import asyncio -import enum import sys import time import traceback import typing as t from collections import deque +import attr import hikari -import lightbulb -import miru -class BucketType(enum.IntEnum): - """All possible ratelimiter bucket types.""" +class ContextLike(t.Protocol): + """An object that has common attributes of a context.""" + + @property + def author(self) -> hikari.UndefinedOr[hikari.User]: + ... + + @property + def guild_id(self) -> hikari.Snowflake | None: + ... + + @property + def channel_id(self) -> hikari.Snowflake: + ... + + +@attr.define() +class BucketData: + """Handles the ratelimiting of a single bucket data. (E.g. a single user or a channel.)""" + + reset_at: float + """The time at which the bucket resets.""" + remaining: int + """The amount of requests remaining in the bucket.""" + bucket: Bucket + """The bucket this data belongs to.""" + queue: t.Deque[asyncio.Event] = attr.field(factory=deque) + """A list of events to set as the iter task proceeds.""" + task: asyncio.Task[t.Any] | None = attr.field(default=None) + """The task that is currently iterating over the queue.""" + + @classmethod + def for_bucket(cls, bucket: Bucket) -> BucketData: + """Create a new BucketData for a Bucket.""" + return cls( + bucket=bucket, + reset_at=time.monotonic() + bucket.period, + remaining=bucket.limit, + ) + + def start_queue(self) -> None: + """Start the queue of a BucketData. + This will start setting events in the queue until the bucket is ratelimited. + """ + if self.task is None: + self.task = asyncio.create_task(self._iter_queue()) + + def reset(self) -> None: + """Reset the ratelimit.""" + self.remaining = self.bucket.limit + self.reset_at = time.monotonic() + self.bucket.period + + async def _iter_queue(self) -> None: + """Iterate over the queue of a BucketData and set events.""" + try: + if self.remaining <= 0 and self.reset_at > time.monotonic(): + # Sleep until ratelimit expires + sleep_time = self.reset_at - time.monotonic() + await asyncio.sleep(sleep_time) + self.reset() + elif self.reset_at <= time.monotonic(): + self.reset() - GLOBAL = 0 - GUILD = 1 - CHANNEL = 2 - USER = 3 - MEMBER = 4 + # Set events while not ratelimited + while self.remaining > 0 and self.queue: + self.remaining -= 1 + self.queue.popleft().set() + self.task = None + + except Exception as e: + print(f"Task Exception was never retrieved: {e}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) -class RateLimiter: - def __init__(self, period: float, limit: int, bucket: BucketType, wait: bool = True) -> None: - """Rate Limiter implementation for Sned + +class Bucket(abc.ABC): + """Abstract class for ratelimiter buckets.""" + + def __init__(self, period: float, limit: int, wait: bool = True) -> None: + """Abstract class for ratelimiter buckets. Parameters ---------- @@ -33,93 +99,126 @@ def __init__(self, period: float, limit: int, bucket: BucketType, wait: bool = T The period, in seconds, after which the quota resets. limit : int The amount of requests allowed in a quota. - bucket : BucketType - The bucket to handle this under. wait : bool Determines if the ratelimiter should wait in case of hitting a ratelimit. """ self.period: float = period self.limit: int = limit - self.bucket: BucketType = bucket - self.wait: bool = False - - self._bucket_data = {} + self.wait: bool = wait + self._bucket_data: t.Dict[str, BucketData] = {} - self._queue: t.Deque[asyncio.Event] = deque() - self._task: asyncio.Task[t.Any] | None = None + @abc.abstractmethod + def get_key(self, ctx: ContextLike) -> str: + """Get key for ratelimiter bucket""" - def _get_key(self, ctx_or_message: lightbulb.Context | miru.Context | hikari.PartialMessage) -> str: - """Get key for cooldown bucket""" - - assert ctx_or_message.member and ctx_or_message.author - - keys = { - BucketType.GLOBAL: 0, - BucketType.GUILD: ctx_or_message.guild_id, - BucketType.CHANNEL: ctx_or_message.channel_id, - BucketType.USER: ctx_or_message.author.id, - BucketType.MEMBER: int(str(ctx_or_message.guild_id) + str(ctx_or_message.member.id)), - } - - return keys[self.bucket] - - def is_rate_limited(self, ctx_or_message: lightbulb.Context | miru.Context | hikari.PartialMessage) -> bool: + def is_rate_limited(self, ctx: ContextLike) -> bool: """Returns a boolean determining if the ratelimiter is ratelimited or not.""" now = time.monotonic() - key = self._get_key(ctx_or_message) - if bucket_item := self._bucket_data.get(key): - if bucket_item["reset_at"] <= now: - bucket_item["remaining"] = self.limit - bucket_item["reset_at"] = now + self.period + if data := self._bucket_data.get(self.get_key(ctx)): + if data.reset_at <= now: return False - return bucket_item["remaining"] <= 0 - - self._bucket_data[key] = {"reset_at": now + self.period, "remaining": self.limit} + return data.remaining <= 0 return False - async def acquire(self, ctx_or_message: lightbulb.Context | miru.Context | hikari.PartialMessage) -> None: + async def acquire(self, ctx: ContextLike) -> None: """Acquire a ratelimit, block execution if ratelimited and wait is True.""" event = asyncio.Event() - self._queue.append(event) - - if self._task is None: - self._task = asyncio.create_task(self._iter_queue(ctx_or_message)) + # Get or insert bucket data + data = self._bucket_data.setdefault(self.get_key(ctx), BucketData.for_bucket(self)) + data.queue.append(event) + data.start_queue() if self.wait: await event.wait() - async def _iter_queue(self, ctx: lightbulb.Context | miru.Context | hikari.PartialMessage) -> None: - try: - if not self._queue: - self._task = None - return + def reset(self, ctx: ContextLike) -> None: + """Reset the ratelimit for a given context.""" + if data := self._bucket_data.get(self.get_key(ctx)): + data.reset() - if self.is_rate_limited(ctx): - # Sleep until ratelimit expires - key = self._get_key(ctx) - bucket_item = self._bucket_data[key] - sleep_time = bucket_item["reset_at"] - time.monotonic() - await asyncio.sleep(sleep_time) - # Set events while not ratelimited - while not self.is_rate_limited(ctx) and self._queue: - key = self._get_key(ctx) +class GlobalBucket(Bucket): + """Ratelimiter bucket for global ratelimits.""" - if bucket_item := self._bucket_data.get(key): - bucket_item["remaining"] -= 1 - else: - self._bucket_data[key] = {"reset_at": time.monotonic() + self.period, "remaining": self.limit - 1} + def get_key(self, _: ContextLike) -> str: + return "amongus" - self._queue.popleft().set() - self._task = None +class GuildBucket(Bucket): + """Ratelimiter bucket for guilds. - except Exception as e: - print(f"Task Exception was never retrieved: {e}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + Note that all ContextLike objects must have a guild_id set. + """ + + def get_key(self, ctx: ContextLike) -> str: + if not ctx.guild_id: + raise KeyError("guild_id is not set.") + return str(ctx.guild_id) + + +class ChannelBucket(Bucket): + """Ratelimiter bucket for channels.""" + + def get_key(self, ctx: ContextLike) -> str: + return str(ctx.channel_id) + + +class UserBucket(Bucket): + """Ratelimiter bucket for users. + + Note that all ContextLike objects must have an author set. + """ + + def get_key(self, ctx: ContextLike) -> str: + if not ctx.author: + raise KeyError("author is not set.") + return str(ctx.author.id) + + +class MemberBucket(Bucket): + """Ratelimiter bucket for members. + + Note that all ContextLike objects must have an author and guild_id set. + """ + + def get_key(self, ctx: ContextLike) -> str: + if not ctx.author or not ctx.guild_id: + raise KeyError("author or guild_id is not set.") + return str(ctx.author.id) + str(ctx.guild_id) + + +class RateLimiter: + def __init__(self, period: float, limit: int, bucket: t.Type[Bucket], wait: bool = True) -> None: + """Rate Limiter implementation for Sned. + + Parameters + ---------- + period : float + The period, in seconds, after which the quota resets. + limit : int + The amount of requests allowed in a quota. + bucket : Bucket + The bucket to handle this under. + wait : bool + Determines if the ratelimiter should wait in + case of hitting a ratelimit. + """ + self.bucket: Bucket = bucket(period, limit, wait) + + def is_rate_limited(self, ctx: ContextLike) -> bool: + """Returns a boolean determining if the ratelimiter is ratelimited or not.""" + return self.bucket.is_rate_limited(ctx) + + async def acquire(self, ctx: ContextLike) -> None: + """Acquire a ratelimit, block execution if ratelimited and wait is True.""" + return await self.bucket.acquire(ctx) + + def reset(self, ctx: ContextLike) -> None: + """Reset the ratelimit for a given context.""" + self.bucket.reset(ctx) # Copyright (C) 2022-present hypergonial