Skip to content

Commit

Permalink
connM Test 개선
Browse files Browse the repository at this point in the history
Co-authored-by: onee-only <kimww0306@gmail.com>
  • Loading branch information
byundojin and onee-only committed Nov 28, 2024
1 parent 7cf94ab commit e492c72
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 55 deletions.
80 changes: 38 additions & 42 deletions conn/manager/test/conn_manager_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest
import uuid

from unittest.mock import MagicMock, AsyncMock
from unittest.mock import MagicMock, AsyncMock, patch
from conn import Conn
from conn.manager import ConnectionManager
from message import Message
Expand All @@ -18,24 +18,8 @@ def setUp(self):
self.con3 = create_connection_mock()
self.con4 = create_connection_mock()

# 기존 new-conn 리시버 비우기 및 mock으로 대체
self.new_conn_receivers = []
if NewConnEvent.NEW_CONN in EventBroker.event_dict:
self.new_conn_receivers = EventBroker.event_dict[NewConnEvent.NEW_CONN].copy()

EventBroker.event_dict[NewConnEvent.NEW_CONN] = []

self.mock_new_conn_func = AsyncMock()
self.mock_new_conn_receiver = EventBroker.add_receiver(NewConnEvent.NEW_CONN)(func=self.mock_new_conn_func)

def tearDown(self):
ConnectionManager.conns = {}

# 리시버 정상화
EventBroker.remove_receiver(self.mock_new_conn_receiver)
EventBroker.event_dict[NewConnEvent.NEW_CONN] = self.new_conn_receivers

async def test_add(self):
@patch("event.EventBroker.publish")
async def test_add(self, mock: AsyncMock):

width = 1
height = 1
Expand All @@ -45,9 +29,9 @@ async def test_add(self):

self.assertEqual(ConnectionManager.get_conn(con_obj.id).id, con_obj.id)

self.assertEqual(len(self.mock_new_conn_func.mock_calls), 1)
mock.assert_called_once()
got: Message[NewConnPayload] = mock.mock_calls[0].args[0]

got = self.mock_new_conn_func.mock_calls[0].args[0]
self.assertEqual(type(got), Message)
self.assertEqual(type(got.payload), NewConnPayload)
self.assertEqual(got.payload.conn_id, con_obj.id)
Expand All @@ -62,7 +46,8 @@ def test_get_conn(self):
self.assertIsNotNone(ConnectionManager.get_conn(valid_id))
self.assertIsNone(ConnectionManager.get_conn(invalid_id))

async def test_generate_conn_id(self):
@patch("event.EventBroker.publish")
async def test_generate_conn_id(self, mock: AsyncMock):
n_conns = 5

conns = [create_connection_mock() for _ in range(n_conns)]
Expand All @@ -77,7 +62,8 @@ async def test_generate_conn_id(self):
# UUID 포맷인지 확인. 아니면 ValueError
uuid.UUID(id)

async def test_receive_broadcast_event(self):
@patch("event.EventBroker.publish")
async def test_receive_broadcast_event(self, mock: AsyncMock):
_ = await ConnectionManager.add(self.con1, 1, 1)
_ = await ConnectionManager.add(self.con2, 1, 1)
_ = await ConnectionManager.add(self.con3, 1, 1)
Expand All @@ -87,21 +73,27 @@ async def test_receive_broadcast_event(self):

message = Message(event="broadcast", header={"origin_event": origin_event}, payload=None)

await EventBroker.publish(message)
await ConnectionManager.receive_broadcast_event(message)

self.con1.send_text.assert_called_once()
self.con2.send_text.assert_called_once()
self.con3.send_text.assert_called_once()
self.con4.send_text.assert_called_once()

expected = Message(event=origin_event, payload=None)

self.assertEqual(len(self.con1.send_text.mock_calls), 1)
self.assertEqual(len(self.con2.send_text.mock_calls), 1)
self.assertEqual(len(self.con3.send_text.mock_calls), 1)
self.assertEqual(len(self.con4.send_text.mock_calls), 1)
got1: str = self.con1.send_text.mock_calls[0].args[0]
got2: str = self.con2.send_text.mock_calls[0].args[0]
got3: str = self.con3.send_text.mock_calls[0].args[0]
got4: str = self.con4.send_text.mock_calls[0].args[0]

self.assertEqual(expected.to_str(), self.con1.send_text.mock_calls[0].args[0])
self.assertEqual(expected.to_str(), self.con2.send_text.mock_calls[0].args[0])
self.assertEqual(expected.to_str(), self.con3.send_text.mock_calls[0].args[0])
self.assertEqual(expected.to_str(), self.con4.send_text.mock_calls[0].args[0])
self.assertEqual(expected.to_str(), got1)
self.assertEqual(expected.to_str(), got2)
self.assertEqual(expected.to_str(), got3)
self.assertEqual(expected.to_str(), got4)

async def test_receive_multicast_event(self):
@patch("event.EventBroker.publish")
async def test_receive_multicast_event(self, mock: AsyncMock):
con1 = await ConnectionManager.add(self.con1, 1, 1)
con2 = await ConnectionManager.add(self.con2, 1, 1)
_ = await ConnectionManager.add(self.con3, 1, 1)
Expand All @@ -120,15 +112,18 @@ async def test_receive_multicast_event(self):

expected = Message(event=origin_event, payload=None)

await EventBroker.publish(message)
await ConnectionManager.receive_multicast_event(message)

self.assertEqual(len(self.con1.send_text.mock_calls), 1)
self.assertEqual(len(self.con2.send_text.mock_calls), 1)
self.assertEqual(len(self.con3.send_text.mock_calls), 0)
self.assertEqual(len(self.con4.send_text.mock_calls), 0)
self.con1.send_text.assert_called_once()
self.con2.send_text.assert_called_once()
self.con3.send_text.assert_not_called()
self.con4.send_text.assert_not_called()

self.assertEqual(expected.to_str(), self.con1.send_text.mock_calls[0].args[0])
self.assertEqual(expected.to_str(), self.con2.send_text.mock_calls[0].args[0])
got1: str = self.con1.send_text.mock_calls[0].args[0]
got2: str = self.con1.send_text.mock_calls[0].args[0]

self.assertEqual(expected.to_str(), got1)
self.assertEqual(expected.to_str(), got2)

async def test_handle_message(self):
mock = AsyncMock()
Expand All @@ -143,9 +138,10 @@ async def test_handle_message(self):

await ConnectionManager.handle_message(message=message)

self.assertEqual(len(mock.mock_calls), 1)
mock.assert_called_once()

got: Message[TilesPayload] = mock.mock_calls[0].args[0]

got = mock.mock_calls[0].args[0]
self.assertEqual(got.header["sender"], conn_id)
self.assertEqual(got.to_str(), message.to_str())

Expand Down
8 changes: 4 additions & 4 deletions conn/test/conn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,17 @@ def test_create(self):
async def test_accept(self):
await self.conn_obj.accept()

self.assertEqual(len(self.conn.accept.mock_calls), 1)
self.conn.accept.assert_called_once()

async def test_close(self):
await self.conn_obj.close()
self.assertEqual(len(self.conn.close.mock_calls), 1)
self.conn.close.assert_called_once()

async def test_send(self):
msg = Message("example", payload=ExamplePayload(a=0))
await self.conn_obj.send(msg)

self.assertEqual(len(self.conn.send_text.mock_calls), 1)
self.conn.send_text.assert_called_once()
self.assertEqual(self.conn.send_text.mock_calls[0].args[0], msg.to_str())

async def test_receive(self):
Expand All @@ -53,6 +53,6 @@ async def test_receive(self):

got = await self.conn_obj.receive()

self.assertEqual(len(self.conn.receive_text.mock_calls), 1)
self.conn.receive_text.assert_called_once()
self.assertEqual(msg.event, got.event)
self.assertEqual(msg.payload.a, got.payload.a)
18 changes: 10 additions & 8 deletions conn/test/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from fastapi.websockets import WebSocket


def create_connection_mock() -> WebSocket:
con = MagicMock()
con.accept = AsyncMock()
con.receive_text = AsyncMock()
con.send_text = AsyncMock()
con.close = AsyncMock()

return con
class ConnMock():
def __init__(self):
self.accept = AsyncMock()
self.receive_text = AsyncMock()
self.send_text = AsyncMock()
self.close = AsyncMock()


def create_connection_mock() -> ConnMock:
return ConnMock()
2 changes: 1 addition & 1 deletion event/test/event_broker_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ async def test_publish(self):

await EventBroker.publish(message=message)

self.assertEqual(len(self.handler.receive_a.func.mock_calls), 1)
self.handler.receive_a.func.assert_called_once()
mock_message = self.handler.receive_a.func.mock_calls[0].args[0]
self.assertEqual(mock_message.event, message.event)

Expand Down

0 comments on commit e492c72

Please sign in to comment.