From ea9bed42424a79a55a77fc2c3f0280acbef0d5e6 Mon Sep 17 00:00:00 2001 From: hypergonial <46067571+hypergonial@users.noreply.github.com> Date: Sun, 7 Jan 2024 13:25:05 +0100 Subject: [PATCH] Fix limiters --- arc/utils/hooks/limiters.py | 175 +++++++++++++++++++++++------------- 1 file changed, 112 insertions(+), 63 deletions(-) diff --git a/arc/utils/hooks/limiters.py b/arc/utils/hooks/limiters.py index 5930343..c763aa5 100644 --- a/arc/utils/hooks/limiters.py +++ b/arc/utils/hooks/limiters.py @@ -17,6 +17,7 @@ if t.TYPE_CHECKING: from arc.context.base import Context + __all__ = ( "RateLimiter", "global_limiter", @@ -28,55 +29,81 @@ ) -@attr.define(slots=True) -class _Quota(t.Generic[ClientT]): +@attr.define(slots=True, kw_only=True) +class _Bucket(t.Generic[ClientT]): """Handles the ratelimiting of a single item. (E.g. a single user or a channel).""" + key: str + """The key of the bucket.""" + reset_at: float - """The time at which the quota resets.""" - remaining: int - """The amount of requests remaining until the quota is exhausted.""" - bucket: RateLimiter[ClientT] - """The limiter this quota belongs to.""" - queue: deque[asyncio.Event] = attr.field(factory=deque) + """The time at which the bucket resets.""" + + limiter: RateLimiter[ClientT] + """The limiter this bucket belongs to.""" + + _remaining: int = attr.field(alias="remaining") + """The amount of requests remaining until the bucket is exhausted.""" + + _queue: deque[asyncio.Event] = attr.field(factory=deque, init=False) """A list of events to set as the iter task proceeds.""" - task: asyncio.Task[None] | None = attr.field(default=None) + + _task: asyncio.Task[None] | None = attr.field(default=None, init=False) """The task that is currently iterating over the queue.""" @classmethod - def for_limiter(cls, limiter: RateLimiter[ClientT]) -> _Quota[ClientT]: - """Create a new Quota for a RateLimiter.""" - return cls(bucket=limiter, reset_at=time.monotonic() + limiter.period, remaining=limiter.limit) + def for_limiter(cls, key: str, limiter: RateLimiter[ClientT]) -> _Bucket[ClientT]: + """Create a new bucket for a RateLimiter.""" + return cls(key=key, limiter=limiter, reset_at=time.monotonic() + limiter.period, remaining=limiter.limit) + + @property + def remaining(self) -> int: + """The amount of requests remaining until the bucket is exhausted.""" + if self.reset_at <= time.monotonic(): + self.reset() + return self._remaining + + @remaining.setter + def remaining(self, value: int) -> None: + self._remaining = value + + @property + def is_exhausted(self) -> bool: + """Return a boolean determining if the bucket is exhausted.""" + return self.remaining <= 0 and self.reset_at > time.monotonic() + + @property + def is_stale(self) -> bool: + """Return a boolean determining if the bucket is stale. + If a bucket is stale, it is no longer in use and can be purged. + """ + return not self._queue and self.remaining == self.limiter.limit and (self._task is None or self._task.done()) def start_queue(self) -> None: - """Start the queue of a Quota. + """Start the queue of a bucket. This will start setting events in the queue until the bucket is ratelimited. """ - if self.task is None or self.task.done(): - self.task = asyncio.create_task(self._iter_queue()) + if self._task is None or self._task.done(): + self._task = asyncio.create_task(self._iter_queue()) def reset(self) -> None: - """Reset the ratelimit.""" - self.remaining = self.bucket.limit - self.reset_at = time.monotonic() + self.bucket.period + """Reset the bucket.""" + self.reset_at = time.monotonic() + self.limiter.period + self._remaining = self.limiter.limit async def _iter_queue(self) -> None: """Iterate over the queue and set events until exhausted.""" try: - if self.remaining <= 0 and self.reset_at > time.monotonic(): - # Sleep until ratelimit expires - sleep_time = self.reset_at - time.monotonic() - await asyncio.sleep(sleep_time) - self.reset() - elif self.reset_at <= time.monotonic(): - self.reset() + while self._queue: + if self.remaining <= 0 and self.reset_at > time.monotonic(): + # Sleep until ratelimit expires + await asyncio.sleep(self.reset_at - time.monotonic()) + self.reset() - # Set events while not ratelimited - while self.remaining > 0 and self.queue: - self.remaining -= 1 - self.queue.popleft().set() - - self.task = None + # Set events while not ratelimited + while self.remaining > 0 and self._queue: + self._queue.popleft().set() + self._remaining -= 1 except Exception as e: print(f"Task Exception was never retrieved: {e}", file=sys.stderr) @@ -89,20 +116,21 @@ class RateLimiter(LimiterProto[ClientT]): Parameters ---------- period : float - The period, in seconds, after which the quota resets. + The period, in seconds, after which the bucket resets. limit : int - The amount of requests allowed in a quota. + The amount of requests allowed in a bucket. get_key_with : Callable[[Context[t.Any]], str] A callable that returns a key for the ratelimiter bucket. """ - __slots__ = ("period", "limit", "_quotas", "_get_key") + __slots__ = ("period", "limit", "_buckets", "_get_key") def __init__(self, period: float, limit: int, *, get_key_with: t.Callable[[Context[t.Any]], str]) -> None: self.period: float = period self.limit: int = limit - self._quotas: t.Dict[str, _Quota[ClientT]] = {} + self._buckets: t.Dict[str, _Bucket[ClientT]] = {} self._get_key: t.Callable[[Context[t.Any]], str] = get_key_with + self._gc_task: asyncio.Task[None] | None = None def get_key(self, ctx: Context[t.Any]) -> str: """Get key for ratelimiter bucket.""" @@ -123,14 +151,27 @@ def is_rate_limited(self, ctx: Context[t.Any]) -> bool: """ now = time.monotonic() - if data := self._quotas.get(self.get_key(ctx)): + if data := self._buckets.get(self.get_key(ctx)): if data.reset_at <= now: return False - return data.remaining <= 0 + return data._remaining <= 0 return False + def _start_gc(self) -> None: + """Start the garbage collector task if one is not running.""" + if self._gc_task is None or self._gc_task.done(): + self._gc_task = asyncio.create_task(self._gc()) + + async def _gc(self) -> None: + """Purge stale buckets.""" + while self._buckets: + await asyncio.sleep(self.period + 1.0) + for bucket in list(self._buckets.values()): + if bucket.is_stale: + del self._buckets[bucket.key] + async def acquire(self, ctx: Context[t.Any], *, wait: bool = True) -> None: - """Acquire a quota, block execution if ratelimited and wait is True. + """Acquire a bucket, block execution if ratelimited and wait is True. Parameters ---------- @@ -147,16 +188,23 @@ async def acquire(self, ctx: Context[t.Any], *, wait: bool = True) -> None: """ event = asyncio.Event() - # Get existing or insert new quota - quota = self._quotas.setdefault(self.get_key(ctx), _Quota.for_limiter(self)) - quota.queue.append(event) - quota.start_queue() + key = self.get_key(ctx) + # Get existing or insert new bucket + bucket = self._buckets.setdefault(key, _Bucket.for_limiter(key, self)) + + if bucket.is_exhausted and not wait: + raise UnderCooldownError( + self, + bucket.reset_at - time.monotonic(), + f"Ratelimited for {bucket.reset_at - time.monotonic()} seconds.", + ) + + bucket._queue.append(event) + bucket.start_queue() + self._start_gc() if wait: await event.wait() - elif self.is_rate_limited(ctx): - retry_after = quota.reset_at - time.monotonic() - raise UnderCooldownError(self, retry_after, f"Ratelimited for {retry_after} seconds.") async def __call__(self, ctx: Context[t.Any]) -> HookResult: """Acquire a ratelimit, fail if ratelimited. @@ -181,8 +229,8 @@ async def __call__(self, ctx: Context[t.Any]) -> HookResult: def reset(self, ctx: Context[t.Any]) -> None: """Reset the ratelimit for a given context.""" - if quota := self._quotas.get(self.get_key(ctx)): - quota.reset() + if bucket := self._buckets.get(self.get_key(ctx)): + bucket.reset() def global_limiter(period: float, limit: int) -> RateLimiter[t.Any]: @@ -193,9 +241,9 @@ def global_limiter(period: float, limit: int) -> RateLimiter[t.Any]: Parameters ---------- period : float - The period, in seconds, after which the quota resets. + The period, in seconds, after which the bucket resets. limit : int - The amount of requests allowed in a quota. + The amount of requests allowed in a bucket. """ return RateLimiter(period, limit, get_key_with=lambda _: "amongus") @@ -208,9 +256,9 @@ def guild_limiter(period: float, limit: int) -> RateLimiter[t.Any]: Parameters ---------- period : float - The period, in seconds, after which the quota resets. + The period, in seconds, after which the bucket resets. limit : int - The amount of requests allowed in a quota. + The amount of requests allowed in a bucket. """ return RateLimiter(period, limit, get_key_with=lambda ctx: str(ctx.guild_id)) @@ -223,9 +271,9 @@ def channel_limiter(period: float, limit: int) -> RateLimiter[t.Any]: Parameters ---------- period : float - The period, in seconds, after which the quota resets. + The period, in seconds, after which the bucket resets. limit : int - The amount of requests allowed in a quota. + The amount of requests allowed in a bucket. """ return RateLimiter(period, limit, get_key_with=lambda ctx: str(ctx.channel_id)) @@ -238,9 +286,9 @@ def user_limiter(period: float, limit: int) -> RateLimiter[t.Any]: Parameters ---------- period : float - The period, in seconds, after which the quota resets. + The period, in seconds, after which the bucket resets. limit : int - The amount of requests allowed in a quota. + The amount of requests allowed in a bucket. """ return RateLimiter(period, limit, get_key_with=lambda ctx: str(ctx.author.id)) @@ -249,30 +297,31 @@ def member_limiter(period: float, limit: int) -> RateLimiter[t.Any]: """Create a member ratelimiter. This ratelimiter is shared across all contexts by a member in a guild. - The same user in a different guild will be assigned a different quota. + The same user in a different guild will be assigned a different bucket. Parameters ---------- period : float - The period, in seconds, after which the quota resets. + The period, in seconds, after which the bucket resets. limit : int - The amount of requests allowed in a quota. + The amount of requests allowed in a bucket. """ return RateLimiter(period, limit, get_key_with=lambda ctx: f"{ctx.author.id}:{ctx.guild_id}") def custom_limiter(period: float, limit: int, get_key_with: t.Callable[[Context[t.Any]], str]) -> RateLimiter[t.Any]: - """Create a custom ratelimiter. + """Create a ratelimiter with a custom key extraction function. Parameters ---------- period : float - The period, in seconds, after which the quota resets. + The period, in seconds, after which the bucket resets. limit : int - The amount of requests allowed in a quota. + The amount of requests allowed in a bucket. get_key_with : Callable[[Context[t.Any]], str] - A callable that returns a key for the ratelimiter bucket. - This key is used to identify the bucket. + A callable that returns a key for the ratelimiter bucket. This key is used to identify the bucket. + For instance, to create a ratelimiter that is shared across all contexts in a guild, + you would use `lambda ctx: str(ctx.guild_id)`. """ return RateLimiter(period, limit, get_key_with=get_key_with)