From e492c725efef80fe3b4db8c76847b06add910c7c Mon Sep 17 00:00:00 2001 From: byundojin Date: Thu, 28 Nov 2024 06:11:09 +0000 Subject: [PATCH] =?UTF-8?q?connM=20Test=20=EA=B0=9C=EC=84=A0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: onee-only --- conn/manager/test/conn_manager_test.py | 80 ++++++++++++-------------- conn/test/conn_test.py | 8 +-- conn/test/fixtures.py | 18 +++--- event/test/event_broker_test.py | 2 +- 4 files changed, 53 insertions(+), 55 deletions(-) diff --git a/conn/manager/test/conn_manager_test.py b/conn/manager/test/conn_manager_test.py index 5b5089f..9854775 100644 --- a/conn/manager/test/conn_manager_test.py +++ b/conn/manager/test/conn_manager_test.py @@ -1,7 +1,7 @@ import unittest import uuid -from unittest.mock import MagicMock, AsyncMock +from unittest.mock import MagicMock, AsyncMock, patch from conn import Conn from conn.manager import ConnectionManager from message import Message @@ -18,24 +18,8 @@ def setUp(self): self.con3 = create_connection_mock() self.con4 = create_connection_mock() - # 기존 new-conn 리시버 비우기 및 mock으로 대체 - self.new_conn_receivers = [] - if NewConnEvent.NEW_CONN in EventBroker.event_dict: - self.new_conn_receivers = EventBroker.event_dict[NewConnEvent.NEW_CONN].copy() - - EventBroker.event_dict[NewConnEvent.NEW_CONN] = [] - - self.mock_new_conn_func = AsyncMock() - self.mock_new_conn_receiver = EventBroker.add_receiver(NewConnEvent.NEW_CONN)(func=self.mock_new_conn_func) - - def tearDown(self): - ConnectionManager.conns = {} - - # 리시버 정상화 - EventBroker.remove_receiver(self.mock_new_conn_receiver) - EventBroker.event_dict[NewConnEvent.NEW_CONN] = self.new_conn_receivers - - async def test_add(self): + @patch("event.EventBroker.publish") + async def test_add(self, mock: AsyncMock): width = 1 height = 1 @@ -45,9 +29,9 @@ async def test_add(self): self.assertEqual(ConnectionManager.get_conn(con_obj.id).id, con_obj.id) - self.assertEqual(len(self.mock_new_conn_func.mock_calls), 1) + mock.assert_called_once() + got: Message[NewConnPayload] = mock.mock_calls[0].args[0] - got = self.mock_new_conn_func.mock_calls[0].args[0] self.assertEqual(type(got), Message) self.assertEqual(type(got.payload), NewConnPayload) self.assertEqual(got.payload.conn_id, con_obj.id) @@ -62,7 +46,8 @@ def test_get_conn(self): self.assertIsNotNone(ConnectionManager.get_conn(valid_id)) self.assertIsNone(ConnectionManager.get_conn(invalid_id)) - async def test_generate_conn_id(self): + @patch("event.EventBroker.publish") + async def test_generate_conn_id(self, mock: AsyncMock): n_conns = 5 conns = [create_connection_mock() for _ in range(n_conns)] @@ -77,7 +62,8 @@ async def test_generate_conn_id(self): # UUID 포맷인지 확인. 아니면 ValueError uuid.UUID(id) - async def test_receive_broadcast_event(self): + @patch("event.EventBroker.publish") + async def test_receive_broadcast_event(self, mock: AsyncMock): _ = await ConnectionManager.add(self.con1, 1, 1) _ = await ConnectionManager.add(self.con2, 1, 1) _ = await ConnectionManager.add(self.con3, 1, 1) @@ -87,21 +73,27 @@ async def test_receive_broadcast_event(self): message = Message(event="broadcast", header={"origin_event": origin_event}, payload=None) - await EventBroker.publish(message) + await ConnectionManager.receive_broadcast_event(message) + + self.con1.send_text.assert_called_once() + self.con2.send_text.assert_called_once() + self.con3.send_text.assert_called_once() + self.con4.send_text.assert_called_once() expected = Message(event=origin_event, payload=None) - self.assertEqual(len(self.con1.send_text.mock_calls), 1) - self.assertEqual(len(self.con2.send_text.mock_calls), 1) - self.assertEqual(len(self.con3.send_text.mock_calls), 1) - self.assertEqual(len(self.con4.send_text.mock_calls), 1) + got1: str = self.con1.send_text.mock_calls[0].args[0] + got2: str = self.con2.send_text.mock_calls[0].args[0] + got3: str = self.con3.send_text.mock_calls[0].args[0] + got4: str = self.con4.send_text.mock_calls[0].args[0] - self.assertEqual(expected.to_str(), self.con1.send_text.mock_calls[0].args[0]) - self.assertEqual(expected.to_str(), self.con2.send_text.mock_calls[0].args[0]) - self.assertEqual(expected.to_str(), self.con3.send_text.mock_calls[0].args[0]) - self.assertEqual(expected.to_str(), self.con4.send_text.mock_calls[0].args[0]) + self.assertEqual(expected.to_str(), got1) + self.assertEqual(expected.to_str(), got2) + self.assertEqual(expected.to_str(), got3) + self.assertEqual(expected.to_str(), got4) - async def test_receive_multicast_event(self): + @patch("event.EventBroker.publish") + async def test_receive_multicast_event(self, mock: AsyncMock): con1 = await ConnectionManager.add(self.con1, 1, 1) con2 = await ConnectionManager.add(self.con2, 1, 1) _ = await ConnectionManager.add(self.con3, 1, 1) @@ -120,15 +112,18 @@ async def test_receive_multicast_event(self): expected = Message(event=origin_event, payload=None) - await EventBroker.publish(message) + await ConnectionManager.receive_multicast_event(message) - self.assertEqual(len(self.con1.send_text.mock_calls), 1) - self.assertEqual(len(self.con2.send_text.mock_calls), 1) - self.assertEqual(len(self.con3.send_text.mock_calls), 0) - self.assertEqual(len(self.con4.send_text.mock_calls), 0) + self.con1.send_text.assert_called_once() + self.con2.send_text.assert_called_once() + self.con3.send_text.assert_not_called() + self.con4.send_text.assert_not_called() - self.assertEqual(expected.to_str(), self.con1.send_text.mock_calls[0].args[0]) - self.assertEqual(expected.to_str(), self.con2.send_text.mock_calls[0].args[0]) + got1: str = self.con1.send_text.mock_calls[0].args[0] + got2: str = self.con1.send_text.mock_calls[0].args[0] + + self.assertEqual(expected.to_str(), got1) + self.assertEqual(expected.to_str(), got2) async def test_handle_message(self): mock = AsyncMock() @@ -143,9 +138,10 @@ async def test_handle_message(self): await ConnectionManager.handle_message(message=message) - self.assertEqual(len(mock.mock_calls), 1) + mock.assert_called_once() + + got: Message[TilesPayload] = mock.mock_calls[0].args[0] - got = mock.mock_calls[0].args[0] self.assertEqual(got.header["sender"], conn_id) self.assertEqual(got.to_str(), message.to_str()) diff --git a/conn/test/conn_test.py b/conn/test/conn_test.py index 21ddf26..096201f 100644 --- a/conn/test/conn_test.py +++ b/conn/test/conn_test.py @@ -33,17 +33,17 @@ def test_create(self): async def test_accept(self): await self.conn_obj.accept() - self.assertEqual(len(self.conn.accept.mock_calls), 1) + self.conn.accept.assert_called_once() async def test_close(self): await self.conn_obj.close() - self.assertEqual(len(self.conn.close.mock_calls), 1) + self.conn.close.assert_called_once() async def test_send(self): msg = Message("example", payload=ExamplePayload(a=0)) await self.conn_obj.send(msg) - self.assertEqual(len(self.conn.send_text.mock_calls), 1) + self.conn.send_text.assert_called_once() self.assertEqual(self.conn.send_text.mock_calls[0].args[0], msg.to_str()) async def test_receive(self): @@ -53,6 +53,6 @@ async def test_receive(self): got = await self.conn_obj.receive() - self.assertEqual(len(self.conn.receive_text.mock_calls), 1) + self.conn.receive_text.assert_called_once() self.assertEqual(msg.event, got.event) self.assertEqual(msg.payload.a, got.payload.a) diff --git a/conn/test/fixtures.py b/conn/test/fixtures.py index c4f352a..f0c2c3f 100644 --- a/conn/test/fixtures.py +++ b/conn/test/fixtures.py @@ -3,11 +3,13 @@ from fastapi.websockets import WebSocket -def create_connection_mock() -> WebSocket: - con = MagicMock() - con.accept = AsyncMock() - con.receive_text = AsyncMock() - con.send_text = AsyncMock() - con.close = AsyncMock() - - return con +class ConnMock(): + def __init__(self): + self.accept = AsyncMock() + self.receive_text = AsyncMock() + self.send_text = AsyncMock() + self.close = AsyncMock() + + +def create_connection_mock() -> ConnMock: + return ConnMock() diff --git a/event/test/event_broker_test.py b/event/test/event_broker_test.py index d5ada3b..23759b6 100644 --- a/event/test/event_broker_test.py +++ b/event/test/event_broker_test.py @@ -41,7 +41,7 @@ async def test_publish(self): await EventBroker.publish(message=message) - self.assertEqual(len(self.handler.receive_a.func.mock_calls), 1) + self.handler.receive_a.func.assert_called_once() mock_message = self.handler.receive_a.func.mock_calls[0].args[0] self.assertEqual(mock_message.event, message.event)