From 663703f5dd4622f90ae5c25b052e0cd64c41064a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cle=CC=81ment=20Doumouro?= Date: Thu, 21 Nov 2024 18:24:56 +0100 Subject: [PATCH] fix: exit on consumption error --- .../icij_worker/tests/worker/test_neo4j.py | 9 ++-- .../icij_worker/tests/worker/test_worker.py | 44 ++++++--------- icij-worker/icij_worker/worker/worker.py | 54 ++++++++++++++----- 3 files changed, 62 insertions(+), 45 deletions(-) diff --git a/icij-worker/icij_worker/tests/worker/test_neo4j.py b/icij-worker/icij_worker/tests/worker/test_neo4j.py index 358908f..0026d71 100644 --- a/icij-worker/icij_worker/tests/worker/test_neo4j.py +++ b/icij-worker/icij_worker/tests/worker/test_neo4j.py @@ -32,6 +32,7 @@ from icij_worker.tests.conftest import count_locks from icij_worker.tests.worker.conftest import make_app from icij_worker.utils import neo4j_ +from icij_worker.worker.worker import TaskConsumptionError @pytest.fixture( @@ -113,15 +114,17 @@ def neo4j_db(group: str) -> str: # When async with worker: # Then - with pytest.raises(ClientError) as ex: + with pytest.raises(TaskConsumptionError) as ex: await worker.consume() - assert ex.value.code == "Neo.ClientError.Database.DatabaseNotFound" + cause = ex.value.__cause__ + assert isinstance(cause, ClientError) + assert cause.code == "Neo.ClientError.Database.DatabaseNotFound" expected = ( "Unable to get a routing table for database 'other-db' because" " this database does not exist" ) - assert ex.value.message == expected + assert cause.message == expected @pytest.mark.parametrize("worker", [None], indirect=["worker"]) diff --git a/icij-worker/icij_worker/tests/worker/test_worker.py b/icij-worker/icij_worker/tests/worker/test_worker.py index a0fcb74..f85270b 100644 --- a/icij-worker/icij_worker/tests/worker/test_worker.py +++ b/icij-worker/icij_worker/tests/worker/test_worker.py @@ -641,35 +641,6 @@ async def test_worker_should_keep_working_on_fatal_error_in_task_codebase( await worker.work_once() -async def test_worker_should_stop_working_on_fatal_error_in_worker_codebase( - mock_failing_worker: MockWorker, -): - # Given - worker = mock_failing_worker - task_manager = MockManager(worker.app, worker.db_path) - created_at = datetime.now() - task = Task( - id="some-id", - name="fatal_error_task", - created_at=created_at, - state=TaskState.CREATED, - ) - await task_manager.save_task(task) - - # When/Then - await task_manager.enqueue(task) - with patch.object(worker, "_consume") as mocked_consume: - - class _FatalError(Exception): ... - - async def _fatal_error_during_consuming(): - raise _FatalError("i'm fatal") - - mocked_consume.side_effect = _fatal_error_during_consuming - with pytest.raises(_FatalError): - await worker.work_once() - - @pytest.mark.parametrize("mock_worker", [{"group": "short"}], indirect=["mock_worker"]) async def test_worker_should_handle_worker_timeout(mock_worker: MockWorker): # Given @@ -704,6 +675,21 @@ async def _assert_has_state(state: TaskState) -> bool: t.cancel() +async def test_worker_should_not_exit_loop_on_invalid_task( + mock_worker: MockWorker, monkeypatch +): + # Given + worker = mock_worker + + async def _failing_consume() -> Task: + raise RuntimeError("some consumption error") + + monkeypatch.setattr(MockWorker, "_consume", _failing_consume) + # When + with fail_if_exception("failed to continue on consumption error"): + await worker.work_once() + + @pytest.mark.parametrize( "mock_worker", [{"app": "test_async_app"}], indirect=["mock_worker"] ) diff --git a/icij-worker/icij_worker/worker/worker.py b/icij-worker/icij_worker/worker/worker.py index 4f8d263..c7e52df 100644 --- a/icij-worker/icij_worker/worker/worker.py +++ b/icij-worker/icij_worker/worker/worker.py @@ -58,6 +58,9 @@ WE = TypeVar("WE", bound=WorkerEvent) +class TaskConsumptionError(RuntimeError): ... + + class Worker( RegistrableFromConfig, EventPublisher, @@ -175,17 +178,26 @@ def graceful_shutdown(self) -> bool: @final async def _work_once(self): async with self.ack_cm: + if self._current is None: # Consumption failed, skipping + return self._current = await task_wrapper(self, self._current) @final async def consume(self) -> Task: - task = await self._consume() - self.debug('Task(id="%s") locked', task.id) - async with self._current_lock: - self._current = task - progress = 0.0 - update = {"progress": progress, "state": TaskState.RUNNING} - self._current = safe_copy(task, update=update) + try: + task = await self._consume() + self.debug('Task(id="%s") locked', task.id) + async with self._current_lock: + self._current = task + except Exception as e: + msg = ( + "failed to consume incoming task probably due to an IO error" + " or deserialization error" + ) + raise TaskConsumptionError(msg) from e + progress = 0.0 + update = {"progress": progress, "state": TaskState.RUNNING} + self._current = safe_copy(task, update=update) event = ProgressEvent.from_task(self._current) await self.publish_event(event) return self._current @@ -224,6 +236,12 @@ async def _early_ack_cm(self): 'Task(id="%s") recoverable error while publishing success', current_id, ) + except TaskConsumptionError: + self.error("failed to deserialize incoming task, skipping...") + # TODO: change this function an AsyncContentManager + # We have to yield here, otherwise we get a + # RuntimeError("generator didn't yield") + yield except Exception as fatal_error: # pylint: disable=broad-exception-caught async with self._current_lock, self._cancel_lock: if self._current is not None: @@ -231,9 +249,11 @@ async def _early_ack_cm(self): # let's fail this task and keep working await self._handle_error(fatal_error, fatal=True) return - # The error was in the worker's code, something is wrong that won't change - # at the next task, let's make the worker crash - raise fatal_error + msg = ( + f"current task is expected to be known when fatal error occur," + f" otherwise a {TaskConsumptionError.__name__} is expected" + ) + raise RuntimeError(msg) from fatal_error @final @asynccontextmanager @@ -262,6 +282,12 @@ async def _late_ack_cm(self): 'Task(id="%s") recoverable error while publishing success', current_id, ) + except TaskConsumptionError: + self.error("failed to deserialize incoming task, skipping...") + # TODO: change this function an AsyncContentManager + # We have to yield here, otherwise we get a + # RuntimeError("generator didn't yield") + yield except Exception as fatal_error: # pylint: disable=broad-exception-caught async with self._current_lock, self._cancel_lock: if self._current is not None: @@ -269,9 +295,11 @@ async def _late_ack_cm(self): # let's fail this task and keep working await self._handle_error(fatal_error, fatal=True) return - # The error was in the worker's code, something is wrong that won't change - # at the next task, let's make the worker crash - raise fatal_error + msg = ( + f"current task is expected to be known when fatal error occur," + f" otherwise a {TaskConsumptionError.__name__} is expected" + ) + raise RuntimeError(msg) from fatal_error async def _handle_error(self, error: BaseException, fatal: bool): task = self._current