diff --git a/conn/internal/conn.py b/conn/internal/conn.py index 90ab6a8..34ab239 100644 --- a/conn/internal/conn.py +++ b/conn/internal/conn.py @@ -1,4 +1,4 @@ -from fastapi.websockets import WebSocket, WebSocketState +from fastapi.websockets import WebSocket, WebSocketState, WebSocketDisconnect from websockets.exceptions import ConnectionClosed from message import Message from dataclasses import dataclass @@ -23,11 +23,11 @@ async def receive(self): return Message.from_str(await self.conn.receive_text()) async def send(self, msg: Message): - if self.conn.client_state == WebSocketState.DISCONNECTED: + if self.conn.application_state == WebSocketState.DISCONNECTED: return try: await self.conn.send_text(msg.to_str()) - except ConnectionClosed: + except (ConnectionClosed, WebSocketDisconnect): # 커넥션이 종료되었는데도 타이밍 문제로 인해 커넥션을 가져왔을 수 있음. return diff --git a/conn/test/fixtures.py b/conn/test/fixtures.py index f0c2c3f..bc2e2f1 100644 --- a/conn/test/fixtures.py +++ b/conn/test/fixtures.py @@ -1,6 +1,6 @@ from unittest.mock import MagicMock, AsyncMock -from fastapi.websockets import WebSocket +from fastapi.websockets import WebSocket, WebSocketState class ConnMock(): @@ -9,6 +9,7 @@ def __init__(self): self.receive_text = AsyncMock() self.send_text = AsyncMock() self.close = AsyncMock() + self.application_state = WebSocketState.CONNECTED def create_connection_mock() -> ConnMock: