Skip to content

Commit

Permalink
커넥션에 rate limit 적용
Browse files Browse the repository at this point in the history
  • Loading branch information
onee-only committed Dec 29, 2024
1 parent 154327a commit ea4aeaf
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 16 deletions.
3 changes: 2 additions & 1 deletion .dev.env
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
MINE_KILL_DURATION_SECONDS=60
BOARD_DATABASE_PATH="/tmp/gamulpung-board-db"
VIEW_SIZE_LIMIT=200
VIEW_SIZE_LIMIT=200
MESSAGE_RATE_LIMIT="4/second" # [count] [per|/] [n (optional)] [second|minute|hour|day|month|year]
1 change: 1 addition & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
MINE_KILL_DURATION_SECONDS: int = int(os.environ.get("MINE_KILL_DURATION_SECONDS"))
BOARD_DATABASE_PATH: str = os.environ.get("BOARD_DATABASE_PATH")
VIEW_SIZE_LIMIT: int = int(os.environ.get("VIEW_SIZE_LIMIT"))
MESSAGE_RATE_LIMIT = os.environ.get("MESSAGE_RATE_LIMIT")
35 changes: 32 additions & 3 deletions conn/manager/internal/connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@
from fastapi.websockets import WebSocket
from conn import Conn
from message import Message
from message.payload import NewConnEvent, NewConnPayload, ConnClosedPayload, DumbHumanException
from message.payload import (
NewConnEvent, NewConnPayload, ConnClosedPayload, DumbHumanException, ErrorEvent, ErrorPayload
)
from event import EventBroker
from uuid import uuid4

from config import MESSAGE_RATE_LIMIT
from limits import storage, strategies, parse


def overwrite_event(msg: Message):
if "origin_event" not in msg.header:
Expand All @@ -17,6 +22,8 @@ def overwrite_event(msg: Message):

class ConnectionManager:
conns: dict[str, Conn] = {}
limiter = strategies.FixedWindowRateLimiter(storage.MemoryStorage())
rate_limit = parse(MESSAGE_RATE_LIMIT)

@staticmethod
def get_conn(id: str):
Expand Down Expand Up @@ -94,5 +101,27 @@ async def receive_multicast_event(message: Message):
await asyncio.gather(*coroutines)

@staticmethod
async def handle_message(message: Message):
await EventBroker.publish(message)
async def publish_client_event(conn_id: str, msg: Message):
# 커넥션 rate limit 확인
ok = ConnectionManager._check_rate_limit(conn_id)

if not ok:
conn = ConnectionManager.get_conn(conn_id)
await conn.send(msg=create_rate_limit_exceeded_message())
return

msg.header = {"sender": conn_id}
await EventBroker.publish(msg)

@staticmethod
def _check_rate_limit(conn_id: str) -> bool:
limit = ConnectionManager.rate_limit
ok = ConnectionManager.limiter.hit(limit, conn_id)
return ok


def create_rate_limit_exceeded_message() -> Message:
return Message(
event=ErrorEvent.ERROR,
payload=ErrorPayload(msg=f"rate limit exceeded. limit: {MESSAGE_RATE_LIMIT}")
)
46 changes: 38 additions & 8 deletions conn/manager/test/conn_manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@
from conn import Conn
from conn.manager import ConnectionManager
from message import Message
from message.payload import TilesPayload, NewConnEvent, NewConnPayload, ConnClosedPayload
from message.payload import (
TilesPayload, NewConnEvent, NewConnPayload, ConnClosedPayload
)
from event import EventBroker
from conn.test.fixtures import create_connection_mock
from board.data import Point

import asyncio


class ConnectionManagerTestCase(unittest.IsolatedAsyncioTestCase):
def setUp(self):
Expand All @@ -20,6 +24,7 @@ def setUp(self):

def tearDown(self):
ConnectionManager.conns = {}
ConnectionManager.limiter.storage.reset()

@patch("event.EventBroker.publish")
async def test_add(self, mock: AsyncMock):
Expand Down Expand Up @@ -145,18 +150,17 @@ async def test_receive_multicast_event(self, mock: AsyncMock):
self.assertEqual(expected.to_str(), got1)
self.assertEqual(expected.to_str(), got2)

async def test_handle_message(self):
async def test_publish_client_event(self):
mock = AsyncMock()
EventBroker.add_receiver("example")(mock)

conn_id = "haha this is some random conn id"
message = Message(event="example",
header={"sender": conn_id},
payload=TilesPayload(
Point(0, 0), Point(0, 0), "abcdefg"
))
message = Message(
event="example",
payload=TilesPayload(Point(0, 0), Point(0, 0), "abcdefg")
)

await ConnectionManager.handle_message(message=message)
await ConnectionManager.publish_client_event(conn_id=conn_id, msg=message)

mock.assert_called_once()

Expand All @@ -165,6 +169,32 @@ async def test_handle_message(self):
self.assertEqual(got.header["sender"], conn_id)
self.assertEqual(got.to_str(), message.to_str())

@patch("event.EventBroker.publish")
async def test_publish_client_event_rate_limit_exceeded(self, mock: AsyncMock):
conn = await ConnectionManager.add(conn=self.con1, width=1, height=1)

limit = ConnectionManager.rate_limit.amount
wait_seconds = ConnectionManager.rate_limit.get_expiry()

async def send_msg():
msg = Message(event="example", payload=None)
await ConnectionManager.publish_client_event(conn_id=conn.id, msg=msg)

# limit 꽉 채우기
await asyncio.gather(*[send_msg() for _ in range(limit)])
self.con1.send_text.assert_not_called()

# 꽉 찬 후에는 불가능
await send_msg()
self.con1.send_text.assert_called_once() # 에러 이벤트 발행
self.con1.send_text.reset_mock()

# 시간이 지난 후에는 다시 가능
self.con1.send_text.reset_mock()
await asyncio.sleep(wait_seconds)
await send_msg()
self.con1.send_text.assert_not_called()


if __name__ == "__main__":
unittest.main()
4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ anyio==4.6.2.post1
certifi==2024.8.30
click==8.1.7
colorama==0.4.6
Deprecated==1.2.15
fastapi==0.115.5
gitdb==4.0.11
GitPython==3.1.41
Expand All @@ -12,6 +13,8 @@ httptools==0.6.4
httpx==0.27.2
idna==3.10
lmdb==1.5.1
limits==3.14.1
packaging==24.2
pydantic==2.9.2
pydantic_core==2.23.4
python-dotenv==1.0.1
Expand All @@ -24,3 +27,4 @@ typing_extensions==4.12.2
uvicorn==0.32.0
watchfiles==0.24.0
websockets==14.1
wrapt==1.17.0
7 changes: 3 additions & 4 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,8 @@ async def session(ws: WebSocket):

while True:
try:
message = await conn.receive()
message.header = {"sender": conn.id}
await ConnectionManager.handle_message(message)
msg = await conn.receive()
await ConnectionManager.publish_client_event(conn_id=conn.id, msg=msg)
except (WebSocketDisconnect, ConnectionClosed) as e:
# 연결 종료됨
break
Expand All @@ -50,7 +49,7 @@ async def session(ws: WebSocket):
payload=ErrorPayload(msg=e)
))

print(f"Unhandled error while handling message: \n{message.__dict__}\n{type(e)}: '{e}'")
print(f"Unhandled error while handling message: \n{msg.__dict__}\n{type(e)}: '{e}'")
break

await ConnectionManager.close(conn)
Expand Down

0 comments on commit ea4aeaf

Please sign in to comment.