From e5f79493cc2ae8d4f778c380247284aa554464ab Mon Sep 17 00:00:00 2001 From: Donkie Date: Fri, 16 Aug 2024 10:28:04 +0200 Subject: [PATCH] Fixed error on trying to send ws message to disconnected client This could result in a major filament usage track failure since the Use http request would fail and Moonraker would keep retrying to save the used filament. However it was actually saved since the db.commit() happens before the websocket message fails. --- spoolman/database/filament.py | 6 +++--- spoolman/database/spool.py | 16 ++++++++++------ spoolman/database/vendor.py | 4 ++-- spoolman/ws.py | 18 +++++++++++++++++- 4 files changed, 32 insertions(+), 12 deletions(-) diff --git a/spoolman/database/filament.py b/spoolman/database/filament.py index 97898fd7a..73ba448a4 100644 --- a/spoolman/database/filament.py +++ b/spoolman/database/filament.py @@ -76,8 +76,8 @@ async def create( extra=[models.FilamentField(key=k, value=v) for k, v in (extra or {}).items()], ) db.add(filament) - await db.commit() await filament_changed(filament, EventType.ADDED) + await db.commit() return filament @@ -172,8 +172,8 @@ async def update( filament.multi_color_direction = v.value if v is not None else None else: setattr(filament, k, v) - await db.commit() await filament_changed(filament, EventType.UPDATED) + await db.commit() return filament @@ -182,8 +182,8 @@ async def delete(db: AsyncSession, filament_id: int) -> None: filament = await get_by_id(db, filament_id) await db.delete(filament) try: - await db.commit() # Flush immediately so any errors are propagated in this request. await filament_changed(filament, EventType.DELETED) + await db.commit() # Flush immediately so any errors are propagated in this request. except IntegrityError as exc: await db.rollback() raise ItemDeleteError("Failed to delete filament.") from exc diff --git a/spoolman/database/spool.py b/spoolman/database/spool.py index b2a6f84c1..594d39206 100644 --- a/spoolman/database/spool.py +++ b/spoolman/database/spool.py @@ -92,8 +92,8 @@ async def create( extra=[models.SpoolField(key=k, value=v) for k, v in (extra or {}).items()], ) db.add(spool) - await db.commit() await spool_changed(spool, EventType.ADDED) + await db.commit() return spool @@ -228,16 +228,16 @@ async def update( spool.extra = [models.SpoolField(key=k, value=v) for k, v in v.items()] else: setattr(spool, k, v) - await db.commit() await spool_changed(spool, EventType.UPDATED) + await db.commit() return spool async def delete(db: AsyncSession, spool_id: int) -> None: """Delete a spool object.""" spool = await get_by_id(db, spool_id) - await db.delete(spool) await spool_changed(spool, EventType.DELETED) + await db.delete(spool) async def clear_extra_field(db: AsyncSession, key: str) -> None: @@ -291,8 +291,8 @@ async def use_weight(db: AsyncSession, spool_id: int, weight: float) -> models.S spool.first_used = datetime.utcnow().replace(microsecond=0) spool.last_used = datetime.utcnow().replace(microsecond=0) - await db.commit() await spool_changed(spool, EventType.UPDATED) + await db.commit() return spool @@ -337,8 +337,12 @@ async def use_length(db: AsyncSession, spool_id: int, length: float) -> models.S spool.first_used = datetime.utcnow().replace(microsecond=0) spool.last_used = datetime.utcnow().replace(microsecond=0) - await db.commit() await spool_changed(spool, EventType.UPDATED) + + # Commit should be the last action, everything after that must never fail + # Otherwise you can end up in a non-atomic thing where the http use request fails + # but the data still has been committed. + await db.commit() return spool @@ -449,6 +453,6 @@ async def reset_initial_weight(db: AsyncSession, spool_id: int, weight: float) - spool.initial_weight = weight spool.used_weight = 0 - await db.commit() await spool_changed(spool, EventType.UPDATED) + await db.commit() return spool diff --git a/spoolman/database/vendor.py b/spoolman/database/vendor.py index fd943a55b..3d07ba07e 100644 --- a/spoolman/database/vendor.py +++ b/spoolman/database/vendor.py @@ -33,8 +33,8 @@ async def create( extra=[models.VendorField(key=k, value=v) for k, v in (extra or {}).items()], ) db.add(vendor) - await db.commit() await vendor_changed(vendor, EventType.ADDED) + await db.commit() return vendor @@ -101,8 +101,8 @@ async def update( vendor.extra = [models.VendorField(key=k, value=v) for k, v in v.items()] else: setattr(vendor, k, v) - await db.commit() await vendor_changed(vendor, EventType.UPDATED) + await db.commit() return vendor diff --git a/spoolman/ws.py b/spoolman/ws.py index 6b02b8e8a..2107b2e99 100644 --- a/spoolman/ws.py +++ b/spoolman/ws.py @@ -3,6 +3,7 @@ import logging from fastapi import WebSocket +from starlette.websockets import WebSocketState from spoolman.api.v1.models import Event @@ -46,7 +47,22 @@ async def send(self, path: tuple[str, ...], evt: Event) -> None: """Send a message to all websockets in this branch of the tree.""" # Broadcast to all subscribers on this level for websocket in self.subscribers: - await websocket.send_text(evt.json()) + if ( + websocket.client_state == WebSocketState.DISCONNECTED # noqa: PLR1714 + or websocket.application_state == WebSocketState.DISCONNECTED + ): + # A bad disconnection may have occurred + self.remove(path, websocket) + logger.info( + "Forcing disconnection of client %s on pool %s", + websocket.client.host if websocket.client else "?", + ",".join(path), + ) + elif ( + websocket.client_state == WebSocketState.CONNECTED + and websocket.application_state == WebSocketState.CONNECTED + ): + await websocket.send_text(evt.json()) # Send the message further down the tree if len(path) > 0 and path[0] in self.children: