Skip to content

Commit

Permalink
Fix limiters
Browse files Browse the repository at this point in the history
  • Loading branch information
hypergonial committed Jan 7, 2024
1 parent 4e021af commit ea9bed4
Showing 1 changed file with 112 additions and 63 deletions.
175 changes: 112 additions & 63 deletions arc/utils/hooks/limiters.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
if t.TYPE_CHECKING:
from arc.context.base import Context


__all__ = (
"RateLimiter",
"global_limiter",
Expand All @@ -28,55 +29,81 @@
)


@attr.define(slots=True)
class _Quota(t.Generic[ClientT]):
@attr.define(slots=True, kw_only=True)
class _Bucket(t.Generic[ClientT]):
"""Handles the ratelimiting of a single item. (E.g. a single user or a channel)."""

key: str
"""The key of the bucket."""

reset_at: float
"""The time at which the quota resets."""
remaining: int
"""The amount of requests remaining until the quota is exhausted."""
bucket: RateLimiter[ClientT]
"""The limiter this quota belongs to."""
queue: deque[asyncio.Event] = attr.field(factory=deque)
"""The time at which the bucket resets."""

limiter: RateLimiter[ClientT]
"""The limiter this bucket belongs to."""

_remaining: int = attr.field(alias="remaining")
"""The amount of requests remaining until the bucket is exhausted."""

_queue: deque[asyncio.Event] = attr.field(factory=deque, init=False)
"""A list of events to set as the iter task proceeds."""
task: asyncio.Task[None] | None = attr.field(default=None)

_task: asyncio.Task[None] | None = attr.field(default=None, init=False)
"""The task that is currently iterating over the queue."""

@classmethod
def for_limiter(cls, limiter: RateLimiter[ClientT]) -> _Quota[ClientT]:
"""Create a new Quota for a RateLimiter."""
return cls(bucket=limiter, reset_at=time.monotonic() + limiter.period, remaining=limiter.limit)
def for_limiter(cls, key: str, limiter: RateLimiter[ClientT]) -> _Bucket[ClientT]:
"""Create a new bucket for a RateLimiter."""
return cls(key=key, limiter=limiter, reset_at=time.monotonic() + limiter.period, remaining=limiter.limit)

@property
def remaining(self) -> int:
"""The amount of requests remaining until the bucket is exhausted."""
if self.reset_at <= time.monotonic():
self.reset()
return self._remaining

@remaining.setter
def remaining(self, value: int) -> None:
self._remaining = value

@property
def is_exhausted(self) -> bool:
"""Return a boolean determining if the bucket is exhausted."""
return self.remaining <= 0 and self.reset_at > time.monotonic()

@property
def is_stale(self) -> bool:
"""Return a boolean determining if the bucket is stale.
If a bucket is stale, it is no longer in use and can be purged.
"""
return not self._queue and self.remaining == self.limiter.limit and (self._task is None or self._task.done())

def start_queue(self) -> None:
"""Start the queue of a Quota.
"""Start the queue of a bucket.
This will start setting events in the queue until the bucket is ratelimited.
"""
if self.task is None or self.task.done():
self.task = asyncio.create_task(self._iter_queue())
if self._task is None or self._task.done():
self._task = asyncio.create_task(self._iter_queue())

def reset(self) -> None:
"""Reset the ratelimit."""
self.remaining = self.bucket.limit
self.reset_at = time.monotonic() + self.bucket.period
"""Reset the bucket."""
self.reset_at = time.monotonic() + self.limiter.period
self._remaining = self.limiter.limit

async def _iter_queue(self) -> None:
"""Iterate over the queue and set events until exhausted."""
try:
if self.remaining <= 0 and self.reset_at > time.monotonic():
# Sleep until ratelimit expires
sleep_time = self.reset_at - time.monotonic()
await asyncio.sleep(sleep_time)
self.reset()
elif self.reset_at <= time.monotonic():
self.reset()
while self._queue:
if self.remaining <= 0 and self.reset_at > time.monotonic():
# Sleep until ratelimit expires
await asyncio.sleep(self.reset_at - time.monotonic())
self.reset()

# Set events while not ratelimited
while self.remaining > 0 and self.queue:
self.remaining -= 1
self.queue.popleft().set()

self.task = None
# Set events while not ratelimited
while self.remaining > 0 and self._queue:
self._queue.popleft().set()
self._remaining -= 1

except Exception as e:
print(f"Task Exception was never retrieved: {e}", file=sys.stderr)
Expand All @@ -89,20 +116,21 @@ class RateLimiter(LimiterProto[ClientT]):
Parameters
----------
period : float
The period, in seconds, after which the quota resets.
The period, in seconds, after which the bucket resets.
limit : int
The amount of requests allowed in a quota.
The amount of requests allowed in a bucket.
get_key_with : Callable[[Context[t.Any]], str]
A callable that returns a key for the ratelimiter bucket.
"""

__slots__ = ("period", "limit", "_quotas", "_get_key")
__slots__ = ("period", "limit", "_buckets", "_get_key")

def __init__(self, period: float, limit: int, *, get_key_with: t.Callable[[Context[t.Any]], str]) -> None:
self.period: float = period
self.limit: int = limit
self._quotas: t.Dict[str, _Quota[ClientT]] = {}
self._buckets: t.Dict[str, _Bucket[ClientT]] = {}
self._get_key: t.Callable[[Context[t.Any]], str] = get_key_with
self._gc_task: asyncio.Task[None] | None = None

def get_key(self, ctx: Context[t.Any]) -> str:
"""Get key for ratelimiter bucket."""
Expand All @@ -123,14 +151,27 @@ def is_rate_limited(self, ctx: Context[t.Any]) -> bool:
"""
now = time.monotonic()

if data := self._quotas.get(self.get_key(ctx)):
if data := self._buckets.get(self.get_key(ctx)):
if data.reset_at <= now:
return False
return data.remaining <= 0
return data._remaining <= 0
return False

def _start_gc(self) -> None:
"""Start the garbage collector task if one is not running."""
if self._gc_task is None or self._gc_task.done():
self._gc_task = asyncio.create_task(self._gc())

async def _gc(self) -> None:
"""Purge stale buckets."""
while self._buckets:
await asyncio.sleep(self.period + 1.0)
for bucket in list(self._buckets.values()):
if bucket.is_stale:
del self._buckets[bucket.key]

async def acquire(self, ctx: Context[t.Any], *, wait: bool = True) -> None:
"""Acquire a quota, block execution if ratelimited and wait is True.
"""Acquire a bucket, block execution if ratelimited and wait is True.
Parameters
----------
Expand All @@ -147,16 +188,23 @@ async def acquire(self, ctx: Context[t.Any], *, wait: bool = True) -> None:
"""
event = asyncio.Event()

# Get existing or insert new quota
quota = self._quotas.setdefault(self.get_key(ctx), _Quota.for_limiter(self))
quota.queue.append(event)
quota.start_queue()
key = self.get_key(ctx)
# Get existing or insert new bucket
bucket = self._buckets.setdefault(key, _Bucket.for_limiter(key, self))

if bucket.is_exhausted and not wait:
raise UnderCooldownError(
self,
bucket.reset_at - time.monotonic(),
f"Ratelimited for {bucket.reset_at - time.monotonic()} seconds.",
)

bucket._queue.append(event)
bucket.start_queue()
self._start_gc()

if wait:
await event.wait()
elif self.is_rate_limited(ctx):
retry_after = quota.reset_at - time.monotonic()
raise UnderCooldownError(self, retry_after, f"Ratelimited for {retry_after} seconds.")

async def __call__(self, ctx: Context[t.Any]) -> HookResult:
"""Acquire a ratelimit, fail if ratelimited.
Expand All @@ -181,8 +229,8 @@ async def __call__(self, ctx: Context[t.Any]) -> HookResult:

def reset(self, ctx: Context[t.Any]) -> None:
"""Reset the ratelimit for a given context."""
if quota := self._quotas.get(self.get_key(ctx)):
quota.reset()
if bucket := self._buckets.get(self.get_key(ctx)):
bucket.reset()


def global_limiter(period: float, limit: int) -> RateLimiter[t.Any]:
Expand All @@ -193,9 +241,9 @@ def global_limiter(period: float, limit: int) -> RateLimiter[t.Any]:
Parameters
----------
period : float
The period, in seconds, after which the quota resets.
The period, in seconds, after which the bucket resets.
limit : int
The amount of requests allowed in a quota.
The amount of requests allowed in a bucket.
"""
return RateLimiter(period, limit, get_key_with=lambda _: "amongus")

Expand All @@ -208,9 +256,9 @@ def guild_limiter(period: float, limit: int) -> RateLimiter[t.Any]:
Parameters
----------
period : float
The period, in seconds, after which the quota resets.
The period, in seconds, after which the bucket resets.
limit : int
The amount of requests allowed in a quota.
The amount of requests allowed in a bucket.
"""
return RateLimiter(period, limit, get_key_with=lambda ctx: str(ctx.guild_id))

Expand All @@ -223,9 +271,9 @@ def channel_limiter(period: float, limit: int) -> RateLimiter[t.Any]:
Parameters
----------
period : float
The period, in seconds, after which the quota resets.
The period, in seconds, after which the bucket resets.
limit : int
The amount of requests allowed in a quota.
The amount of requests allowed in a bucket.
"""
return RateLimiter(period, limit, get_key_with=lambda ctx: str(ctx.channel_id))

Expand All @@ -238,9 +286,9 @@ def user_limiter(period: float, limit: int) -> RateLimiter[t.Any]:
Parameters
----------
period : float
The period, in seconds, after which the quota resets.
The period, in seconds, after which the bucket resets.
limit : int
The amount of requests allowed in a quota.
The amount of requests allowed in a bucket.
"""
return RateLimiter(period, limit, get_key_with=lambda ctx: str(ctx.author.id))

Expand All @@ -249,30 +297,31 @@ def member_limiter(period: float, limit: int) -> RateLimiter[t.Any]:
"""Create a member ratelimiter.
This ratelimiter is shared across all contexts by a member in a guild.
The same user in a different guild will be assigned a different quota.
The same user in a different guild will be assigned a different bucket.
Parameters
----------
period : float
The period, in seconds, after which the quota resets.
The period, in seconds, after which the bucket resets.
limit : int
The amount of requests allowed in a quota.
The amount of requests allowed in a bucket.
"""
return RateLimiter(period, limit, get_key_with=lambda ctx: f"{ctx.author.id}:{ctx.guild_id}")


def custom_limiter(period: float, limit: int, get_key_with: t.Callable[[Context[t.Any]], str]) -> RateLimiter[t.Any]:
"""Create a custom ratelimiter.
"""Create a ratelimiter with a custom key extraction function.
Parameters
----------
period : float
The period, in seconds, after which the quota resets.
The period, in seconds, after which the bucket resets.
limit : int
The amount of requests allowed in a quota.
The amount of requests allowed in a bucket.
get_key_with : Callable[[Context[t.Any]], str]
A callable that returns a key for the ratelimiter bucket.
This key is used to identify the bucket.
A callable that returns a key for the ratelimiter bucket. This key is used to identify the bucket.
For instance, to create a ratelimiter that is shared across all contexts in a guild,
you would use `lambda ctx: str(ctx.guild_id)`.
"""
return RateLimiter(period, limit, get_key_with=get_key_with)

Expand Down

0 comments on commit ea9bed4

Please sign in to comment.