Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

범위 관련 예외처리 추가 #88

Merged
merged 4 commits into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .dev.env
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
MINE_KILL_DURATION_SECONDS=60
BOARD_DATABASE_PATH="/tmp/gamulpung-board-db"
BOARD_DATABASE_PATH="/tmp/gamulpung-board-db"
VIEW_SIZE_LIMIT=200
37 changes: 34 additions & 3 deletions board/event/handler/internal/board_handler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
from event import EventBroker
from board.data import Point, Tile, Tiles
from board.data import Point, Tile, Tiles, Section
from board.data.handler import BoardHandler
from cursor.data import Color
from message import Message
Expand All @@ -21,8 +21,11 @@
InteractionEvent,
TilesOpenedPayload,
SingleTileOpenedPayload,
FlagSetPayload
FlagSetPayload,
ErrorEvent,
ErrorPayload
)
from config import VIEW_SIZE_LIMIT


class BoardEventHandler():
Expand All @@ -31,7 +34,35 @@ class BoardEventHandler():
async def receive_fetch_tiles(message: Message[FetchTilesPayload]):
sender = message.header["sender"]

await BoardEventHandler._publish_tiles(message.payload.start_p, message.payload.end_p, [sender])
start_p: Point = message.payload.start_p
end_p: Point = message.payload.end_p

# start_p: 좌상, end_p: 우하 확인
if start_p.x > end_p.x or start_p.y < end_p.y:
await EventBroker.publish(Message(
event="multicast",
header={
"origin_event": ErrorEvent.ERROR,
"target_conns": [sender]
},
payload=ErrorPayload(msg="start_p should be left-top, and end_p should be right-bottom")
))
return

# start_p와 end_p 차이 확인
x_gap, y_gap = (end_p.x - start_p.x + 1), (start_p.y - end_p.y + 1)
if x_gap > VIEW_SIZE_LIMIT or y_gap > VIEW_SIZE_LIMIT:
await EventBroker.publish(Message(
event="multicast",
header={
"origin_event": ErrorEvent.ERROR,
"target_conns": [sender]
},
payload=ErrorPayload(msg=f"fetch gap should not be more than {VIEW_SIZE_LIMIT}")
))
return

await BoardEventHandler._publish_tiles(start_p, end_p, [sender])

@EventBroker.add_receiver(NewConnEvent.NEW_CONN)
@staticmethod
Expand Down
77 changes: 75 additions & 2 deletions board/event/handler/test/board_handler_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
from cursor.data import Color
from board.data import Point, Tile, Tiles
from board.data import Point, Tile, Tiles, Section
from board.event.handler import BoardEventHandler
from board.data.handler import BoardHandler
from board.data.storage.test.fixtures import setup_board
Expand All @@ -22,8 +22,11 @@
InteractionEvent,
SingleTileOpenedPayload,
TilesOpenedPayload,
FlagSetPayload
FlagSetPayload,
ErrorEvent,
ErrorPayload
)
from config import VIEW_SIZE_LIMIT

import unittest
from unittest.mock import AsyncMock, patch
Expand Down Expand Up @@ -125,6 +128,76 @@ async def test_fetch_tiles_receiver_normal_case(self, mock: AsyncMock):

self.assertEqual(got.payload.tiles, expected.to_str())

@patch("event.EventBroker.publish")
async def test_fetch_tiles_receiver_malformed_start_end(self, mock: AsyncMock):
start_p = Point(1, 0)
end_p = Point(0, -1)

message = Message(
event=TilesEvent.FETCH_TILES,
header={"sender": "ayo"},
payload=FetchTilesPayload(
start_p=start_p,
end_p=end_p,
)
)

# trigger event
await BoardEventHandler.receive_fetch_tiles(message)

# 호출 여부
mock.assert_called_once()
got: Message[ErrorPayload] = mock.mock_calls[0].args[0]

# message 확인
self.assertEqual(type(got), Message)
# message.event
self.assertEqual(got.event, "multicast")
# message.header
self.assertIn("target_conns", got.header)
self.assertEqual(len(got.header["target_conns"]), 1)
self.assertEqual(got.header["target_conns"][0], message.header["sender"])
self.assertIn("origin_event", got.header)
self.assertEqual(got.header["origin_event"], ErrorEvent.ERROR)

# message.payload
self.assertEqual(type(got.payload), ErrorPayload)

@patch("event.EventBroker.publish")
async def test_fetch_tiles_receiver_range_exceeded(self, mock: AsyncMock):
start_p = Point((-VIEW_SIZE_LIMIT/2) // 1, 0)
end_p = Point((VIEW_SIZE_LIMIT/2) // 1, -1)

message = Message(
event=TilesEvent.FETCH_TILES,
header={"sender": "ayo"},
payload=FetchTilesPayload(
start_p=start_p,
end_p=end_p,
)
)

# trigger event
await BoardEventHandler.receive_fetch_tiles(message)

# 호출 여부
mock.assert_called_once()
got: Message[ErrorPayload] = mock.mock_calls[0].args[0]

# message 확인
self.assertEqual(type(got), Message)
# message.event
self.assertEqual(got.event, "multicast")
# message.header
self.assertIn("target_conns", got.header)
self.assertEqual(len(got.header["target_conns"]), 1)
self.assertEqual(got.header["target_conns"][0], message.header["sender"])
self.assertIn("origin_event", got.header)
self.assertEqual(got.header["origin_event"], ErrorEvent.ERROR)

# message.payload
self.assertEqual(type(got.payload), ErrorPayload)

@patch("event.EventBroker.publish")
async def test_receive_new_conn(self, mock: AsyncMock):
conn_id = "ayo"
Expand Down
1 change: 1 addition & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@

MINE_KILL_DURATION_SECONDS: int = int(os.environ.get("MINE_KILL_DURATION_SECONDS"))
BOARD_DATABASE_PATH: str = os.environ.get("BOARD_DATABASE_PATH")
VIEW_SIZE_LIMIT: int = int(os.environ.get("VIEW_SIZE_LIMIT"))
18 changes: 16 additions & 2 deletions cursor/event/handler/internal/cursor_event_handler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
from cursor.data import Cursor
from cursor.data.handler import CursorHandler
from board.data import Point, Tile, Tiles
from board.data import Point, Tile, Tiles, Section
from event import EventBroker
from message import Message
from datetime import datetime, timedelta
Expand Down Expand Up @@ -34,7 +34,7 @@
ErrorPayload,
NewCursorCandidatePayload
)
from config import MINE_KILL_DURATION_SECONDS
from config import MINE_KILL_DURATION_SECONDS, VIEW_SIZE_LIMIT


class CursorEventHandler:
Expand Down Expand Up @@ -471,6 +471,20 @@ async def receive_set_view_size(message: Message[SetViewSizePayload]):
# 변동 없음
return

if \
new_width <= 0 or new_height <= 0 or \
new_width > VIEW_SIZE_LIMIT or new_height > VIEW_SIZE_LIMIT:
# 뷰 범위 한계 넘음
await EventBroker.publish(Message(
event="multicast",
header={
"origin_event": ErrorEvent.ERROR,
"target_conns": [sender]
},
payload=ErrorPayload(msg=f"view width or height should be more than 0 and less than {VIEW_SIZE_LIMIT}")
))
return

cur_watching = CursorHandler.get_watching(cursor_id=cursor.conn_id)

old_width, old_height = cursor.width, cursor.height
Expand Down
60 changes: 60 additions & 0 deletions cursor/event/handler/test/cursor_event_handler_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
from board.data import Section
from datetime import datetime
from cursor.data import Cursor, Color
from cursor.data.handler import CursorHandler
Expand Down Expand Up @@ -36,6 +37,7 @@
import unittest
from unittest.mock import AsyncMock, patch
from board.data import Point, Tile, Tiles
from config import VIEW_SIZE_LIMIT

"""
CursorEventHandler Test
Expand Down Expand Up @@ -1175,6 +1177,64 @@ async def test_receive_set_view_size_same(self, mock: AsyncMock):
b_watchers = CursorHandler.get_watchers("B")
self.assertEqual(len(b_watchers), 0)

@patch("event.EventBroker.publish")
async def test_receive_set_view_size_exceed_limit(self, mock: AsyncMock):
message = Message(
event=NewConnEvent.SET_VIEW_SIZE,
header={"sender": self.cur_a.conn_id},
payload=SetViewSizePayload(
width=VIEW_SIZE_LIMIT + 1,
height=self.cur_a.height
)
)

await CursorEventHandler.receive_set_view_size(message)

mock.assert_called_once()

# error
got: Message[ErrorPayload] = mock.mock_calls[0].args[0]
self.assertEqual(type(got), Message)
self.assertEqual(got.event, "multicast")
# origin_event
self.assertIn("origin_event", got.header)
self.assertEqual(got.header["origin_event"], ErrorEvent.ERROR)
# target_conns 확인, [A]
self.assertIn("target_conns", got.header)
self.assertEqual(len(got.header["target_conns"]), 1)
self.assertIn(self.cur_a.conn_id, got.header["target_conns"])
# payload 확인
self.assertEqual(type(got.payload), ErrorPayload)

@patch("event.EventBroker.publish")
async def test_receive_set_view_size_0(self, mock: AsyncMock):
message = Message(
event=NewConnEvent.SET_VIEW_SIZE,
header={"sender": self.cur_a.conn_id},
payload=SetViewSizePayload(
width=0,
height=self.cur_a.height
)
)

await CursorEventHandler.receive_set_view_size(message)

mock.assert_called_once()

# error
got: Message[ErrorPayload] = mock.mock_calls[0].args[0]
self.assertEqual(type(got), Message)
self.assertEqual(got.event, "multicast")
# origin_event
self.assertIn("origin_event", got.header)
self.assertEqual(got.header["origin_event"], ErrorEvent.ERROR)
# target_conns 확인, [A]
self.assertIn("target_conns", got.header)
self.assertEqual(len(got.header["target_conns"]), 1)
self.assertIn(self.cur_a.conn_id, got.header["target_conns"])
# payload 확인
self.assertEqual(type(got.payload), ErrorPayload)

@patch("event.EventBroker.publish")
async def test_receive_set_view_size_shrink(self, mock: AsyncMock):
message = Message(
Expand Down
8 changes: 8 additions & 0 deletions server.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from fastapi import FastAPI, WebSocket, Response, WebSocketDisconnect
from websockets.exceptions import ConnectionClosed
from conn.manager import ConnectionManager
from board.data import Section
from board.event.handler import BoardEventHandler
from cursor.event.handler import CursorEventHandler
from message import Message
from message.payload import ErrorEvent, ErrorPayload
from config import VIEW_SIZE_LIMIT

app = FastAPI()

Expand All @@ -14,6 +16,12 @@ async def session(ws: WebSocket):
try:
view_width = int(ws.query_params.get("view_width"))
view_height = int(ws.query_params.get("view_height"))

if \
view_width <= 0 or view_height <= 0 or \
view_width > VIEW_SIZE_LIMIT or view_height > VIEW_SIZE_LIMIT:
raise Exception({"msg": "don't play with view size"})

except KeyError as e:
print(f"WebSocket connection closed: {e}")
await ws.close(code=1000, reason="Missing required data")
Expand Down
Loading