Skip to content

Commit

Permalink
fix: exit on consumption error
Browse files Browse the repository at this point in the history
  • Loading branch information
ClemDoum committed Nov 21, 2024
1 parent 7ebc7de commit 663703f
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 45 deletions.
9 changes: 6 additions & 3 deletions icij-worker/icij_worker/tests/worker/test_neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from icij_worker.tests.conftest import count_locks
from icij_worker.tests.worker.conftest import make_app
from icij_worker.utils import neo4j_
from icij_worker.worker.worker import TaskConsumptionError


@pytest.fixture(
Expand Down Expand Up @@ -113,15 +114,17 @@ def neo4j_db(group: str) -> str:
# When
async with worker:
# Then
with pytest.raises(ClientError) as ex:
with pytest.raises(TaskConsumptionError) as ex:
await worker.consume()

assert ex.value.code == "Neo.ClientError.Database.DatabaseNotFound"
cause = ex.value.__cause__
assert isinstance(cause, ClientError)
assert cause.code == "Neo.ClientError.Database.DatabaseNotFound"
expected = (
"Unable to get a routing table for database 'other-db' because"
" this database does not exist"
)
assert ex.value.message == expected
assert cause.message == expected


@pytest.mark.parametrize("worker", [None], indirect=["worker"])
Expand Down
44 changes: 15 additions & 29 deletions icij-worker/icij_worker/tests/worker/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,35 +641,6 @@ async def test_worker_should_keep_working_on_fatal_error_in_task_codebase(
await worker.work_once()


async def test_worker_should_stop_working_on_fatal_error_in_worker_codebase(
mock_failing_worker: MockWorker,
):
# Given
worker = mock_failing_worker
task_manager = MockManager(worker.app, worker.db_path)
created_at = datetime.now()
task = Task(
id="some-id",
name="fatal_error_task",
created_at=created_at,
state=TaskState.CREATED,
)
await task_manager.save_task(task)

# When/Then
await task_manager.enqueue(task)
with patch.object(worker, "_consume") as mocked_consume:

class _FatalError(Exception): ...

async def _fatal_error_during_consuming():
raise _FatalError("i'm fatal")

mocked_consume.side_effect = _fatal_error_during_consuming
with pytest.raises(_FatalError):
await worker.work_once()


@pytest.mark.parametrize("mock_worker", [{"group": "short"}], indirect=["mock_worker"])
async def test_worker_should_handle_worker_timeout(mock_worker: MockWorker):
# Given
Expand Down Expand Up @@ -704,6 +675,21 @@ async def _assert_has_state(state: TaskState) -> bool:
t.cancel()


async def test_worker_should_not_exit_loop_on_invalid_task(
mock_worker: MockWorker, monkeypatch
):
# Given
worker = mock_worker

async def _failing_consume() -> Task:
raise RuntimeError("some consumption error")

monkeypatch.setattr(MockWorker, "_consume", _failing_consume)
# When
with fail_if_exception("failed to continue on consumption error"):
await worker.work_once()


@pytest.mark.parametrize(
"mock_worker", [{"app": "test_async_app"}], indirect=["mock_worker"]
)
Expand Down
54 changes: 41 additions & 13 deletions icij-worker/icij_worker/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@
WE = TypeVar("WE", bound=WorkerEvent)


class TaskConsumptionError(RuntimeError): ...


class Worker(
RegistrableFromConfig,
EventPublisher,
Expand Down Expand Up @@ -175,17 +178,26 @@ def graceful_shutdown(self) -> bool:
@final
async def _work_once(self):
async with self.ack_cm:
if self._current is None: # Consumption failed, skipping
return
self._current = await task_wrapper(self, self._current)

@final
async def consume(self) -> Task:
task = await self._consume()
self.debug('Task(id="%s") locked', task.id)
async with self._current_lock:
self._current = task
progress = 0.0
update = {"progress": progress, "state": TaskState.RUNNING}
self._current = safe_copy(task, update=update)
try:
task = await self._consume()
self.debug('Task(id="%s") locked', task.id)
async with self._current_lock:
self._current = task
except Exception as e:
msg = (
"failed to consume incoming task probably due to an IO error"
" or deserialization error"
)
raise TaskConsumptionError(msg) from e
progress = 0.0
update = {"progress": progress, "state": TaskState.RUNNING}
self._current = safe_copy(task, update=update)
event = ProgressEvent.from_task(self._current)
await self.publish_event(event)
return self._current
Expand Down Expand Up @@ -224,16 +236,24 @@ async def _early_ack_cm(self):
'Task(id="%s") recoverable error while publishing success',
current_id,
)
except TaskConsumptionError:
self.error("failed to deserialize incoming task, skipping...")
# TODO: change this function an AsyncContentManager
# We have to yield here, otherwise we get a
# RuntimeError("generator didn't yield")
yield
except Exception as fatal_error: # pylint: disable=broad-exception-caught
async with self._current_lock, self._cancel_lock:
if self._current is not None:
# The error is due to the current task, other tasks might success,
# let's fail this task and keep working
await self._handle_error(fatal_error, fatal=True)
return
# The error was in the worker's code, something is wrong that won't change
# at the next task, let's make the worker crash
raise fatal_error
msg = (
f"current task is expected to be known when fatal error occur,"
f" otherwise a {TaskConsumptionError.__name__} is expected"
)
raise RuntimeError(msg) from fatal_error

@final
@asynccontextmanager
Expand Down Expand Up @@ -262,16 +282,24 @@ async def _late_ack_cm(self):
'Task(id="%s") recoverable error while publishing success',
current_id,
)
except TaskConsumptionError:
self.error("failed to deserialize incoming task, skipping...")
# TODO: change this function an AsyncContentManager
# We have to yield here, otherwise we get a
# RuntimeError("generator didn't yield")
yield
except Exception as fatal_error: # pylint: disable=broad-exception-caught
async with self._current_lock, self._cancel_lock:
if self._current is not None:
# The error is due to the current task, other tasks might success,
# let's fail this task and keep working
await self._handle_error(fatal_error, fatal=True)
return
# The error was in the worker's code, something is wrong that won't change
# at the next task, let's make the worker crash
raise fatal_error
msg = (
f"current task is expected to be known when fatal error occur,"
f" otherwise a {TaskConsumptionError.__name__} is expected"
)
raise RuntimeError(msg) from fatal_error

async def _handle_error(self, error: BaseException, fatal: bool):
task = self._current
Expand Down

0 comments on commit 663703f

Please sign in to comment.