Skip to content

Commit

Permalink
Merge pull request #19 from gamultong/enhancement/improve-mock-usage
Browse files Browse the repository at this point in the history
test 리팩터링
  • Loading branch information
byundojin authored Nov 28, 2024
2 parents 0327173 + 682226d commit 72670db
Show file tree
Hide file tree
Showing 18 changed files with 310 additions and 385 deletions.
1 change: 0 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,5 @@ jobs:
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Test
run: |
python -m tests.utils_test
python -m tests.__init__
python -m server_test
9 changes: 0 additions & 9 deletions TODO.txt

This file was deleted.

2 changes: 1 addition & 1 deletion board/handler/internal/board_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ async def receive_try_pointing(message: Message[TryPointingPayload]):
pointable = tiles.find("O") != -1

await EventBroker.publish(
message=Message(
Message(
event=PointEvent.POINTING_RESULT,
header={"receiver": sender},
payload=PointingResultPayload(
Expand Down
2 changes: 1 addition & 1 deletion board/handler/test/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import unittest
from .board_handler_test import BoardHandlerTestCase
from .board_handler_test import BoardHandler_FetchTilesReceiver_TestCase

if __name__ == "__main__":
unittest.main()
169 changes: 93 additions & 76 deletions board/handler/test/board_handler_test.py
Original file line number Diff line number Diff line change
@@ -1,80 +1,97 @@
import unittest
from unittest.mock import AsyncMock

from unittest.mock import AsyncMock, patch
from board.handler import BoardHandler
from event import EventBroker
from message import Message
from message.payload import FetchTilesPayload, TilesEvent, TilesPayload, NewConnEvent, NewConnPayload, TryPointingPayload, PointingResultPayload, PointEvent, ClickType

from message.payload import \
FetchTilesPayload, TilesEvent, TilesPayload, NewConnEvent, NewConnPayload, TryPointingPayload, PointingResultPayload, PointEvent, ClickType
from board.test.fixtures import setup_board
from board import Point

from cursor import Color
from event.internal.event_broker import Receiver


class BoardHandlerTestCase(unittest.IsolatedAsyncioTestCase):
def setUp(self):
setup_board()

# 기존 tiles 리시버 비우기 및 mock으로 대체
self.multi_receivers = []
if "multicast" in EventBroker.event_dict:
self.multi_receivers = EventBroker.event_dict["multicast"].copy()

EventBroker.event_dict["multicast"] = []

self.mock_multicast_func = AsyncMock()
self.mock_multicast_receiver = EventBroker.add_receiver("multicast")(func=self.mock_multicast_func)
"""
BoardHandler Test
----------------------------
Test
❌ ✅
- fetch-tiles-receiver
- ✅| normal-case
- ❌| invaild-message
- ❌| invaild-message-payload
- ❌| no-sender
- ❌| invaild-header
- new-conn-receiver
- ✅| normal-case
- try-pointing-receiver
- ✅| normal-case
"""

self.pointing_result_receivers = []
if PointEvent.POINTING_RESULT in EventBroker.event_dict:
self.pointing_result_receivers = EventBroker.event_dict[PointEvent.POINTING_RESULT].copy()

EventBroker.event_dict[PointEvent.POINTING_RESULT] = []
# fetch-tiles-receiver Test
class BoardHandler_FetchTilesReceiver_TestCase(unittest.IsolatedAsyncioTestCase):

self.mock_pointing_result_func = AsyncMock()
self.mock_pointing_result_receiver = EventBroker.add_receiver(PointEvent.POINTING_RESULT)(func=self.mock_pointing_result_func)

def tearDown(self):
# 리시버 정상화
EventBroker.remove_receiver(self.mock_multicast_receiver)
EventBroker.event_dict["multicast"] = self.multi_receivers

EventBroker.remove_receiver(self.mock_pointing_result_receiver)
EventBroker.event_dict[PointEvent.POINTING_RESULT] = self.pointing_result_receivers
def setUp(self):
setup_board()

async def test_receive_fetch_tiles(self):
@patch("event.EventBroker.publish")
async def test_fetch_tiles_receiver_normal_case(self, mock: AsyncMock):
"""
fetch-tiles-receiver
normal-case
----------------------------
trigger event ->
- fetch-tiles : message[FetchTilesPayload]
- header :
- sender : conn_id
- descrption :
econn_id의 tiles 정보 요청
----------------------------
publish event ->
- multicast : message[TilesPayload]
- header :
- target_conns : [conn_id]
- origin_event : tiles
- descrption :
fetch-tiles의 대한 응답
----------------------------
"""

# trigger message 생성
message = Message(
event=TilesEvent.FETCH_TILES,
payload=FetchTilesPayload(Point(-2, 1), Point(1, -2)),
header={"sender": "ayo"},

)

# trigger event
await BoardHandler.receive_fetch_tiles(message)

self.assertEqual(len(self.mock_multicast_func.mock_calls), 1)
got = self.mock_multicast_func.mock_calls[0].args[0]

assert type(got) == Message
assert got.event == "multicast"

assert "target_conns" in got.header
assert len(got.header["target_conns"]) == 1
assert got.header["target_conns"][0] == message.header["sender"]

# 호출 여부
mock.assert_called_once()
got: Message[TilesPayload] = mock.mock_calls[0].args[0]

# message 확인
self.assertEqual(type(got), Message)
# message.event
self.assertEqual(got.event, "multicast")
# message.header
self.assertIn("target_conns", got.header)
self.assertEqual(len(got.header["target_conns"]), 1)
self.assertEqual(got.header["target_conns"][0], message.header["sender"])
self.assertIn("origin_event", got.header)
self.assertEqual(got.header["origin_event"], TilesEvent.TILES)

assert type(got.payload) == TilesPayload
assert got.payload.start_p.x == -2
assert got.payload.start_p.y == 1
assert got.payload.end_p.x == 1
assert got.payload.end_p.y == -2
assert got.payload.tiles == "df12df12er56er56"
# message.payload
self.assertEqual(type(got.payload), TilesPayload)
self.assertEqual(got.payload.start_p.x, -2)
self.assertEqual(got.payload.start_p.y, 1)
self.assertEqual(got.payload.end_p.x, 1)
self.assertEqual(got.payload.end_p.y, -2)
self.assertEqual(got.payload.tiles, "df12df12er56er56")

async def test_receive_new_conn(self):
@patch("event.EventBroker.publish")
async def test_receive_new_conn(self, mock: AsyncMock):
message = Message(
event=NewConnEvent.NEW_CONN,
header={"sender": "ayo"},
Expand All @@ -83,24 +100,25 @@ async def test_receive_new_conn(self):

await BoardHandler.receive_new_conn(message)

self.assertEqual(len(self.mock_multicast_func.mock_calls), 1)
got = self.mock_multicast_func.mock_calls[0].args[0]
mock.assert_called_once()
got: Message[TilesPayload] = mock.mock_calls[0].args[0]

assert type(got) == Message
assert got.event == "multicast"
self.assertEqual(type(got), Message)
self.assertEqual(got.event, "multicast")

assert "target_conns" in got.header
assert len(got.header["target_conns"]) == 1
assert got.header["target_conns"][0] == message.header["sender"]
self.assertIn("target_conns", got.header)
self.assertEqual(len(got.header["target_conns"]), 1)
self.assertEqual(got.header["target_conns"][0], message.header["sender"])

assert type(got.payload) == TilesPayload
assert got.payload.start_p.x == -2
assert got.payload.start_p.y == 2
assert got.payload.end_p.x == 2
assert got.payload.end_p.y == -2
assert got.payload.tiles == "df123df123df123er567er567"
self.assertEqual(type(got.payload), TilesPayload)
self.assertEqual(got.payload.start_p.x, -2)
self.assertEqual(got.payload.start_p.y, 2)
self.assertEqual(got.payload.end_p.x, 2)
self.assertEqual(got.payload.end_p.y, -2)
self.assertEqual(got.payload.tiles, "df123df123df123er567er567")

async def test_try_pointing(self):
@ patch("event.EventBroker.publish")
async def test_try_pointing(self, mock: AsyncMock):
message = Message(
event=PointEvent.TRY_POINTING,
header={"sender": "ayo"},
Expand All @@ -114,19 +132,18 @@ async def test_try_pointing(self):

await BoardHandler.receive_try_pointing(message)

mock = self.mock_pointing_result_func
self.assertEqual(len(mock.mock_calls), 1)
self.assertEqual(len(mock.mock_calls[0].args), 1)
mock.assert_called_once()
got: Message[PointingResultPayload] = mock.mock_calls[0].args[0]

message: Message[PointingResultPayload] = mock.mock_calls[0].args[0]
self.assertEqual(message.event, PointEvent.POINTING_RESULT)
self.assertEqual(type(got), Message)
self.assertEqual(got.event, PointEvent.POINTING_RESULT)

self.assertEqual(len(message.header), 1)
self.assertIn("receiver", message.header)
self.assertEqual(message.header["receiver"], "ayo")
self.assertEqual(len(got.header), 1)
self.assertIn("receiver", got.header)
self.assertEqual(got.header["receiver"], "ayo")

self.assertEqual(type(message.payload), PointingResultPayload)
self.assertEqual(message.payload.pointable, False)
self.assertEqual(type(got.payload), PointingResultPayload)
self.assertEqual(got.payload.pointable, False)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion board/test/board_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_fetch(self, data, expect):

data = Board.fetch(start_p, end_p)

assert data == expect, f"{data} {expect}"
self.assertEqual(data, expect)


if __name__ == "__main__":
Expand Down
10 changes: 5 additions & 5 deletions board/test/section_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ def test_create(self):
data=EXAPMLE_SECTION_DATA
)

assert self.section.p == EXAMPLE_POINT
self.assertEqual(self.section.p, EXAMPLE_POINT)

data_length = sum(len(row) for row in self.section.data)
assert data_length == Section.LENGTH ** 2, data_length
self.assertEqual(data_length, Section.LENGTH ** 2)

@cases(FETCH_TEST_CASES)
def test_fetch(self, desc, range, expect):
Expand All @@ -71,7 +71,7 @@ def test_fetch(self, desc, range, expect):
if end is not None:
data = bytearray().join(data)

assert data.decode('ascii') == expect, f"desc: {desc}, {data}, {expect}"
self.assertEqual(data.decode('ascii'), expect, desc)

def test_fetch_out_of_range(self):
pass
Expand All @@ -93,7 +93,7 @@ def test_update_one(self):
self.section.update(data=value, start=EXAMPLE_POINT)

got = self.section.fetch(start=EXAMPLE_POINT)
assert got == value, f"{type(got)} {got} {value}"
self.assertEqual(got, value)

def test_update_range(self):
rows = 3
Expand All @@ -107,7 +107,7 @@ def test_update_range(self):
self.section.update(data=value, start=start, end=end)

got = self.section.fetch(start=start, end=end)
assert got == value, f"{got} {value}"
self.assertEqual(got, value)


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 72670db

Please sign in to comment.