-
Notifications
You must be signed in to change notification settings - Fork 0
/
rate_limit.py
59 lines (40 loc) · 1.65 KB
/
rate_limit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from starlette.exceptions import HTTPException
import time
cache = {}
class RateLimitMiddleware:
def __init__(self, app, max_requests: int = None, seconds: int = None) -> None:
self.app = app
self.seconds = seconds
self.max_requests = max_requests
async def __call__(self, scope, receive, send):
ip = self._get_ip_client(scope)
limit = self._get_limit(key=f"rate_limit:{ip}", cache=cache)
self._validation_limit(limit)
if limit:
limit = self._increase_counter(limit)
else:
limit = self._set_counter(key=f"rate_limit:{ip}")
start_time = limit.get("start_time", time.time())
limit = self._reset_ttl_seconds(limit=limit, start_time=start_time)
await self.app(scope, receive, send)
def _get_ip_client(self, scope):
return scope.get("client")[0]
def _get_limit(self, key, cache):
return cache.get(key)
def _validation_limit(self, limit):
if limit:
request_count = limit.get("request_count", 1)
if request_count >= self.max_requests:
raise HTTPException(status_code=429, detail="Too many requests")
def _increase_counter(self, limit):
limit["request_count"] += 1
return limit
def _set_counter(self, key):
cache[key] = {"request_count": 1, "start_time": time.time()}
return cache[key]
def _reset_ttl_seconds(self, start_time, limit) -> dict:
current_time = time.time()
if (current_time - start_time) >= self.seconds:
limit["start_time"] = time.time()
limit["request_count"] = 1
return limit