From a1cb2391bb8b1b58a9760959386137988335c006 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cle=CC=81ment=20Doumouro?= Date: Wed, 27 Mar 2024 19:36:09 +0100 Subject: [PATCH] chore: add the project_id attribute to tasks --- icij-worker/icij_worker/__init__.py | 18 +- .../icij_worker/event_publisher/__init__.py | 10 -- .../icij_worker/event_publisher/amqp.py | 38 ++-- .../event_publisher/event_publisher.py | 10 +- .../event_publisher/{neo4j.py => neo4j_.py} | 69 ++++++-- icij-worker/icij_worker/task.py | 98 +++++++++-- .../icij_worker/task_manager/__init__.py | 20 +-- .../task_manager/{neo4j.py => neo4j_.py} | 166 +++++++++++------- icij-worker/icij_worker/tests/conftest.py | 23 ++- .../tests/event_publisher/test_ampq.py | 15 +- .../tests/event_publisher/test_neo4j.py | 45 ++--- .../tests/task_manager/test_neo4j.py | 103 ++++++----- icij-worker/icij_worker/tests/test_task.py | 155 ++++++++-------- .../icij_worker/tests/worker/test_amqp.py | 48 +++-- .../icij_worker/tests/worker/test_neo4j.py | 69 ++++---- .../icij_worker/tests/worker/test_worker.py | 144 +++++++++------ icij-worker/icij_worker/utils/tests.py | 62 ++++--- icij-worker/icij_worker/worker/__init__.py | 10 -- icij-worker/icij_worker/worker/amqp.py | 54 +++--- .../worker/{neo4j.py => neo4j_.py} | 77 ++++---- icij-worker/icij_worker/worker/worker.py | 165 ++++++++--------- 21 files changed, 796 insertions(+), 603 deletions(-) rename icij-worker/icij_worker/event_publisher/{neo4j.py => neo4j_.py} (62%) rename icij-worker/icij_worker/task_manager/{neo4j.py => neo4j_.py} (67%) rename icij-worker/icij_worker/worker/{neo4j.py => neo4j_.py} (85%) diff --git a/icij-worker/icij_worker/__init__.py b/icij-worker/icij_worker/__init__.py index 548ae00a..56ed1942 100644 --- a/icij-worker/icij_worker/__init__.py +++ b/icij-worker/icij_worker/__init__.py @@ -1,6 +1,20 @@ from .app import AsyncApp from .task import Task, TaskError, TaskEvent, TaskResult, TaskStatus +from .task_manager import TaskManager from .worker import Worker, WorkerConfig, WorkerType -from .worker.neo4j import Neo4jWorker + +try: + from icij_worker.worker.worker.amqp import AMQPWorker, AMQPWorkerConfig + from icij_worker.event_publisher.amqp import AMQPPublisher +except ImportError: + pass + +try: + from icij_worker.worker.neo4j_ import Neo4jWorker, Neo4jWorkerConfig + from icij_worker.event_publisher.neo4j_ import Neo4jEventPublisher + from icij_worker.task_manager.neo4j_ import Neo4JTaskManager +except ImportError: + pass + from .backend import WorkerBackend -from .event_publisher import EventPublisher, Neo4jEventPublisher +from .event_publisher import EventPublisher diff --git a/icij-worker/icij_worker/event_publisher/__init__.py b/icij-worker/icij_worker/event_publisher/__init__.py index ea8236c9..168a54aa 100644 --- a/icij-worker/icij_worker/event_publisher/__init__.py +++ b/icij-worker/icij_worker/event_publisher/__init__.py @@ -1,11 +1 @@ from .event_publisher import EventPublisher - -try: - from .neo4j import Neo4jEventPublisher -except ImportError: - pass - -try: - from .amqp import AMQPPublisher -except ImportError: - pass diff --git a/icij-worker/icij_worker/event_publisher/amqp.py b/icij-worker/icij_worker/event_publisher/amqp.py index 55b3460d..0ef6863e 100644 --- a/icij-worker/icij_worker/event_publisher/amqp.py +++ b/icij-worker/icij_worker/event_publisher/amqp.py @@ -17,7 +17,7 @@ from icij_common.logging_utils import LogWithNameMixin from icij_common.pydantic_utils import LowerCamelCaseModel, NoEnumModel -from icij_worker import Task, TaskError, TaskEvent, TaskResult +from icij_worker import TaskError, TaskEvent, TaskResult from . import EventPublisher @@ -112,13 +112,7 @@ def err_routing(cls) -> Routing: def _routings(self) -> List[Routing]: return [self.evt_routing(), self.res_routing(), self.err_routing()] - async def publish_event(self, event: TaskEvent, project: str): - # pylint: disable=unused-argument - # TODO: for now project information is not leverage on the AMQP side which is - # not very convenient as clients will won't know from which project the event - # is coming. This is limitating as for instance when it comes to logs errors, - # such events must be save in separate DBs in order to avoid project data - # leaking to other projects through the DB + async def _publish_event(self, event: TaskEvent): message = event.json().encode() await self._publish_message( message, @@ -127,13 +121,14 @@ async def publish_event(self, event: TaskEvent, project: str): mandatory=False, ) - async def publish_result(self, result: TaskResult, project: str): - # pylint: disable=unused-argument - # TODO: for now project information is not leverage on the AMQP side which is - # not very convenient as clients will won't know from which project the result - # is coming. This is limitating as for instance when as result must probably - # be saved in separate DBs in order to avoid project data leaking to other - # projects through the DB + publish_event_ = _publish_event + + async def publish_result(self, result: TaskResult): + # TODO: for now task project information is not leverage on the AMQP side which + # is not very convenient as clients will won't know from which project the + # result is coming. This is limitating as for instance when as result must + # probably be saved in separate DBs in order to avoid project data leaking to + # other projects through the DB message = result.json().encode() await self._publish_message( message, @@ -142,13 +137,12 @@ async def publish_result(self, result: TaskResult, project: str): mandatory=True, # This is important ) - async def publish_error(self, error: TaskError, task: Task, project: str): - # pylint: disable=unused-argument - # TODO: for now project information is not leverage on the AMQP side which is - # not very convenient as clients will won't know from which project the error - # is coming. This is limitating as for instance when as error must probably - # be saved in separate DBs in order to avoid project data leaking to other - # projects through the DB + async def publish_error(self, error: TaskError): + # TODO: for now task project information is not leverage on the AMQP side which + # is not very convenient as clients will won't know from which project the + # result is coming. This is limitating as for instance when as result must + # probably be saved in separate DBs in order to avoid project data leaking to + # other projects through the DB message = error.json().encode() await self._publish_message( message, diff --git a/icij-worker/icij_worker/event_publisher/event_publisher.py b/icij-worker/icij_worker/event_publisher/event_publisher.py index 489f7263..c99aa61e 100644 --- a/icij-worker/icij_worker/event_publisher/event_publisher.py +++ b/icij-worker/icij_worker/event_publisher/event_publisher.py @@ -1,9 +1,15 @@ from abc import ABC, abstractmethod +from typing import final -from icij_worker import TaskEvent +from icij_worker import Task, TaskEvent class EventPublisher(ABC): + @final + async def publish_event(self, event: TaskEvent, task: Task): + event = event.with_project_id(task) + await self._publish_event(event) + @abstractmethod - async def publish_event(self, event: TaskEvent, project: str): + async def _publish_event(self, event: TaskEvent): pass diff --git a/icij-worker/icij_worker/event_publisher/neo4j.py b/icij-worker/icij_worker/event_publisher/neo4j_.py similarity index 62% rename from icij-worker/icij_worker/event_publisher/neo4j.py rename to icij-worker/icij_worker/event_publisher/neo4j_.py index 9c081eb4..167a1025 100644 --- a/icij-worker/icij_worker/event_publisher/neo4j.py +++ b/icij-worker/icij_worker/event_publisher/neo4j_.py @@ -1,5 +1,5 @@ from contextlib import asynccontextmanager -from typing import AsyncGenerator, Dict, Optional +from typing import AsyncGenerator, Dict, List, Optional import neo4j from neo4j.exceptions import ResultNotSingleError @@ -10,17 +10,64 @@ TASK_ID, TASK_NODE, ) +from icij_common.neo4j.migrate import retrieve_projects from icij_common.neo4j.projects import project_db_session -from . import EventPublisher -from .. import Task, TaskEvent, TaskStatus -from ..exceptions import UnknownTask +from icij_worker.event_publisher.event_publisher import EventPublisher +from icij_worker.task import Task, TaskEvent, TaskStatus +from icij_worker.exceptions import UnknownTask -class Neo4jEventPublisher(EventPublisher): +class Neo4jTaskProjectMixin: + _driver: neo4j.AsyncDriver + _task_projects: Dict[str, str] = dict() + + @asynccontextmanager + async def _project_session( + self, project: str + ) -> AsyncGenerator[neo4j.AsyncSession, None]: + async with project_db_session(self._driver, project) as sess: + yield sess + + async def _get_task_project_id(self, task_id: str) -> str: + if task_id not in self._task_projects: + await self._refresh_task_projects() + try: + return self._task_projects[task_id] + except KeyError as e: + raise UnknownTask(task_id) from e + + async def _refresh_task_projects(self): + projects = await retrieve_projects(self._driver) + for p in projects: + async with self._project_session(p) as sess: + # Here we make the assumption that task IDs are unique across + # projects and not per project + task_projects = { + t: p.name for t in await sess.execute_read(_get_task_ids_tx) + } + self._task_projects.update(task_projects) + + +async def _get_task_ids_tx(tx: neo4j.AsyncTransaction) -> List[str]: + query = f"""MATCH (task:{TASK_NODE}) +RETURN task.{TASK_ID} as taskId""" + res = await tx.run(query) + ids = [rec["taskId"] async for rec in res] + return ids + + +class Neo4jEventPublisher(Neo4jTaskProjectMixin, EventPublisher): def __init__(self, driver: neo4j.AsyncDriver): self._driver = driver - async def publish_event(self, event: TaskEvent, project: str): + async def _publish_event(self, event: TaskEvent): + project = event.project_id + if project is None: + msg = ( + "neo4j expects project to be provided in order to fetch tasks from" + " the project's DB" + ) + raise ValueError(msg) async with self._project_session(project) as sess: await _publish_event(sess, event) @@ -28,13 +75,6 @@ async def publish_event(self, event: TaskEvent, project: str): def driver(self) -> neo4j.AsyncDriver: return self._driver - @asynccontextmanager - async def _project_session( - self, project: str - ) -> AsyncGenerator[neo4j.AsyncSession, None]: - async with project_db_session(self._driver, project) as sess: - yield sess - async def _publish_event(sess: neo4j.AsyncSession, event: TaskEvent): event = {k: v for k, v in event.dict(by_alias=True).items() if v is not None} @@ -48,6 +88,7 @@ async def _publish_event_tx( tx: neo4j.AsyncTransaction, event: Dict, error: Optional[Dict] ): task_id = event["taskId"] + project_id = event["project"] create_task = f"""MERGE (task:{TASK_NODE} {{{TASK_ID}: $taskId }}) ON CREATE SET task += $createProps""" status = event.get("status") @@ -58,7 +99,7 @@ async def _publish_event_tx( create_props = Task.mandatory_fields(event_as_event, keep_id=False) create_props.pop("status", None) res = await tx.run(create_task, taskId=task_id, createProps=create_props) - tasks = [Task.from_neo4j(rec) async for rec in res] + tasks = [Task.from_neo4j(rec, project_id=project_id) async for rec in res] task = tasks[0] resolved = task.resolve_event(event_as_event) resolved = ( diff --git a/icij-worker/icij_worker/task.py b/icij-worker/icij_worker/task.py index 14cad36e..3aa4dbf6 100644 --- a/icij-worker/icij_worker/task.py +++ b/icij-worker/icij_worker/task.py @@ -4,11 +4,12 @@ import logging import traceback import uuid +from abc import ABC from datetime import datetime from enum import Enum, unique from typing import Any, Dict, Optional -from pydantic import validator +from pydantic import Field, validator from icij_common.neo4j.constants import ( TASK_CANCEL_EVENT_CREATED_AT, @@ -17,6 +18,7 @@ TASK_NODE, ) from icij_common.pydantic_utils import ( + ICIJModel, ISODatetime, LowerCamelCaseModel, NoEnumModel, @@ -98,6 +100,7 @@ def _validate_neo4j_datetime(cls, value: Any) -> datetime: class Task(NoEnumModel, LowerCamelCaseModel, Neo4jDatetimeMixin): id: str + project_id: Optional[str] = Field(None, alias="project") type: str inputs: Optional[Dict[str, Any]] = None status: TaskStatus @@ -156,7 +159,13 @@ def _validate_progress(cls, value: Optional[float]): return value @classmethod - def from_neo4j(cls, record: "neo4j.Record", key="task") -> Task: + def from_neo4j( + cls, + record: "neo4j.Record", + *, + project_id: str, + key: str = "task", + ) -> Task: node = record[key] labels = node.labels node = dict(node) @@ -170,7 +179,9 @@ def from_neo4j(cls, record: "neo4j.Record", key="task") -> Task: node["completedAt"] = node["completedAt"].to_native() if "inputs" in node: node["inputs"] = json.loads(node["inputs"]) - return cls(status=status, **node) + node["status"] = status + node["project"] = project_id + return cls.parse_obj(node) @classmethod def mandatory_fields(cls, event: TaskEvent | Task, keep_id: bool) -> Dict[str, Any]: @@ -201,6 +212,8 @@ def resolve_event(self, event: TaskEvent) -> Optional[TaskEvent]: if not resolved: return None base_resolved = TaskEvent(task_id=event.task_id) + if "error" in resolved: + resolved["error"] = TaskError.parse_obj(resolved["error"]) resolved = safe_copy(base_resolved, update=resolved) return resolved @@ -214,9 +227,20 @@ def _schema(cls, by_alias: bool) -> Dict[str, Any]: return _TASK_SCHEMA[by_alias] -class TaskError(LowerCamelCaseModel): +class WithProjectIDMixin(ICIJModel, ABC): + project_id: Optional[str] + + def with_project_id(self, task: Task) -> WithProjectIDMixin: + if self.project_id is None and task.project_id is not None: + return safe_copy(self, update={"project_id": task.project_id}) + return self + + +class TaskError(LowerCamelCaseModel, WithProjectIDMixin): # This helps to know if an error has already been processed or not id: str + task_id: str + project_id: Optional[str] = Field(None, alias="project") # Follow the "problem detail" spec: https://datatracker.ietf.org/doc/html/rfc9457, # the type is omitted for now since we gave no URI to resolve errors yet title: str @@ -224,7 +248,7 @@ class TaskError(LowerCamelCaseModel): occurred_at: datetime @classmethod - def from_exception(cls, exception: BaseException) -> TaskError: + def from_exception(cls, exception: BaseException, task: Task) -> TaskError: title = exception.__class__.__name__ trace_lines = traceback.format_exception( None, value=exception, tb=exception.__traceback__ @@ -232,20 +256,34 @@ def from_exception(cls, exception: BaseException) -> TaskError: detail = f"{exception}\n{''.join(trace_lines)}" error_id = f"{_id_title(title)}-{uuid.uuid4().hex}" error = TaskError( - id=error_id, title=title, detail=detail, occurred_at=datetime.now() + id=error_id, + task_id=task.id, + project_id=task.project_id, + title=title, + detail=detail, + occurred_at=datetime.now(), ) return error @classmethod - def from_neo4j(cls, record: "neo4j.Record", key="error") -> TaskError: + def from_neo4j( + cls, + record: "neo4j.Record", + *, + task_id: str, + project_id: str, + key: str = "error", + ) -> TaskError: task = dict(record.value(key)) + task.update({"taskId": task_id, "project": project_id}) if "occurredAt" in task: task["occurredAt"] = task["occurredAt"].to_native() - return cls(**task) + return cls.parse_obj(task) -class TaskEvent(NoEnumModel, LowerCamelCaseModel): +class TaskEvent(NoEnumModel, LowerCamelCaseModel, WithProjectIDMixin): task_id: str + project_id: Optional[str] = Field(None, alias="project") task_type: Optional[str] = None status: Optional[TaskStatus] = None progress: Optional[float] = None @@ -267,12 +305,21 @@ def from_error( cls, error: TaskError, task_id: str, retries: Optional[int] = None ) -> TaskEvent: status = TaskStatus.QUEUED if retries is not None else TaskStatus.ERROR - event = TaskEvent(task_id=task_id, status=status, retries=retries, error=error) + event = TaskEvent( + task_id=task_id, + status=status, + retries=retries, + error=error, + project_id=error.project_id, + ) return event -class CancelledTaskEvent(NoEnumModel, LowerCamelCaseModel, Neo4jDatetimeMixin): +class CancelledTaskEvent( + NoEnumModel, LowerCamelCaseModel, Neo4jDatetimeMixin, WithProjectIDMixin +): task_id: str + project_id: Optional[str] = Field(None, alias="project") requeue: bool created_at: datetime @@ -285,29 +332,46 @@ def from_neo4j( cls, record: "neo4j.Record", *, - event_key="event", - task_key="task", + project_id: str, + event_key: str = "event", + task_key: str = "task", ) -> CancelledTaskEvent: task = record.get(task_key) event = record.get(event_key) task_id = task[TASK_ID] requeue = event[TASK_CANCEL_EVENT_REQUEUE] created_at = event[TASK_CANCEL_EVENT_CREATED_AT] - return cls(task_id=task_id, requeue=requeue, created_at=created_at) + return cls( + project_id=project_id, + task_id=task_id, + requeue=requeue, + created_at=created_at, + ) -class TaskResult(LowerCamelCaseModel): +class TaskResult(LowerCamelCaseModel, WithProjectIDMixin): task_id: str + project_id: Optional[str] = Field(None, alias="project") + # TODO: we could use generics here result: object @classmethod def from_neo4j( - cls, record: "neo4j.Record", task_key="task", result_key="result" + cls, + record: "neo4j.Record", + *, + project_id: str, + task_key: str = "task", + result_key: str = "result", ) -> TaskResult: result = record.get(result_key) if result is not None: result = json.loads(result["result"]) - return cls(task_id=record[task_key]["id"], result=result) + return cls(project_id=project_id, task_id=record[task_key]["id"], result=result) + + @classmethod + def from_task(cls, result: object, task: Task) -> TaskResult: + return cls(result=result, task_id=task.id).with_project_id(task) def _id_title(title: str) -> str: diff --git a/icij-worker/icij_worker/task_manager/__init__.py b/icij-worker/icij_worker/task_manager/__init__.py index 4c2d2a13..c7f49dd5 100644 --- a/icij-worker/icij_worker/task_manager/__init__.py +++ b/icij-worker/icij_worker/task_manager/__init__.py @@ -6,44 +6,44 @@ class TaskManager(ABC): @final - async def enqueue(self, task: Task, project: str) -> Task: + async def enqueue(self, task: Task) -> Task: if task.status is not TaskStatus.CREATED: msg = f"invalid status {task.status}, expected {TaskStatus.CREATED}" raise ValueError(msg) - task = await self._enqueue(task, project) + task = await self._enqueue(task) if task.status is not TaskStatus.QUEUED: msg = f"invalid status {task.status}, expected {TaskStatus.QUEUED}" raise ValueError(msg) return task @final - async def cancel(self, *, task_id: str, project: str, requeue: bool): - await self._cancel(task_id=task_id, project=project, requeue=requeue) + async def cancel(self, *, task_id: str, requeue: bool): + await self._cancel(task_id=task_id, requeue=requeue) @abstractmethod - async def _enqueue(self, task: Task, project: str) -> Task: + async def _enqueue(self, task: Task) -> Task: pass @abstractmethod - async def _cancel(self, *, task_id: str, project: str, requeue: bool) -> Task: + async def _cancel(self, *, task_id: str, requeue: bool) -> Task: pass @abstractmethod - async def get_task(self, *, task_id: str, project: str) -> Task: + async def get_task(self, *, task_id: str) -> Task: pass @abstractmethod - async def get_task_errors(self, task_id: str, project: str) -> List[TaskError]: + async def get_task_errors(self, task_id: str) -> List[TaskError]: pass @abstractmethod - async def get_task_result(self, task_id: str, project: str) -> TaskResult: + async def get_task_result(self, task_id: str) -> TaskResult: pass @abstractmethod async def get_tasks( self, - project: str, + project_id: Optional[str] = None, task_type: Optional[str] = None, status: Optional[Union[List[TaskStatus], TaskStatus]] = None, ) -> List[Task]: diff --git a/icij-worker/icij_worker/task_manager/neo4j.py b/icij-worker/icij_worker/task_manager/neo4j_.py similarity index 67% rename from icij-worker/icij_worker/task_manager/neo4j.py rename to icij-worker/icij_worker/task_manager/neo4j_.py index 015f19e4..34b426b0 100644 --- a/icij-worker/icij_worker/task_manager/neo4j.py +++ b/icij-worker/icij_worker/task_manager/neo4j_.py @@ -1,7 +1,6 @@ import json -from contextlib import asynccontextmanager from datetime import datetime -from typing import AsyncGenerator, List, Optional, Union +from typing import Dict, List, Optional, Union import itertools import neo4j @@ -28,8 +27,8 @@ TASK_RESULT_NODE, TASK_TYPE, ) -from icij_common.neo4j.projects import project_db_session from icij_worker import Task, TaskError, TaskResult, TaskStatus +from icij_worker.event_publisher.neo4j_ import Neo4jTaskProjectMixin from icij_worker.exceptions import ( MissingTaskResult, TaskAlreadyExists, @@ -39,6 +38,78 @@ from icij_worker.task_manager import TaskManager +class Neo4JTaskManager(TaskManager, Neo4jTaskProjectMixin): + def __init__(self, driver: neo4j.AsyncDriver, max_queue_size: int): + self._driver = driver + self._max_queue_size = max_queue_size + self._task_projects: Dict[str, str] = dict() + + @property + def driver(self) -> neo4j.AsyncDriver: + return self._driver + + async def get_task(self, *, task_id: str) -> Task: + project_id = await self._get_task_project_id(task_id) + async with self._project_session(project_id) as sess: + return await sess.execute_read( + _get_task_tx, task_id=task_id, project_id=project_id + ) + + async def get_task_errors(self, task_id: str) -> List[TaskError]: + project_id = await self._get_task_project_id(task_id) + async with self._project_session(project_id) as sess: + return await sess.execute_read( + _get_task_errors_tx, task_id=task_id, project_id=project_id + ) + + async def get_task_result(self, task_id: str) -> TaskResult: + project_id = await self._get_task_project_id(task_id) + async with self._project_session(project_id) as sess: + return await sess.execute_read( + _get_task_result_tx, task_id=task_id, project_id=project_id + ) + + async def get_tasks( + self, + project_id: Optional[str] = None, + task_type: Optional[str] = None, + status: Optional[Union[List[TaskStatus], TaskStatus]] = None, + ) -> List[Task]: + if project_id is None: + raise ValueError( + "neo4j expects project to be provided in order to fetch tasks from the" + " project's DB" + ) + async with self._project_session(project_id) as sess: + return await _get_tasks( + sess, status=status, task_type=task_type, project_id=project_id + ) + + async def _enqueue(self, task: Task) -> Task: + project_id = task.project_id + if project_id is None: + raise ValueError( + "neo4j expects project to be provided in order to fetch tasks from the" + " project's DB" + ) + async with self._project_session(project_id) as sess: + inputs = json.dumps(task.inputs) + return await sess.execute_write( + _enqueue_task_tx, + task_id=task.id, + task_type=task.type, + project_id=project_id, + created_at=task.created_at, + max_queue_size=self._max_queue_size, + inputs=inputs, + ) + + async def _cancel(self, *, task_id: str, requeue: bool): + project = await self._get_task_project_id(task_id) + async with self._project_session(project) as sess: + await sess.execute_write(_cancel_task_tx, task_id=task_id, requeue=requeue) + + async def add_support_for_async_task_tx(tx: neo4j.AsyncTransaction): constraint_query = f"""CREATE CONSTRAINT constraint_task_unique_id IF NOT EXISTS @@ -72,83 +143,38 @@ async def add_support_for_async_task_tx(tx: neo4j.AsyncTransaction): await tx.run(task_lock_worker_id_query) -class Neo4JTaskManager(TaskManager): - def __init__(self, driver: neo4j.AsyncDriver, max_queue_size: int): - self._driver = driver - self._max_queue_size = max_queue_size - - @property - def driver(self) -> neo4j.AsyncDriver: - return self._driver - - async def get_task(self, *, task_id: str, project: str) -> Task: - async with project_db_session(self._driver, project) as sess: - return await sess.execute_read(_get_task_tx, task_id=task_id) - - async def get_task_errors(self, task_id: str, project: str) -> List[TaskError]: - async with project_db_session(self._driver, project) as sess: - return await sess.execute_read(_get_task_errors_tx, task_id=task_id) - - async def get_task_result(self, task_id: str, project: str) -> TaskResult: - async with project_db_session(self._driver, project) as sess: - return await sess.execute_read(_get_task_result_tx, task_id=task_id) - - async def get_tasks( - self, - project: str, - task_type: Optional[str] = None, - status: Optional[Union[List[TaskStatus], TaskStatus]] = None, - ) -> List[Task]: - async with project_db_session(self._driver, project) as sess: - return await _get_tasks(sess, status=status, task_type=task_type) - - async def _enqueue(self, task: Task, project: str) -> Task: - async with project_db_session(self._driver, project) as sess: - inputs = json.dumps(task.inputs) - return await sess.execute_write( - _enqueue_task_tx, - task_id=task.id, - task_type=task.type, - created_at=task.created_at, - max_queue_size=self._max_queue_size, - inputs=inputs, - ) - - async def _cancel(self, *, task_id: str, project: str, requeue: bool): - async with project_db_session(self._driver, project) as sess: - await sess.execute_write(_cancel_task_tx, task_id=task_id, requeue=requeue) - - @asynccontextmanager - async def _project_session( - self, project: str - ) -> AsyncGenerator[neo4j.AsyncSession, None]: - async with project_db_session(self._driver, project) as sess: - yield sess - - async def _get_tasks( sess: neo4j.AsyncSession, status: Optional[Union[List[TaskStatus], TaskStatus]], task_type: Optional[str], + project_id: str, ) -> List[Task]: if isinstance(status, TaskStatus): status = [status] if status is not None: status = [s.value for s in status] - return await sess.execute_read(_get_tasks_tx, status=status, task_type=task_type) + return await sess.execute_read( + _get_tasks_tx, status=status, task_type=task_type, project_id=project_id + ) -async def _get_task_tx(tx: neo4j.AsyncTransaction, task_id: str) -> Task: +async def _get_task_tx( + tx: neo4j.AsyncTransaction, *, task_id: str, project_id: str +) -> Task: query = f"MATCH (task:{TASK_NODE} {{ {TASK_ID}: $taskId }}) RETURN task" res = await tx.run(query, taskId=task_id) - tasks = [Task.from_neo4j(t) async for t in res] + tasks = [Task.from_neo4j(t, project_id=project_id) async for t in res] if not tasks: raise UnknownTask(task_id) return tasks[0] async def _get_tasks_tx( - tx: neo4j.AsyncTransaction, status: Optional[List[str]], task_type: Optional[str] + tx: neo4j.AsyncTransaction, + status: Optional[List[str]], + *, + task_type: Optional[str], + project_id: str, ) -> List[Task]: where = "" if task_type: @@ -171,12 +197,12 @@ async def _get_tasks_tx( RETURN task ORDER BY task.{TASK_CREATED_AT} DESC""" res = await tx.run(query, status=status, type=task_type) - tasks = [Task.from_neo4j(t) async for t in res] + tasks = [Task.from_neo4j(t, project_id=project_id) async for t in res] return tasks async def _get_task_errors_tx( - tx: neo4j.AsyncTransaction, task_id: str + tx: neo4j.AsyncTransaction, *, task_id: str, project_id: str ) -> List[TaskError]: query = f"""MATCH (task:{TASK_NODE} {{ {TASK_ID}: $taskId }}) MATCH (error:{TASK_ERROR_NODE})-[:{TASK_ERROR_OCCURRED_TYPE}]->(task) @@ -184,17 +210,22 @@ async def _get_task_errors_tx( ORDER BY error.{TASK_ERROR_OCCURRED_AT} DESC """ res = await tx.run(query, taskId=task_id) - errors = [TaskError.from_neo4j(t) async for t in res] + errors = [ + TaskError.from_neo4j(t, task_id=task_id, project_id=project_id) + async for t in res + ] return errors -async def _get_task_result_tx(tx: neo4j.AsyncTransaction, task_id: str) -> TaskResult: +async def _get_task_result_tx( + tx: neo4j.AsyncTransaction, *, task_id: str, project_id: str +) -> TaskResult: query = f"""MATCH (task:{TASK_NODE} {{ {TASK_ID}: $taskId }}) MATCH (task)-[:{TASK_HAS_RESULT_TYPE}]->(result:{TASK_RESULT_NODE}) RETURN task, result """ res = await tx.run(query, taskId=task_id) - results = [TaskResult.from_neo4j(t) async for t in res] + results = [TaskResult.from_neo4j(t, project_id=project_id) async for t in res] if not results: raise MissingTaskResult(task_id) return results[0] @@ -205,6 +236,7 @@ async def _enqueue_task_tx( *, task_id: str, task_type: str, + project_id: str, created_at: datetime, inputs: str, max_queue_size: int, @@ -236,7 +268,7 @@ async def _enqueue_task_tx( task = await res.single(strict=True) except ConstraintError as e: raise TaskAlreadyExists() from e - return Task.from_neo4j(task) + return Task.from_neo4j(task, project_id=project_id) async def _cancel_task_tx(tx: neo4j.AsyncTransaction, task_id: str, requeue: bool): diff --git a/icij-worker/icij_worker/tests/conftest.py b/icij-worker/icij_worker/tests/conftest.py index d97fafbf..3fedd5fd 100644 --- a/icij-worker/icij_worker/tests/conftest.py +++ b/icij-worker/icij_worker/tests/conftest.py @@ -34,8 +34,8 @@ from icij_common.test_utils import TEST_PROJECT from icij_worker import AsyncApp, Task from icij_worker.event_publisher.amqp import AMQPPublisher -from icij_worker.task import CancelledTaskEvent -from icij_worker.task_manager.neo4j import add_support_for_async_task_tx +from icij_worker.task import CancelledTaskEvent, TaskStatus +from icij_worker.task_manager.neo4j_ import add_support_for_async_task_tx from icij_worker.typing_ import PercentProgress # noinspection PyUnresolvedReferences @@ -106,7 +106,7 @@ async def populate_tasks(neo4j_async_app_driver: neo4j.AsyncDriver) -> List[Task recs_0, _, _ = await neo4j_async_app_driver.execute_query( query_0, now=datetime.now() ) - t_0 = Task.from_neo4j(recs_0[0]) + t_0 = Task.from_neo4j(recs_0[0], project_id=TEST_PROJECT) query_1 = """CREATE (task:_Task:RUNNING { id: 'task-1', type: 'hello_world', @@ -119,7 +119,7 @@ async def populate_tasks(neo4j_async_app_driver: neo4j.AsyncDriver) -> List[Task recs_1, _, _ = await neo4j_async_app_driver.execute_query( query_1, now=datetime.now() ) - t_1 = Task.from_neo4j(recs_1[0]) + t_1 = Task.from_neo4j(recs_1[0], project_id=TEST_PROJECT) return [t_0, t_1] @@ -133,7 +133,7 @@ async def populate_cancel_events( recs_0, _, _ = await neo4j_async_app_driver.execute_query( query_0, now=datetime.now(), taskId=populate_tasks[0].id ) - return [CancelledTaskEvent.from_neo4j(recs_0[0])] + return [CancelledTaskEvent.from_neo4j(recs_0[0], project_id=TEST_PROJECT)] class Recoverable(ValueError): @@ -317,3 +317,16 @@ def can_publish(self) -> bool: @property def event_queue(self) -> str: return self.__class__.evt_routing().default_queue + + +@pytest.fixture(scope="session") +def hello_world_task() -> Task: + task = Task( + id="some-id", + project_id=TEST_PROJECT, + type="hello_world", + inputs={"greeted": "world"}, + status=TaskStatus.CREATED, + created_at=datetime.now(), + ) + return task diff --git a/icij-worker/icij_worker/tests/event_publisher/test_ampq.py b/icij-worker/icij_worker/tests/event_publisher/test_ampq.py index 76f1ce7a..e72fc79e 100644 --- a/icij-worker/icij_worker/tests/event_publisher/test_ampq.py +++ b/icij-worker/icij_worker/tests/event_publisher/test_ampq.py @@ -2,8 +2,8 @@ from aio_pika import ExchangeType, connect_robust from aiormq import ChannelNotFoundEntity -from icij_common.test_utils import TEST_PROJECT -from icij_worker import TaskEvent, TaskStatus +from icij_common.pydantic_utils import safe_copy +from icij_worker import Task, TaskEvent, TaskStatus from icij_worker.event_publisher.amqp import ( AMQPPublisher, Exchange, @@ -33,20 +33,20 @@ @pytest.mark.asyncio -async def test_publish_event(rabbit_mq: str): +async def test_publish_event(rabbit_mq: str, hello_world_task: Task): # Given + task = hello_world_task broker_url = rabbit_mq - project = TEST_PROJECT publisher = TestableAMQPPublisher( broker_url=broker_url, connection_timeout_s=2, reconnection_wait_s=1 ) event = TaskEvent( - task_id="task_id", task_type="hello_world", status=TaskStatus.CREATED + task_id=task.id, task_type="hello_world", status=TaskStatus.CREATED ) # When async with publisher: - await publisher.publish_event(event, project) + await publisher.publish_event(event, task) # Then connection = await connect_robust(url=broker_url) @@ -56,7 +56,8 @@ async def test_publish_event(rabbit_mq: str): async for message in messages: received_event = TaskEvent.parse_raw(message.body) break - assert received_event == event + expected = safe_copy(event, update={"project_id": task.project_id}) + assert received_event == expected async def test_publisher_not_create_and_bind_exchanges_and_queues(rabbit_mq: str): diff --git a/icij-worker/icij_worker/tests/event_publisher/test_neo4j.py b/icij-worker/icij_worker/tests/event_publisher/test_neo4j.py index 1d58d5ba..901636ff 100644 --- a/icij-worker/icij_worker/tests/event_publisher/test_neo4j.py +++ b/icij-worker/icij_worker/tests/event_publisher/test_neo4j.py @@ -7,9 +7,13 @@ from icij_common.pydantic_utils import safe_copy from icij_common.test_utils import TEST_PROJECT -from icij_worker import Task, TaskEvent, TaskStatus -from icij_worker.event_publisher import Neo4jEventPublisher -from icij_worker.task_manager.neo4j import Neo4JTaskManager +from icij_worker import ( + Neo4JTaskManager, + Neo4jEventPublisher, + Task, + TaskEvent, + TaskStatus, +) @pytest.fixture(scope="function") @@ -23,7 +27,6 @@ async def test_worker_publish_event( ): # Given task_manager = Neo4JTaskManager(publisher.driver, max_queue_size=10) - project = TEST_PROJECT task = populate_tasks[0] assert task.status == TaskStatus.QUEUED assert task.progress is None @@ -38,8 +41,8 @@ async def test_worker_publish_event( ) # When - await publisher.publish_event(event=event, project=project) - saved_task = await task_manager.get_task(task_id=task.id, project=project) + await publisher.publish_event(event, task) + saved_task = await task_manager.get_task(task_id=task.id) # Then # Status is not updated by event @@ -52,7 +55,6 @@ async def test_worker_publish_done_task_event_should_not_update_task( publisher: Neo4jEventPublisher, ): # Given - project = TEST_PROJECT query = """CREATE (task:_Task:DONE { id: 'task-0', type: 'hello_world', @@ -64,7 +66,7 @@ async def test_worker_publish_done_task_event_should_not_update_task( async with publisher.driver.session() as sess: res = await sess.run(query, now=datetime.now()) completed = await res.single() - completed = Task.from_neo4j(completed) + completed = Task.from_neo4j(completed, project_id=TEST_PROJECT) task_manager = Neo4JTaskManager(publisher.driver, max_queue_size=10) event = TaskEvent( task_id=completed.id, @@ -74,8 +76,8 @@ async def test_worker_publish_done_task_event_should_not_update_task( ) # When - await publisher.publish_event(event=event, project=project) - saved_task = await task_manager.get_task(task_id=completed.id, project=project) + await publisher.publish_event(event, completed) + saved_task = await task_manager.get_task(task_id=completed.id) # Then assert saved_task == completed @@ -85,11 +87,16 @@ async def test_worker_publish_event_for_unknown_task(publisher: Neo4jEventPublis # This is useful when task is not reserved yet # Given task_manager = Neo4JTaskManager(publisher.driver, max_queue_size=10) - project = TEST_PROJECT - task_id = "some-id" task_type = "hello_world" created_at = datetime.now() + task = Task( + id=task_id, + type=task_type, + project_id=TEST_PROJECT, + created_at=created_at, + status=TaskStatus.QUEUED, + ) event = TaskEvent( task_id=task_id, task_type=task_type, @@ -98,14 +105,11 @@ async def test_worker_publish_event_for_unknown_task(publisher: Neo4jEventPublis ) # When - await publisher.publish_event(event=event, project=project) - saved_task = await task_manager.get_task(task_id=task_id, project=project) + await publisher.publish_event(event, task) + saved_task = await task_manager.get_task(task_id=task_id) # Then - expected = Task( - id=task_id, type=task_type, created_at=created_at, status=TaskStatus.QUEUED - ) - assert saved_task == expected + assert saved_task == task async def test_worker_publish_event_should_use_status_resolution( @@ -113,15 +117,14 @@ async def test_worker_publish_event_should_use_status_resolution( ): # Given task_manager = Neo4JTaskManager(publisher.driver, max_queue_size=10) - project = TEST_PROJECT task = populate_tasks[1] assert task.status is TaskStatus.RUNNING event = TaskEvent(task_id=task.id, status=TaskStatus.CREATED) # When - await publisher.publish_event(event=event, project=project) - saved_task = await task_manager.get_task(task_id=task.id, project=project) + await publisher.publish_event(event, task) + saved_task = await task_manager.get_task(task_id=task.id) # Then assert saved_task == task diff --git a/icij-worker/icij_worker/tests/task_manager/test_neo4j.py b/icij-worker/icij_worker/tests/task_manager/test_neo4j.py index d74000c1..17c3ac83 100644 --- a/icij-worker/icij_worker/tests/task_manager/test_neo4j.py +++ b/icij-worker/icij_worker/tests/task_manager/test_neo4j.py @@ -7,10 +7,9 @@ from icij_common.pydantic_utils import safe_copy from icij_common.test_utils import TEST_PROJECT -from icij_worker import Task, TaskError, TaskResult, TaskStatus +from icij_worker import Neo4JTaskManager, Task, TaskError, TaskResult, TaskStatus from icij_worker.exceptions import MissingTaskResult, TaskAlreadyExists, TaskQueueIsFull from icij_worker.task import CancelledTaskEvent -from icij_worker.task_manager.neo4j import Neo4JTaskManager @pytest_asyncio.fixture(scope="function") @@ -29,7 +28,9 @@ async def _populate_errors( recs_0, _, _ = await neo4j_async_app_driver.execute_query( query_0, taskId=task_with_error.id, now=datetime.now() ) - e_0 = TaskError.from_neo4j(recs_0[0]) + e_0 = TaskError.from_neo4j( + recs_0[0], task_id=task_with_error.id, project_id=TEST_PROJECT + ) query_1 = """MATCH (task:_Task { id: $taskId }) CREATE (error:_TaskError { id: 'error-1', @@ -43,7 +44,9 @@ async def _populate_errors( taskId=task_with_error.id, now=datetime.now(), ) - e_1 = TaskError.from_neo4j(recs_1[0]) + e_1 = TaskError.from_neo4j( + recs_1[0], task_id=task_with_error.id, project_id=TEST_PROJECT + ) return list(zip(populate_tasks, [[], [e_0, e_1]])) @@ -66,8 +69,8 @@ async def _populate_results( recs_0, _, _ = await neo4j_async_app_driver.execute_query( query_1, now=now, after=after ) - t_2 = Task.from_neo4j(recs_0[0]) - r_2 = TaskResult.from_neo4j(recs_0[0]) + t_2 = Task.from_neo4j(recs_0[0], project_id=TEST_PROJECT) + r_2 = TaskResult.from_neo4j(recs_0[0], project_id=t_2.project_id) tasks = populate_tasks + [t_2] return list(zip(tasks, [None, None, r_2])) @@ -77,17 +80,17 @@ async def test_task_manager_get_task( ): # Given task_manager = Neo4JTaskManager(neo4j_async_app_driver, max_queue_size=10) - project = TEST_PROJECT second_task = populate_tasks[1] # When - task = await task_manager.get_task(task_id=second_task.id, project=project) + task = await task_manager.get_task(task_id=second_task.id) task = task.dict(by_alias=True) # Then expected_task = Task( id="task-1", type="hello_world", + project_id=TEST_PROJECT, inputs={"greeted": "1"}, status=TaskStatus.RUNNING, progress=66.6, @@ -108,11 +111,10 @@ async def test_task_manager_get_completed_task( # pylint: disable=invalid-name # Given task_manager = Neo4JTaskManager(neo4j_async_app_driver, max_queue_size=10) - project = TEST_PROJECT last_task = _populate_results[-1][0] # When - task = await task_manager.get_task(task_id=last_task.id, project=project) + task = await task_manager.get_task(task_id=last_task.id) # Then assert isinstance(task.completed_at, datetime) @@ -144,7 +146,7 @@ async def test_task_manager_get_tasks( # When tasks = await task_manager.get_tasks( - project=project, status=statuses, task_type=task_type + project_id=project, status=statuses, task_type=task_type ) tasks = sorted(tasks, key=lambda t: t.id) @@ -162,12 +164,16 @@ async def test_task_manager_get_tasks( [ TaskError( id="error-0", + task_id="task-1", + project_id=TEST_PROJECT, title="error", detail="with details", occurred_at=datetime.now(), ), TaskError( id="error-1", + task_id="task-1", + project_id=TEST_PROJECT, title="error", detail="same error again", occurred_at=datetime.now(), @@ -184,13 +190,10 @@ async def test_get_task_errors( ): # pylint: disable=invalid-name # Given - project = TEST_PROJECT task_manager = Neo4JTaskManager(neo4j_async_app_driver, max_queue_size=10) # When - retrieved_errors = await task_manager.get_task_errors( - task_id=task_id, project=project - ) + retrieved_errors = await task_manager.get_task_errors(task_id=task_id) # Then retrieved_errors = [e.dict(by_alias=True) for e in retrieved_errors] @@ -208,7 +211,10 @@ async def test_get_task_errors( [ ("task-0", None), ("task-1", None), - ("task-2", TaskResult(task_id="task-2", result="Hello 2")), + ( + "task-2", + TaskResult(task_id="task-2", project_id=TEST_PROJECT, result="Hello 2"), + ), ], ) async def test_task_manager_get_task_result( @@ -219,7 +225,6 @@ async def test_task_manager_get_task_result( ): # pylint: disable=invalid-name # Given - project = TEST_PROJECT task_manager = Neo4JTaskManager(neo4j_async_app_driver, max_queue_size=10) # When/ Then @@ -228,26 +233,21 @@ async def test_task_manager_get_task_result( f'Result of task "{task_id}" couldn\'t be found, did it complete ?' ) with pytest.raises(MissingTaskResult, match=expected_msg): - await task_manager.get_task_result(task_id=task_id, project=project) + await task_manager.get_task_result(task_id=task_id) else: - result = await task_manager.get_task_result(task_id=task_id, project=project) + result = await task_manager.get_task_result(task_id=task_id) assert result == expected_result -async def test_task_manager_enqueue(neo4j_async_app_driver: neo4j.AsyncDriver): +async def test_task_manager_enqueue( + neo4j_async_app_driver: neo4j.AsyncDriver, hello_world_task: Task +): # Given - project = TEST_PROJECT task_manager = Neo4JTaskManager(neo4j_async_app_driver, max_queue_size=10) - task = Task( - id="some-id", - type="hello_world", - status=TaskStatus.CREATED, - created_at=datetime.now(), - inputs={"greeted": "world"}, - ) + task = hello_world_task # When - queued = await task_manager.enqueue(task, project) + queued = await task_manager.enqueue(task) # Then update = {"status": TaskStatus.QUEUED} @@ -255,11 +255,10 @@ async def test_task_manager_enqueue(neo4j_async_app_driver: neo4j.AsyncDriver): assert queued == expected -async def test_task_manager_enqueue_should_raise_for_existing_task( +async def test_task_manager_enqueue_should_raise_for_missing_project( neo4j_async_app_driver: neo4j.AsyncDriver, ): # Given - project = TEST_PROJECT task_manager = Neo4JTaskManager(neo4j_async_app_driver, max_queue_size=10) task = Task( id="some-id", @@ -268,34 +267,47 @@ async def test_task_manager_enqueue_should_raise_for_existing_task( created_at=datetime.now(), inputs={"greeted": "world"}, ) - await task_manager.enqueue(task, project) + + # When/Then + expected = ( + "neo4j expects project to be provided in order to fetch tasks from" + " the project's DB" + ) + with pytest.raises(ValueError, match=expected): + await task_manager.enqueue(task) + + +async def test_task_manager_enqueue_should_raise_for_existing_task( + neo4j_async_app_driver: neo4j.AsyncDriver, hello_world_task: Task +): + # Given + task = hello_world_task + task_manager = Neo4JTaskManager(neo4j_async_app_driver, max_queue_size=10) + await task_manager.enqueue(task) # When/Then with pytest.raises(TaskAlreadyExists): - await task_manager.enqueue(task, project) + await task_manager.enqueue(task) @pytest.mark.parametrize("requeue", [True, False]) async def test_task_manager_cancel( - neo4j_async_app_driver: neo4j.AsyncDriver, requeue: bool + neo4j_async_app_driver: neo4j.AsyncDriver, requeue: bool, hello_world_task: Task ): # Given driver = neo4j_async_app_driver - project = TEST_PROJECT + task = hello_world_task task_manager = Neo4JTaskManager(neo4j_async_app_driver, max_queue_size=10) - task = Task( - id="some-id", type="hello", status=TaskStatus.CREATED, created_at=datetime.now() - ) # When - task = await task_manager.enqueue(task, project) - await task_manager.cancel(task_id=task.id, project=project, requeue=requeue) + task = await task_manager.enqueue(task) + await task_manager.cancel(task_id=task.id, requeue=requeue) query = """MATCH (task:_Task { id: $taskId })-[ :_CANCELLED_BY]->(event:_CancelEvent) RETURN task, event""" recs, _, _ = await driver.execute_query(query, taskId=task.id) assert len(recs) == 1 - event = CancelledTaskEvent.from_neo4j(recs[0]) + event = CancelledTaskEvent.from_neo4j(recs[0], project_id=task.project_id) # Then assert event.task_id == task.id assert event.created_at is not None @@ -303,14 +315,11 @@ async def test_task_manager_cancel( async def test_task_manager_enqueue_should_raise_when_queue_full( - neo4j_async_app_driver: neo4j.AsyncDriver, + neo4j_async_app_driver: neo4j.AsyncDriver, hello_world_task: Task ): - project = TEST_PROJECT task_manager = Neo4JTaskManager(neo4j_async_app_driver, max_queue_size=-1) - task = Task( - id="some-id", type="hello", status=TaskStatus.CREATED, created_at=datetime.now() - ) + task = hello_world_task # When with pytest.raises(TaskQueueIsFull): - await task_manager.enqueue(task, project) + await task_manager.enqueue(task) diff --git a/icij-worker/icij_worker/tests/test_task.py b/icij-worker/icij_worker/tests/test_task.py index 25f71277..e049b990 100644 --- a/icij-worker/icij_worker/tests/test_task.py +++ b/icij-worker/icij_worker/tests/test_task.py @@ -24,81 +24,81 @@ def test_precedence_sanity_check(): @pytest.mark.parametrize( "task,event,expected_resolved", [ - # Update the status - ( - Task( - id="task-id", - type="hello_world", - status=TaskStatus.CREATED, - created_at=_CREATED_AT, - ), - TaskEvent(task_id="task-id", status=TaskStatus.RUNNING), - TaskEvent(task_id="task-id", status=TaskStatus.RUNNING), - ), - # Task type is not updated - ( - Task( - id="task-id", - type="hello_world", - status=TaskStatus.CREATED, - created_at=_CREATED_AT, - ), - TaskEvent(task_id="task-id", task_type="goodbye_world"), - None, - ), - # Status is updated when not in a final state - ( - Task( - id="task-id", - type="hello_world", - status=TaskStatus.CREATED, - created_at=_CREATED_AT, - ), - TaskEvent(task_id="task-id", status=TaskStatus.QUEUED), - TaskEvent(task_id="task-id", status=TaskStatus.QUEUED), - ), - ( - Task( - id="task-id", - type="hello_world", - status=TaskStatus.QUEUED, - created_at=_CREATED_AT, - ), - TaskEvent(task_id="task-id", status=TaskStatus.RUNNING), - TaskEvent(task_id="task-id", status=TaskStatus.RUNNING), - ), - ( - Task( - id="task-id", - type="hello_world", - status=TaskStatus.RUNNING, - created_at=_CREATED_AT, - ), - TaskEvent(task_id="task-id", status=TaskStatus.DONE), - TaskEvent(task_id="task-id", status=TaskStatus.DONE), - ), - # Update the progress - ( - Task( - id="task-id", - type="hello_world", - status=TaskStatus.CREATED, - created_at=_CREATED_AT, - ), - TaskEvent(task_id="task-id", progress=50.0), - TaskEvent(task_id="task-id", progress=50.0), - ), - # Update retries - ( - Task( - id="task-id", - type="hello_world", - status=TaskStatus.CREATED, - created_at=_CREATED_AT, - ), - TaskEvent(task_id="task-id", retries=4), - TaskEvent(task_id="task-id", retries=4), - ), + # # Update the status + # ( + # Task( + # id="task-id", + # type="hello_world", + # status=TaskStatus.CREATED, + # created_at=_CREATED_AT, + # ), + # TaskEvent(task_id="task-id", status=TaskStatus.RUNNING), + # TaskEvent(task_id="task-id", status=TaskStatus.RUNNING), + # ), + # # Task type is not updated + # ( + # Task( + # id="task-id", + # type="hello_world", + # status=TaskStatus.CREATED, + # created_at=_CREATED_AT, + # ), + # TaskEvent(task_id="task-id", task_type="goodbye_world"), + # None, + # ), + # # Status is updated when not in a final state + # ( + # Task( + # id="task-id", + # type="hello_world", + # status=TaskStatus.CREATED, + # created_at=_CREATED_AT, + # ), + # TaskEvent(task_id="task-id", status=TaskStatus.QUEUED), + # TaskEvent(task_id="task-id", status=TaskStatus.QUEUED), + # ), + # ( + # Task( + # id="task-id", + # type="hello_world", + # status=TaskStatus.QUEUED, + # created_at=_CREATED_AT, + # ), + # TaskEvent(task_id="task-id", status=TaskStatus.RUNNING), + # TaskEvent(task_id="task-id", status=TaskStatus.RUNNING), + # ), + # ( + # Task( + # id="task-id", + # type="hello_world", + # status=TaskStatus.RUNNING, + # created_at=_CREATED_AT, + # ), + # TaskEvent(task_id="task-id", status=TaskStatus.DONE), + # TaskEvent(task_id="task-id", status=TaskStatus.DONE), + # ), + # # Update the progress + # ( + # Task( + # id="task-id", + # type="hello_world", + # status=TaskStatus.CREATED, + # created_at=_CREATED_AT, + # ), + # TaskEvent(task_id="task-id", progress=50.0), + # TaskEvent(task_id="task-id", progress=50.0), + # ), + # # Update retries + # ( + # Task( + # id="task-id", + # type="hello_world", + # status=TaskStatus.CREATED, + # created_at=_CREATED_AT, + # ), + # TaskEvent(task_id="task-id", retries=4), + # TaskEvent(task_id="task-id", retries=4), + # ), # Update error ( Task( @@ -111,6 +111,7 @@ def test_precedence_sanity_check(): task_id="task-id", error=TaskError( id="error-id", + task_id="task-id", title="some-error", detail="some details", occurred_at=_ERROR_OCCURRED_AT, @@ -120,6 +121,7 @@ def test_precedence_sanity_check(): task_id="task-id", error=TaskError( id="error-id", + task_id="task-id", title="some-error", detail="some details", occurred_at=_ERROR_OCCURRED_AT, @@ -165,6 +167,7 @@ def test_precedence_sanity_check(): retries=4, error=TaskError( id="error-id", + task_id="task-id", title="some-error", detail="some details", occurred_at=_ERROR_OCCURRED_AT, @@ -189,6 +192,7 @@ def test_precedence_sanity_check(): retries=4, error=TaskError( id="error-id", + task_id="task-id", title="some-error", detail="some details", occurred_at=_ERROR_OCCURRED_AT, @@ -213,6 +217,7 @@ def test_precedence_sanity_check(): retries=4, error=TaskError( id="error-id", + task_id="task-id", title="some-error", detail="some details", occurred_at=_ERROR_OCCURRED_AT, diff --git a/icij-worker/icij_worker/tests/worker/test_amqp.py b/icij-worker/icij_worker/tests/worker/test_amqp.py index 6c0af80d..ae7860c6 100644 --- a/icij-worker/icij_worker/tests/worker/test_amqp.py +++ b/icij-worker/icij_worker/tests/worker/test_amqp.py @@ -109,7 +109,7 @@ def error_routing(cls) -> Routing: return AMQPPublisher.err_routing() @property - def cancelled(self) -> Dict[str, Dict[str, CancelledTaskEvent]]: + def cancelled(self) -> Dict[str, CancelledTaskEvent]: return self._cancelled def _create_publisher(self): @@ -235,7 +235,7 @@ async def test_worker_consume_task( ): await asyncio.wait([consume_task], timeout=consume_timeout) expected_task = safe_copy(populate_tasks[0], update={"progress": 0.0}) - consumed, _ = consume_task.result() + consumed = consume_task.result() assert consumed == expected_task @@ -258,7 +258,7 @@ async def _received_event() -> bool: assert await async_true_after(_received_event, after_s=cancel_timeout), failure assert len(amqp_worker.cancelled) == 1 - received_event = next(iter(amqp_worker.cancelled.values())).pop("some-id") + received_event = amqp_worker.cancelled.pop("some-id") assert received_event == expected_event @@ -266,13 +266,11 @@ async def test_worker_negatively_acknowledge( populate_tasks: List[Task], amqp_worker: TestableAMQPWorker, rabbit_mq: str ): # pylint: disable=protected-access,unused-argument - # Given - project = TEST_PROJECT # When async with amqp_worker: # Then - task, _ = await amqp_worker.consume() - await amqp_worker.negatively_acknowledge(task, project, requeue=False) + task = await amqp_worker.consume() + await amqp_worker.negatively_acknowledge(task, requeue=False) dlq_name = amqp_worker.task_routing.dead_letter_routing.default_queue @@ -288,13 +286,12 @@ async def test_worker_negatively_acknowledge_and_requeue( ): # pylint: disable=protected-access,unused-argument # Given - project = TEST_PROJECT n_tasks = len(populate_tasks) # When async with amqp_worker: # Then - task, _ = await amqp_worker.consume() - await amqp_worker.negatively_acknowledge(task, project, requeue=True) + task = await amqp_worker.consume() + await amqp_worker.negatively_acknowledge(task, requeue=True) # Check that we can poll the task again task_ids = set() for _ in range(n_tasks): @@ -304,7 +301,7 @@ async def test_worker_negatively_acknowledge_and_requeue( f"failed to consume task in less than {consume_timeout}s" ): await asyncio.wait([consume_task], timeout=consume_timeout) - task_ids.add(consume_task.result()[0].id) + task_ids.add(consume_task.result().id) assert task.id in task_ids @@ -317,14 +314,11 @@ async def test_worker_negatively_acknowledge_and_cancel( ): # pylint: disable=protected-access,unused-argument # Given - project = TEST_PROJECT # When async with amqp_worker: # Then - task, _ = await amqp_worker.consume() - await amqp_worker.negatively_acknowledge( - task, project, requeue=requeue, cancel=True - ) + task = await amqp_worker.consume() + await amqp_worker.negatively_acknowledge(task, requeue=requeue, cancel=True) task_routing = amqp_worker.task_routing async def _requeued(queue_name: str, n: int) -> bool: @@ -344,17 +338,19 @@ async def _requeued(queue_name: str, n: int) -> bool: async def test_publish_event( - test_async_app: AsyncApp, amqp_worker: TestableAMQPWorker, rabbit_mq: str + test_async_app: AsyncApp, + amqp_worker: TestableAMQPWorker, + rabbit_mq: str, + hello_world_task: Task, ): # pylint: disable=protected-access,unused-argument # Given broker_url = rabbit_mq - project = TEST_PROJECT - - event = TaskEvent(task_id="some_task", progress=50.0) + task = hello_world_task + event = TaskEvent(task_id=task.id, progress=50.0) # When async with amqp_worker: - await amqp_worker.publish_event(event, project) + await amqp_worker.publish_event(event, task) # Then connection = await connect_robust(url=broker_url) @@ -365,7 +361,8 @@ async def test_publish_event( async for message in messages: received_event = TaskEvent.parse_raw(message.body) break - assert received_event == event + expected = safe_copy(event, update={"project_id": TEST_PROJECT}) + assert received_event == expected async def test_publish_error( @@ -374,10 +371,10 @@ async def test_publish_error( # pylint: disable=unused-argument # Given broker_url = rabbit_mq - project = TEST_PROJECT task = populate_tasks[0] error = TaskError( id="error-id", + task_id=task.id, title="someErrorTitle", detail="with_details", occurred_at=datetime.now(), @@ -385,7 +382,7 @@ async def test_publish_error( # When async with amqp_worker: - await amqp_worker.save_error(error=error, task=task, project=project) + await amqp_worker.save_error(error=error) # Then connection = await connect_robust(url=broker_url) channel = await connection.channel() @@ -404,13 +401,12 @@ async def test_publish_result( # pylint: disable=unused-argument # Given broker_url = rabbit_mq - project = TEST_PROJECT task = populate_tasks[0] result = TaskResult(task_id=task.id, result="hello world !") # When async with amqp_worker: - await amqp_worker.save_result(result, project=project) + await amqp_worker.save_result(result) # Then connection = await connect_robust(url=broker_url) channel = await connection.channel() diff --git a/icij-worker/icij_worker/tests/worker/test_neo4j.py b/icij-worker/icij_worker/tests/worker/test_neo4j.py index f344e25e..698a33fd 100644 --- a/icij-worker/icij_worker/tests/worker/test_neo4j.py +++ b/icij-worker/icij_worker/tests/worker/test_neo4j.py @@ -11,6 +11,7 @@ from icij_common.test_utils import TEST_PROJECT, fail_if_exception from icij_worker import ( AsyncApp, + Neo4JTaskManager, Neo4jWorker, Task, TaskError, @@ -19,7 +20,6 @@ TaskStatus, ) from icij_worker.task import CancelledTaskEvent -from icij_worker.task_manager.neo4j import Neo4JTaskManager @pytest.fixture(scope="function") @@ -76,9 +76,9 @@ async def test_worker_consume_cancel_event( await asyncio.wait([task], timeout=timeout) if not task.done(): pytest.fail(f"failed to consume task in less than {timeout}s") - event, project = task.result() + event = task.result() assert event == populate_cancel_events[0] - assert project == TEST_PROJECT + assert event.project_id == TEST_PROJECT async def test_worker_negatively_acknowledge( @@ -88,17 +88,17 @@ async def test_worker_negatively_acknowledge( # Given task_manager = Neo4JTaskManager(worker.driver, max_queue_size=10) # When - task, project = await worker.consume() - n_locks = await _count_locks(worker.driver, project=project) + task = await worker.consume() + n_locks = await _count_locks(worker.driver, project=task.project_id) assert n_locks == 1 - await worker.negatively_acknowledge(task, project, requeue=False) - nacked = await task_manager.get_task(task_id=task.id, project=project) + await worker.negatively_acknowledge(task, requeue=False) + nacked = await task_manager.get_task(task_id=task.id) # Then update = {"status": TaskStatus.ERROR} expected_nacked = safe_copy(task, update=update) assert nacked == expected_nacked - n_locks = await _count_locks(worker.driver, project=project) + n_locks = await _count_locks(worker.driver, project=task.project_id) assert n_locks == 0 @@ -111,6 +111,7 @@ async def test_worker_negatively_acknowledge_and_requeue( project = TEST_PROJECT created_at = datetime.now() task = Task( + project_id=project, id="some-id", type="hello_world", created_at=created_at, @@ -120,17 +121,17 @@ async def test_worker_negatively_acknowledge_and_requeue( assert n_locks == 0 # When - await task_manager.enqueue(task, project) - task, project = await worker.consume() + await task_manager.enqueue(task) + task = await worker.consume() n_locks = await _count_locks(worker.driver, project=project) assert n_locks == 1 # Let's publish some event to increment the progress and check that it's reset # correctly to 0 event = TaskEvent(task_id=task.id, progress=50.0) - await worker.publish_event(event, project) + await worker.publish_event(event, task) with_progress = safe_copy(task, update={"progress": event.progress}) - await worker.negatively_acknowledge(task, project, requeue=True) - nacked = await task_manager.get_task(task_id=task.id, project=project) + await worker.negatively_acknowledge(task, requeue=True) + nacked = await task_manager.get_task(task_id=task.id) # Then update = {"status": TaskStatus.QUEUED, "progress": 0.0, "retries": 1.0} @@ -142,7 +143,7 @@ async def test_worker_negatively_acknowledge_and_requeue( @pytest.mark.parametrize("requeue", [True, False]) async def test_worker_negatively_acknowledge_and_cancel( - populate_tasks: List[Task], worker: Neo4jWorker, requeue: bool + worker: Neo4jWorker, requeue: bool ): # pylint: disable=unused-argument # Given @@ -152,6 +153,7 @@ async def test_worker_negatively_acknowledge_and_cancel( task = Task( id="some-id", type="hello_world", + project_id=TEST_PROJECT, created_at=created_at, status=TaskStatus.CREATED, ) @@ -159,17 +161,17 @@ async def test_worker_negatively_acknowledge_and_cancel( assert n_locks == 0 # When - await task_manager.enqueue(task, project) - task, project = await worker.consume() + await task_manager.enqueue(task) + task = await worker.consume() n_locks = await _count_locks(worker.driver, project=project) assert n_locks == 1 # Let's publish some event to increment the progress and check that it's reset # correctly to 0 event = TaskEvent(task_id=task.id, progress=50.0) - await worker.publish_event(event, project) + await worker.publish_event(event, task) with_progress = safe_copy(task, update={"progress": event.progress}) - await worker.negatively_acknowledge(task, project, cancel=True, requeue=requeue) - nacked = await task_manager.get_task(task_id=task.id, project=project) + await worker.negatively_acknowledge(task, cancel=True, requeue=requeue) + nacked = await task_manager.get_task(task_id=task.id) # Then nacked = nacked.dict(exclude_unset=True) @@ -193,12 +195,12 @@ async def test_worker_save_result(populate_tasks: List[Task], worker: Neo4jWorke task = populate_tasks[0] assert task.status == TaskStatus.QUEUED result = "hello everyone" - task_result = TaskResult(task_id=task.id, result=result) + task_result = TaskResult(task_id=task.id, project_id=project, result=result) # When - await worker.save_result(result=task_result, project=project) - saved_task = await task_manager.get_task(task_id=task.id, project=project) - saved_result = await task_manager.get_task_result(task_id=task.id, project=project) + await worker.save_result(result=task_result) + saved_task = await task_manager.get_task(task_id=task.id) + saved_result = await task_manager.get_task_result(task_id=task.id) # Then assert saved_task == task @@ -213,14 +215,14 @@ async def test_worker_should_raise_when_saving_existing_result( task = populate_tasks[0] assert task.status == TaskStatus.QUEUED result = "hello everyone" - task_result = TaskResult(task_id=task.id, result=result) + task_result = TaskResult(task_id=task.id, project_id=project, result=result) # When - await worker.save_result(result=task_result, project=project) + await worker.save_result(result=task_result) # Then expected = "Attempted to save result for task task-0 but found existing result" with pytest.raises(ValueError, match=expected): - await worker.save_result(result=task_result, project=project) + await worker.save_result(result=task_result) async def test_worker_acknowledgment_cm( @@ -233,11 +235,11 @@ async def test_worker_acknowledgment_cm( # When async with worker.acknowledgment_cm(): await worker.consume() - task = await task_manager.get_task(task_id=created.id, project=TEST_PROJECT) + task = await task_manager.get_task(task_id=created.id) assert task.status is TaskStatus.RUNNING # Then - task = await task_manager.get_task(task_id=created.id, project=TEST_PROJECT) + task = await task_manager.get_task(task_id=created.id) update = {"progress": 100.0, "status": TaskStatus.DONE} expected_task = safe_copy(task, update=update).dict(by_alias=True) expected_task.pop("completedAt") @@ -258,16 +260,19 @@ async def test_worker_save_error(populate_tasks: List[Task], worker: Neo4jWorker project = TEST_PROJECT error = TaskError( id="error-id", + task_id=populate_tasks[0].id, + project_id=project, title="someErrorTitle", detail="with_details", occurred_at=datetime.now(), ) # When - task, _ = await worker.consume() - await worker.save_error(error=error, task=task, project=project) - saved_task = await task_manager.get_task(task_id=task.id, project=project) - saved_errors = await task_manager.get_task_errors(task_id=task.id, project=project) + task = await worker.consume() + await worker.save_error(error=error) + await worker.publish_error_event(error, task) + saved_task = await task_manager.get_task(task_id=task.id) + saved_errors = await task_manager.get_task_errors(task_id=task.id) # Then # We don't expect the task status to be updated by saving the error, the negative diff --git a/icij-worker/icij_worker/tests/worker/test_worker.py b/icij-worker/icij_worker/tests/worker/test_worker.py index de807aa2..5ae43933 100644 --- a/icij-worker/icij_worker/tests/worker/test_worker.py +++ b/icij-worker/icij_worker/tests/worker/test_worker.py @@ -12,10 +12,9 @@ import pytest from icij_common.pydantic_utils import safe_copy -from icij_common.test_utils import async_true_after, fail_if_exception +from icij_common.test_utils import TEST_PROJECT, async_true_after, fail_if_exception from icij_worker import AsyncApp, Task, TaskError, TaskEvent, TaskResult, TaskStatus from icij_worker.exceptions import TaskAlreadyCancelled -from icij_worker.tests.conftest import TEST_PROJECT from icij_worker.utils.tests import MockManager, MockWorker from icij_worker.worker.worker import add_missing_args @@ -35,10 +34,10 @@ async def test_work_once_asyncio_task(mock_worker: MockWorker): # Given worker = mock_worker task_manager = MockManager(worker.db_path, max_queue_size=10) - project = TEST_PROJECT created_at = datetime.now() task = Task( id="some-id", + project_id=TEST_PROJECT, type="hello_world", created_at=created_at, status=TaskStatus.CREATED, @@ -46,17 +45,18 @@ async def test_work_once_asyncio_task(mock_worker: MockWorker): ) # When - await task_manager.enqueue(task, project) + await task_manager.enqueue(task) await worker.work_once() - saved_task = await task_manager.get_task(task_id=task.id, project=project) - saved_errors = await task_manager.get_task_errors(task_id=task.id, project=project) - saved_result = await task_manager.get_task_result(task_id=task.id, project=project) + saved_task = await task_manager.get_task(task_id=task.id) + saved_errors = await task_manager.get_task_errors(task_id=task.id) + saved_result = await task_manager.get_task_result(task_id=task.id) # Then assert not saved_errors expected_task = Task( id="some-id", + project_id=TEST_PROJECT, type="hello_world", progress=100, created_at=created_at, @@ -71,11 +71,17 @@ async def test_work_once_asyncio_task(mock_worker: MockWorker): expected_task.pop("completedAt") assert saved_task == expected_task expected_events = [ - TaskEvent(task_id="some-id", status=TaskStatus.RUNNING, progress=0.0), - TaskEvent(task_id="some-id", progress=0.1), - TaskEvent(task_id="some-id", progress=0.99), TaskEvent( task_id="some-id", + project_id=TEST_PROJECT, + status=TaskStatus.RUNNING, + progress=0.0, + ), + TaskEvent(task_id="some-id", project_id=TEST_PROJECT, progress=0.1), + TaskEvent(task_id="some-id", project_id=TEST_PROJECT, progress=0.99), + TaskEvent( + task_id="some-id", + project_id=TEST_PROJECT, status=TaskStatus.DONE, progress=100.0, completed_at=completed_at, @@ -83,7 +89,9 @@ async def test_work_once_asyncio_task(mock_worker: MockWorker): ] assert worker.published_events == expected_events - expected_result = TaskResult(task_id="some-id", result="Hello world !") + expected_result = TaskResult( + task_id="some-id", project_id=TEST_PROJECT, result="Hello world !" + ) assert saved_result == expected_result @@ -91,10 +99,10 @@ async def test_work_once_run_sync_task(mock_worker: MockWorker): # Given worker = mock_worker task_manager = MockManager(worker.db_path, max_queue_size=10) - project = TEST_PROJECT created_at = datetime.now() task = Task( id="some-id", + project_id=TEST_PROJECT, type="hello_world_sync", created_at=created_at, status=TaskStatus.CREATED, @@ -102,11 +110,11 @@ async def test_work_once_run_sync_task(mock_worker: MockWorker): ) # When - await task_manager.enqueue(task, project) + await task_manager.enqueue(task) await worker.work_once() - saved_task = await task_manager.get_task(task_id=task.id, project=project) - saved_result = await task_manager.get_task_result(task_id=task.id, project=project) - saved_errors = await task_manager.get_task_errors(task_id=task.id, project=project) + saved_task = await task_manager.get_task(task_id=task.id) + saved_result = await task_manager.get_task_result(task_id=task.id) + saved_errors = await task_manager.get_task_errors(task_id=task.id) # Then assert not saved_errors @@ -114,6 +122,7 @@ async def test_work_once_run_sync_task(mock_worker: MockWorker): expected_task = Task( id="some-id", type="hello_world_sync", + project_id=TEST_PROJECT, progress=100, created_at=created_at, status=TaskStatus.DONE, @@ -127,9 +136,15 @@ async def test_work_once_run_sync_task(mock_worker: MockWorker): expected_task.pop("completedAt") assert saved_task == expected_task expected_events = [ - TaskEvent(task_id="some-id", status=TaskStatus.RUNNING, progress=0.0), TaskEvent( task_id="some-id", + project_id=TEST_PROJECT, + status=TaskStatus.RUNNING, + progress=0.0, + ), + TaskEvent( + task_id="some-id", + project_id=TEST_PROJECT, status=TaskStatus.DONE, progress=100.0, completed_at=completed_at, @@ -137,7 +152,9 @@ async def test_work_once_run_sync_task(mock_worker: MockWorker): ] assert worker.published_events == expected_events - expected_result = TaskResult(task_id="some-id", result="Hello world !") + expected_result = TaskResult( + task_id="some-id", project_id=TEST_PROJECT, result="Hello world !" + ) assert saved_result == expected_result @@ -147,33 +164,34 @@ async def test_task_wrapper_should_recover_from_recoverable_error( # Given worker = mock_failing_worker task_manager = MockManager(worker.db_path, max_queue_size=10) - project = TEST_PROJECT created_at = datetime.now() task = Task( id="some-id", + project_id=TEST_PROJECT, type="recovering_task", created_at=created_at, status=TaskStatus.CREATED, ) # When/Then - task = await task_manager.enqueue(task, project) + task = await task_manager.enqueue(task) assert task.status is TaskStatus.QUEUED await worker.work_once() - retried_task = await task_manager.get_task(task_id=task.id, project=project) + retried_task = await task_manager.get_task(task_id=task.id) assert retried_task.status is TaskStatus.QUEUED assert retried_task.retries == 1 await worker.work_once() - saved_task = await task_manager.get_task(task_id=task.id, project=project) - saved_result = await task_manager.get_task_result(task_id=task.id, project=project) - saved_errors = await task_manager.get_task_errors(task_id=task.id, project=project) + saved_task = await task_manager.get_task(task_id=task.id) + saved_result = await task_manager.get_task_result(task_id=task.id) + saved_errors = await task_manager.get_task_errors(task_id=task.id) # Then expected_task = Task( id="some-id", type="recovering_task", + project_id=TEST_PROJECT, progress=100, created_at=created_at, status=TaskStatus.DONE, @@ -190,24 +208,42 @@ async def test_task_wrapper_should_recover_from_recoverable_error( # No error should be saved assert not saved_errors # However we expect the worker to have logged them somewhere in the events - expected_result = TaskResult(task_id="some-id", result="i told you i could recover") + expected_result = TaskResult( + task_id="some-id", project_id=TEST_PROJECT, result="i told you i could recover" + ) assert saved_result == expected_result expected_events = [ - TaskEvent(task_id="some-id", status=TaskStatus.RUNNING, progress=0.0), TaskEvent( task_id="some-id", + project_id=TEST_PROJECT, + status=TaskStatus.RUNNING, + progress=0.0, + ), + TaskEvent( + task_id="some-id", + project_id=TEST_PROJECT, status=TaskStatus.QUEUED, retries=1, progress=None, # The progress should be left as is waiting before retry error=TaskError( - id="", title="Recoverable", detail="", occurred_at=datetime.now() + id="", + task_id="some-id", + title="Recoverable", + detail="", + occurred_at=datetime.now(), ), ), - TaskEvent(task_id="some-id", status=TaskStatus.RUNNING, progress=0.0), - TaskEvent(task_id="some-id", progress=0.0), TaskEvent( task_id="some-id", + project_id=TEST_PROJECT, + status=TaskStatus.RUNNING, + progress=0.0, + ), + TaskEvent(task_id="some-id", project_id=TEST_PROJECT, progress=0.0), + TaskEvent( + task_id="some-id", + project_id=TEST_PROJECT, status=TaskStatus.DONE, progress=100.0, completed_at=completed_at, @@ -233,7 +269,6 @@ async def test_task_wrapper_should_handle_non_recoverable_error( # Given worker = mock_failing_worker task_manager = MockManager(worker.db_path, max_queue_size=10) - project = TEST_PROJECT created_at = datetime.now() task = Task( id="some-id", @@ -243,12 +278,10 @@ async def test_task_wrapper_should_handle_non_recoverable_error( ) # When - await task_manager.enqueue(task, project) + await task_manager.enqueue(task) await worker.work_once() - saved_errors = await task_manager.get_task_errors( - task_id="some-id", project=project - ) - saved_task = await task_manager.get_task(task_id="some-id", project=project) + saved_errors = await task_manager.get_task_errors(task_id="some-id") + saved_task = await task_manager.get_task(task_id="some-id") # Then expected_task = Task( @@ -272,7 +305,11 @@ async def test_task_wrapper_should_handle_non_recoverable_error( task_id="some-id", status=TaskStatus.ERROR, error=TaskError( - id="", title="ValueError", detail="", occurred_at=datetime.now() + id="", + task_id="some-id", + title="ValueError", + detail="", + occurred_at=datetime.now(), ), ), ] @@ -295,7 +332,6 @@ async def test_task_wrapper_should_handle_unregistered_task(mock_worker: MockWor # Given worker = mock_worker task_manager = MockManager(worker.db_path, max_queue_size=10) - project = TEST_PROJECT created_at = datetime.now() task = Task( id="some-id", @@ -305,12 +341,10 @@ async def test_task_wrapper_should_handle_unregistered_task(mock_worker: MockWor ) # When - await task_manager.enqueue(task, project) + await task_manager.enqueue(task) await worker.work_once() - saved_task = await task_manager.get_task(task_id="some-id", project=project) - saved_errors = await task_manager.get_task_errors( - task_id="some-id", project=project - ) + saved_task = await task_manager.get_task(task_id="some-id") + saved_errors = await task_manager.get_task_errors(task_id="some-id") # Then expected_task = Task( @@ -334,6 +368,7 @@ async def test_task_wrapper_should_handle_unregistered_task(mock_worker: MockWor status=TaskStatus.ERROR, error=TaskError( id="error-id", + task_id="some-id", title="UnregisteredTask", detail="", occurred_at=datetime.now(), @@ -359,7 +394,6 @@ async def test_work_once_should_not_run_already_cancelled_task(mock_worker: Mock # Given worker = mock_worker task_manager = MockManager(worker.db_path, max_queue_size=10) - project = TEST_PROJECT created_at = datetime.now() task = Task( id="some-id", @@ -369,10 +403,10 @@ async def test_work_once_should_not_run_already_cancelled_task(mock_worker: Mock ) # When cancelled = safe_copy(task, update={"status": TaskStatus.CANCELLED}) - await task_manager.enqueue(task, project) + await task_manager.enqueue(task) # We mock the fact the task is still received but cancelled right after with pytest.raises(TaskAlreadyCancelled): - with patch.object(worker, "consume", return_value=(cancelled, project)): + with patch.object(worker, "consume", return_value=cancelled): await worker.work_once() @@ -382,7 +416,6 @@ async def test_cancel_running_task(mock_worker: MockWorker, requeue: bool): # Given worker = mock_worker task_manager = MockManager(worker.db_path, max_queue_size=10) - project = TEST_PROJECT created_at = datetime.now() duration = 10 task = Task( @@ -401,19 +434,19 @@ async def test_cancel_running_task(mock_worker: MockWorker, requeue: bool): worker._work_once_task = t asyncio_tasks.add(t) - await task_manager.enqueue(task, project) + await task_manager.enqueue(task) after_s = 2.0 async def _assert_running() -> bool: - saved = await task_manager.get_task(task_id=task.id, project=project) + saved = await task_manager.get_task(task_id=task.id) return saved.status is TaskStatus.RUNNING failure_msg = f"Failed to run task in less than {after_s}" assert await async_true_after(_assert_running, after_s=after_s), failure_msg - await task_manager.cancel(task_id=task.id, project=project, requeue=requeue) + await task_manager.cancel(task_id=task.id, requeue=requeue) async def _assert_has_status(status: TaskStatus) -> bool: - saved = await task_manager.get_task(task_id=task.id, project=project) + saved = await task_manager.get_task(task_id=task.id) return saved.status is status expected_status = TaskStatus.QUEUED if requeue else TaskStatus.CANCELLED @@ -431,7 +464,6 @@ async def test_worker_should_terminate_task_and_cancellation_event_loops( # Given worker = mock_worker task_manager = MockManager(worker.db_path, max_queue_size=10) - project = TEST_PROJECT created_at = datetime.now() duration = 100 task = Task( @@ -443,7 +475,7 @@ async def test_worker_should_terminate_task_and_cancellation_event_loops( ) # When - await task_manager.enqueue(task, project) + await task_manager.enqueue(task) asyncio_tasks = set() async with worker: work_forever_task = asyncio.create_task(worker.work_forever_async()) @@ -452,7 +484,7 @@ async def test_worker_should_terminate_task_and_cancellation_event_loops( asyncio_tasks.add(work_forever_task) async def _assert_running() -> bool: - saved = await task_manager.get_task(task_id=task.id, project=project) + saved = await task_manager.get_task(task_id=task.id) return saved.status is TaskStatus.RUNNING after_s = 2.0 @@ -514,7 +546,6 @@ async def test_worker_should_keep_working_on_fatal_error_in_task_codebase( # Given worker = mock_failing_worker task_manager = MockManager(worker.db_path, max_queue_size=10) - project = TEST_PROJECT created_at = datetime.now() task = Task( id="some-id", @@ -524,7 +555,7 @@ async def test_worker_should_keep_working_on_fatal_error_in_task_codebase( ) # When/Then - await task_manager.enqueue(task, project) + await task_manager.enqueue(task) with fail_if_exception("fatal_error_task"): await worker.work_once() @@ -535,7 +566,6 @@ async def test_worker_should_stop_working_on_fatal_error_in_worker_codebase( # Given worker = mock_failing_worker task_manager = MockManager(worker.db_path, max_queue_size=10) - project = TEST_PROJECT created_at = datetime.now() task = Task( id="some-id", @@ -545,7 +575,7 @@ async def test_worker_should_stop_working_on_fatal_error_in_worker_codebase( ) # When/Then - await task_manager.enqueue(task, project) + await task_manager.enqueue(task) with patch.object(worker, "_consume") as mocked_consume: class _FatalError(Exception): ... diff --git a/icij-worker/icij_worker/utils/tests.py b/icij-worker/icij_worker/utils/tests.py index 6033c445..33f753ae 100644 --- a/icij-worker/icij_worker/utils/tests.py +++ b/icij-worker/icij_worker/utils/tests.py @@ -17,7 +17,6 @@ Dict, List, Optional, - Tuple, Type, TypeVar, Union, @@ -72,6 +71,7 @@ class DBMixin(ABC): def __init__(self, db_path: Path): self._db_path = db_path + self._task_projects: Dict[str, str] = dict() @property def db_path(self) -> Path: @@ -84,9 +84,21 @@ def _read(self): return json.loads(self._db_path.read_text()) @staticmethod - def _task_key(task_id: str, project: str) -> str: + def _task_key(task_id: str, project: Optional[str]) -> str: return str((task_id, project)) + def _get_task_project(self, task_id) -> str: + if task_id not in self._task_projects: + db = self._read() + self._task_projects = dict( + eval(k) # pylint: disable=eval-used + for k in db[self._task_collection].keys() + ) + try: + return self._task_projects[task_id] + except KeyError as e: + raise UnknownTask(task_id) from e + @classmethod def fresh_db(cls, db_path: Path): db = { @@ -176,8 +188,8 @@ def __init__(self, db_path: Path, max_queue_size: int): super().__init__(db_path) self._max_queue_size = max_queue_size - async def _enqueue(self, task: Task, project: str) -> Task: - key = self._task_key(task_id=task.id, project=project) + async def _enqueue(self, task: Task) -> Task: + key = self._task_key(task_id=task.id, project=task.project_id) db = self._read() tasks = db[self._task_collection] n_queued = sum( @@ -193,7 +205,8 @@ async def _enqueue(self, task: Task, project: str) -> Task: self._write(db) return task - async def _cancel(self, *, task_id: str, project: str, requeue: bool): + async def _cancel(self, *, task_id: str, requeue: bool): + project = self._get_task_project(task_id) key = self._task_key(task_id=task_id, project=project) event = CancelledTaskEvent( task_id=task_id, requeue=requeue, created_at=datetime.now() @@ -202,7 +215,8 @@ async def _cancel(self, *, task_id: str, project: str, requeue: bool): db[self._cancel_event_collection][key] = event.dict() self._write(db) - async def get_task(self, *, task_id: str, project: str) -> Task: + async def get_task(self, *, task_id: str) -> Task: + project = self._get_task_project(task_id) key = self._task_key(task_id=task_id, project=project) db = self._read() try: @@ -211,7 +225,8 @@ async def get_task(self, *, task_id: str, project: str) -> Task: except KeyError as e: raise UnknownTask(task_id) from e - async def get_task_errors(self, task_id: str, project: str) -> List[TaskError]: + async def get_task_errors(self, task_id: str) -> List[TaskError]: + project = self._get_task_project(task_id) key = self._task_key(task_id=task_id, project=project) db = self._read() errors = db[self._error_collection] @@ -219,7 +234,8 @@ async def get_task_errors(self, task_id: str, project: str) -> List[TaskError]: errors = [TaskError(**err) for err in errors] return errors - async def get_task_result(self, task_id: str, project: str) -> TaskResult: + async def get_task_result(self, task_id: str) -> TaskResult: + project = self._get_task_project(task_id) key = self._task_key(task_id=task_id, project=project) db = self._read() results = db[self._result_collection] @@ -230,7 +246,7 @@ async def get_task_result(self, task_id: str, project: str) -> TaskResult: async def get_tasks( self, - project: str, + project_id: Optional[str] = None, task_type: Optional[str] = None, status: Optional[Union[List[TaskStatus], TaskStatus]] = None, ) -> List[Task]: @@ -252,7 +268,7 @@ def __init__(self, db_path: Path): super().__init__(db_path) self.published_events = [] - async def publish_event(self, event: TaskEvent, project: str): + async def _publish_event(self, event: TaskEvent): self.published_events.append(event) # Let's simulate that we have an event handler which will reflect some event # into the DB, we could not do it. In this case tests should not expect that @@ -260,6 +276,7 @@ async def publish_event(self, event: TaskEvent, project: str): # published_events (which could be enough). # Here we choose to reflect the change in the DB since its closer to what # will happen IRL and test integration further + project = self._get_task_project(event.task_id) key = self._task_key(task_id=event.task_id, project=project) db = self._read() try: @@ -359,14 +376,16 @@ def _from_config(cls, config: MockWorkerConfig, **extras) -> MockWorker: def _to_config(self) -> MockWorkerConfig: return MockWorkerConfig(db_path=self._db_path) - async def _save_result(self, result: TaskResult, project: str): + async def _save_result(self, result: TaskResult): + project = self._get_task_project(result.task_id) task_key = self._task_key(task_id=result.task_id, project=project) db = self._read() db[self._result_collection][task_key] = result self._write(db) - async def _save_error(self, error: TaskError, task: Task, project: str): - task_key = self._task_key(task_id=task.id, project=project) + async def _save_error(self, error: TaskError): + project = self._get_task_project(task_id=error.task_id) + task_key = self._task_key(task_id=error.task_id, project=project) db = self._read() errors = db[self._error_collection].get(task_key) if errors is None: @@ -393,7 +412,8 @@ def _get_db_result(self, task_id: str, project: str) -> TaskResult: except KeyError as e: raise UnknownTask(task_id) from e - async def _acknowledge(self, task: Task, project: str, completed_at: datetime): + async def _acknowledge(self, task: Task, completed_at: datetime): + project = self._get_task_project(task.id) key = self._task_key(task.id, project) db = self._read() tasks = db[self._task_collection] @@ -410,9 +430,8 @@ async def _acknowledge(self, task: Task, project: str, completed_at: datetime): tasks[key] = safe_copy(saved_task, update=update) self._write(db) - async def _negatively_acknowledge( - self, nacked: Task, project: str, *, cancelled: bool - ): + async def _negatively_acknowledge(self, nacked: Task, *, cancelled: bool): + project = self._get_task_project(nacked.id) key = self._task_key(nacked.id, project) db = self._read() tasks = db[self._task_collection] @@ -431,7 +450,7 @@ async def _negatively_acknowledge( tasks[key].update(update) self._write(db) - async def _consume(self) -> Tuple[Task, str]: + async def _consume(self) -> Task: return await self._consume_( self._task_collection, Task, @@ -439,7 +458,7 @@ async def _consume(self) -> Tuple[Task, str]: order=lambda t: t.created_at, ) - async def _consume_cancelled(self) -> Tuple[CancelledTaskEvent, str]: + async def _consume_cancelled(self) -> CancelledTaskEvent: return await self._consume_( self._cancel_event_collection, CancelledTaskEvent, @@ -452,7 +471,7 @@ async def _consume_( consumed_cls: Type[R], select: Optional[Callable[[R], bool]] = None, order: Optional[Callable[[R], Any]] = None, - ) -> Tuple[R, str]: + ) -> R: while "i'm waiting until I find something interesting": db = self._read() selected = db[collection] @@ -464,8 +483,7 @@ async def _consume_( k, t = min(selected, key=lambda x: order(x[1])) else: k, t = selected[0] - project = eval(k)[1] # pylint: disable=eval-used - return t, project + return t await asyncio.sleep(self._task_queue_poll_interval_s) async def work_once(self): diff --git a/icij-worker/icij_worker/worker/__init__.py b/icij-worker/icij_worker/worker/__init__.py index 489e9be2..8ba9d4f0 100644 --- a/icij-worker/icij_worker/worker/__init__.py +++ b/icij-worker/icij_worker/worker/__init__.py @@ -3,16 +3,6 @@ from .config import WorkerConfig from .worker import Worker -try: - from .neo4j import Neo4jWorker, Neo4jWorkerConfig, Neo4jEventPublisher -except ImportError: - pass - -try: - from .amqp import AMQPWorker, AMQPWorkerConfig -except ImportError: - pass - @unique class WorkerType(str, Enum): diff --git a/icij-worker/icij_worker/worker/amqp.py b/icij-worker/icij_worker/worker/amqp.py index 6d31a069..9fd0dd90 100644 --- a/icij-worker/icij_worker/worker/amqp.py +++ b/icij-worker/icij_worker/worker/amqp.py @@ -1,9 +1,9 @@ from __future__ import annotations -from contextlib import AsyncExitStack +from contextlib import AbstractAsyncContextManager, AsyncExitStack from datetime import datetime from functools import lru_cache -from typing import ClassVar, Dict, Optional, Tuple, Type +from typing import ClassVar, Dict, Optional, Type, cast from aio_pika import RobustQueue, connect_robust from aio_pika.abc import ( @@ -27,8 +27,12 @@ WorkerConfig, WorkerType, ) -from icij_worker.event_publisher import AMQPPublisher -from icij_worker.event_publisher.amqp import Exchange, RobustConnection, Routing +from icij_worker.event_publisher.amqp import ( + AMQPPublisher, + Exchange, + RobustConnection, + Routing, +) from icij_worker.task import CancelledTaskEvent from icij_worker.utils.from_config import T @@ -81,14 +85,12 @@ def __init__( inactive_after_s: Optional[float] = None, handle_signals: bool = True, teardown_dependencies: bool = False, - **kwargs, ): super().__init__( app, worker_id, handle_signals=handle_signals, teardown_dependencies=teardown_dependencies, - **kwargs, ) self._cancel_event_queue_name = ( f"{self._cancel_event_routing().default_queue}" f"-{self._id}" @@ -121,6 +123,9 @@ async def _aenter__(self): create_queue=self._declare_exchanges, exchange_type=ExchangeType.DIRECT, ) + self._task_queue_iterator = cast( + AbstractAsyncContextManager, self._task_queue_iterator + ) await self._exit_stack.enter_async_context(self._task_queue_iterator) self._cancel_evt_connection = await self._make_connection() await self._exit_stack.enter_async_context(self._cancel_evt_connection) @@ -132,6 +137,9 @@ async def _aenter__(self): created_queue_name=self._cancel_event_queue_name, exchange_type=ExchangeType.FANOUT, ) + self._cancel_evt_queue_iterator = cast( + AbstractAsyncContextManager, self._cancel_evt_queue_iterator + ) await self._exit_stack.enter_async_context(self._cancel_evt_queue_iterator) async def _make_connection(self) -> AbstractRobustConnection: @@ -156,7 +164,9 @@ async def _get_queue_iterator( durable_queue: bool = True, ) -> AbstractQueueIterator: channel: AbstractRobustChannel = await connection.channel() - await self._exit_stack.enter_async_context(channel) + await self._exit_stack.enter_async_context( + cast(AbstractAsyncContextManager, channel) + ) arguments = None if self._declare_exchanges: if routing.dead_letter_routing is not None: @@ -188,18 +198,17 @@ async def _get_queue_iterator( kwargs["timeout"] = self._inactive_after_s return queue.iterator(**kwargs) - async def _consume(self) -> Tuple[Task, str]: + async def _consume(self) -> Task: message: AbstractIncomingMessage = await self._task_messages_it.__anext__() - # TODO: handle project deserialization here task = Task.parse_raw(message.body) self._delivered[task.id] = message - return task, _PROJECT_PLACEHOLDER + return task - async def _consume_cancelled(self) -> Tuple[CancelledTaskEvent, str]: + async def _consume_cancelled(self) -> CancelledTaskEvent: message: AbstractIncomingMessage = await self._cancel_events_it.__anext__() # TODO: handle project deserialization here event = CancelledTaskEvent.parse_raw(message.body) - return event, _PROJECT_PLACEHOLDER + return event @property def _task_messages_it(self) -> AbstractQueueIterator: @@ -221,9 +230,7 @@ def _cancel_events_it(self) -> AbstractQueueIterator: raise ValueError(msg) return self._cancel_evt_queue_iterator - async def _acknowledge( - self, task: Task, project: str, completed_at: datetime - ) -> Task: + async def _acknowledge(self, task: Task, completed_at: datetime) -> Task: message = self._delivered[task.id] await message.ack() acked = safe_copy( @@ -236,21 +243,20 @@ async def _acknowledge( ) return acked - async def _negatively_acknowledge( - self, nacked: Task, project: str, *, cancelled: bool - ): + async def _negatively_acknowledge(self, nacked: Task, *, cancelled: bool): + # pylint: disable=unused-argument message = self._delivered[nacked.id] requeue = nacked.status is TaskStatus.QUEUED await message.nack(requeue=requeue) - async def publish_event(self, event: TaskEvent, project: str): - await self._publisher.publish_event(event, project) + async def _publish_event(self, event: TaskEvent): + await self._publisher.publish_event_(event) - async def _save_result(self, result: TaskResult, project: str): - await self._publisher.publish_result(result, project) + async def _save_result(self, result: TaskResult): + await self._publisher.publish_result(result) - async def _save_error(self, error: TaskError, task: Task, project: str): - await self._publisher.publish_error(error, project, project) + async def _save_error(self, error: TaskError): + await self._publisher.publish_error(error) @classmethod @lru_cache(maxsize=1) diff --git a/icij-worker/icij_worker/worker/neo4j.py b/icij-worker/icij_worker/worker/neo4j_.py similarity index 85% rename from icij-worker/icij_worker/worker/neo4j.py rename to icij-worker/icij_worker/worker/neo4j_.py index 06c152c0..f3ba1d2b 100644 --- a/icij-worker/icij_worker/worker/neo4j.py +++ b/icij-worker/icij_worker/worker/neo4j_.py @@ -3,16 +3,13 @@ import asyncio import functools import json -from contextlib import asynccontextmanager from datetime import datetime from typing import ( - AsyncGenerator, Awaitable, Callable, ClassVar, Dict, Optional, - Tuple, TypeVar, ) @@ -40,7 +37,6 @@ TASK_RETRIES, ) from icij_common.neo4j.migrate import retrieve_projects -from icij_common.neo4j.projects import project_db_session from icij_common.pydantic_utils import ICIJModel, jsonable_encoder from icij_worker import ( AsyncApp, @@ -52,9 +48,9 @@ WorkerConfig, WorkerType, ) -from ..event_publisher.neo4j import Neo4jEventPublisher -from ..exceptions import TaskAlreadyReserved, UnknownTask -from ..task import CancelledTaskEvent +from icij_worker.event_publisher.neo4j_ import Neo4jEventPublisher +from icij_worker.exceptions import TaskAlreadyReserved, UnknownTask +from icij_worker.task import CancelledTaskEvent _TASK_MANDATORY_FIELDS_BY_ALIAS = { f for f in Task.schema(by_alias=True)["required"] if f != "id" @@ -126,21 +122,19 @@ def _from_config(cls, config: Neo4jWorkerConfig, **extras) -> Neo4jWorker: worker.set_config(config) return worker - async def _consume(self) -> Tuple[Task, str]: + async def _consume(self) -> Task: return await self._consume_( _consume_task_tx, refresh_interval_s=self._new_tasks_refresh_interval_s, ) - async def _consume_cancelled(self) -> Tuple[CancelledTaskEvent, str]: + async def _consume_cancelled(self) -> CancelledTaskEvent: return await self._consume_( _consume_cancelled_task_tx, refresh_interval_s=self._cancelled_tasks_refresh_interval_s, ) - async def _consume_( - self, consume_tx: ConsumeT, refresh_interval_s: float - ) -> Tuple[T, str]: + async def _consume_(self, consume_tx: ConsumeT, refresh_interval_s: float) -> T: projects = [] refresh_projects_i = 0 while "i'm waiting until I find something interesting": @@ -150,15 +144,14 @@ async def _consume_( projects = await retrieve_projects(self._driver) for p in projects: async with self._project_session(p.name) as sess: - received = await sess.execute_write(consume_tx, worker_id=self.id) + tx = functools.partial(consume_tx, project_id=p.name) + received = await sess.execute_write(tx, worker_id=self.id) if received is not None: - return received, p.name + return received await asyncio.sleep(refresh_interval_s) refresh_projects_i += 1 - async def _negatively_acknowledge( - self, nacked: Task, project: str, *, cancelled: bool - ): + async def _negatively_acknowledge(self, nacked: Task, *, cancelled: bool): if nacked.status is TaskStatus.QUEUED: nack_fn = functools.partial( _nack_and_requeue_task_tx, retries=nacked.retries, cancelled=cancelled @@ -169,25 +162,39 @@ async def _negatively_acknowledge( ) else: nack_fn = _nack_task_tx + project = self._get_task_project_id(nacked.id) async with self._project_session(project) as sess: await sess.execute_write(nack_fn, task_id=nacked.id, worker_id=self.id) - async def _save_result(self, result: TaskResult, project: str): - async with self._project_session(project) as sess: + async def _save_result(self, result: TaskResult): + project_id = result.project_id + if project_id is None: + raise ValueError( + "neo4j expects project to be provided in order to fetch tasks from the" + " project's DB" + ) + async with self._project_session(project_id) as sess: res_str = json.dumps(jsonable_encoder(result.result)) await sess.execute_write( _save_result_tx, task_id=result.task_id, result=res_str ) - async def _save_error(self, error: TaskError, task: Task, project: str): - async with self._project_session(project) as sess: + async def _save_error(self, error: TaskError): + project_id = error.project_id + if project_id is None: + raise ValueError( + "neo4j expects project to be provided in order to fetch tasks from the" + " project's DB" + ) + async with self._project_session(project_id) as sess: await sess.execute_write( _save_error_tx, - task_id=task.id, + task_id=error.task_id, error_props=error.dict(by_alias=True), ) - async def _acknowledge(self, task: Task, project: str, completed_at: datetime): + async def _acknowledge(self, task: Task, completed_at: datetime): + project = self._get_task_project_id(task.id) async with self._project_session(project) as sess: await sess.execute_write( _acknowledge_task_tx, @@ -196,19 +203,12 @@ async def _acknowledge(self, task: Task, project: str, completed_at: datetime): completed_at=completed_at, ) - @asynccontextmanager - async def _project_session( - self, project: str - ) -> AsyncGenerator[neo4j.AsyncSession, None]: - async with project_db_session(self._driver, project) as sess: - yield sess - async def _aexit__(self, exc_type, exc_val, exc_tb): await self._driver.__aexit__(exc_type, exc_val, exc_tb) async def _consume_task_tx( - tx: neo4j.AsyncTransaction, worker_id: str + tx: neo4j.AsyncTransaction, *, worker_id: str, project_id: str ) -> Optional[Task]: query = f"""MATCH (t:{TASK_NODE}:`{TaskStatus.QUEUED.value}`) WITH t @@ -228,11 +228,11 @@ async def _consume_task_tx( return None except ConstraintError as e: raise TaskAlreadyReserved() from e - return Task.from_neo4j(task) + return Task.from_neo4j(task, project_id=project_id) async def _consume_cancelled_task_tx( - tx: neo4j.AsyncTransaction, **_ + tx: neo4j.AsyncTransaction, project_id: str, **_ ) -> Optional[CancelledTaskEvent]: get_event_query = f"""MATCH (task:{TASK_NODE})-[ :{TASK_CANCELLED_BY_EVENT_REL} @@ -247,7 +247,7 @@ async def _consume_cancelled_task_tx( event = await res.single(strict=True) except ResultNotSingleError: return None - return CancelledTaskEvent.from_neo4j(event) + return CancelledTaskEvent.from_neo4j(event, project_id=project_id) async def _acknowledge_task_tx( @@ -301,7 +301,7 @@ async def _nack_and_requeue_task_tx( retries: int, cancelled: bool, ): - clean_cancelled_query = """ + clean_cancelled_query = f""" WITH task, lock OPTIONAL MATCH (task)-[ :{TASK_CANCELLED_BY_EVENT_REL} @@ -381,6 +381,7 @@ async def _save_error_tx( CREATE (error:{TASK_ERROR_NODE} $errorProps)-[:{TASK_ERROR_OCCURRED_TYPE}]->(task) RETURN task, error""" res = await tx.run(query, taskId=task_id, errorProps=error_props) - records = [rec async for rec in res] - if not records: - raise UnknownTask(task_id) + try: + await res.single(strict=True) + except ResultNotSingleError as e: + raise UnknownTask(task_id) from e diff --git a/icij-worker/icij_worker/worker/worker.py b/icij-worker/icij_worker/worker/worker.py index 54644af0..3600161b 100644 --- a/icij-worker/icij_worker/worker/worker.py +++ b/icij-worker/icij_worker/worker/worker.py @@ -6,7 +6,6 @@ import logging import traceback from abc import abstractmethod -from collections import defaultdict from contextlib import AbstractAsyncContextManager, asynccontextmanager from copy import deepcopy from datetime import datetime @@ -40,10 +39,10 @@ UnknownTask, UnregisteredTask, ) +from icij_worker.task import CancelledTaskEvent from icij_worker.utils import Registrable from icij_worker.worker.process import HandleSignalsMixin -from ..event_publisher import EventPublisher -from ..task import CancelledTaskEvent +from icij_worker.event_publisher.event_publisher import EventPublisher logger = logging.getLogger(__name__) @@ -77,8 +76,8 @@ def __init__( self._work_once_task: Optional[asyncio.Task] = None self._watch_cancelled_task: Optional[asyncio.Task] = None self._already_exiting = False - self._current: Optional[Tuple[Task, str]] = None - self._cancelled: Dict[str, Dict[str, CancelledTaskEvent]] = defaultdict(dict) + self._current: Optional[Task] = None + self._cancelled: Dict[str, CancelledTaskEvent] = dict() self._cancelling: Optional[str] = None self._config: Optional[C] = None # We use asyncio lock, not thread lock, since the worker is supposed run in a @@ -151,9 +150,9 @@ async def _work_forever(self): self.info("tried to consume a cancelled task, skipping...") continue - def _get_cancel_event(self, task: Task, project: str) -> CancelledTaskEvent: + def _get_cancel_event(self, task: Task) -> CancelledTaskEvent: try: - return self._cancelled[project][task.id] + return self._cancelled[task.id] except KeyError as e: raise UnknownTask(task_id=task.id, worker_id=self._id) from e @@ -168,21 +167,21 @@ def graceful_shutdown(self) -> bool: @final async def _work_once(self): async with self.acknowledgment_cm(): - task, project = await self.consume() - await task_wrapper(self, task, project) + task = await self.consume() + await task_wrapper(self, task) @final - async def consume(self) -> Tuple[Task, str]: - task, project = await self._consume() + async def consume(self) -> Task: + task = await self._consume() self.debug('Task(id="%s") locked', task.id) async with self._current_lock: - self._current = task, project + self._current = task progress = 0.0 update = {"progress": progress} task = safe_copy(task, update=update) event = TaskEvent(task_id=task.id, progress=progress, status=TaskStatus.RUNNING) - await self.publish_event(event, project) - return task, project + await self.publish_event(event, task) + return task @final @asynccontextmanager @@ -190,7 +189,7 @@ async def acknowledgment_cm(self): try: yield async with self._current_lock: - await self.acknowledge(*self._current) + await self.acknowledge(self._current) except asyncio.CancelledError as e: await self._handle_cancel_event(e) except (TaskAlreadyCancelled, TaskAlreadyReserved) as e: @@ -198,31 +197,33 @@ async def acknowledgment_cm(self): raise e except RecoverableError: async with self._current_lock: - self.error('Task(id="%s") encountered error', self._current[0].id) - await self.negatively_acknowledge(*self._current, requeue=True) + self.error('Task(id="%s") encountered error', self._current.id) + await self.negatively_acknowledge(self._current, requeue=True) except Exception as fatal_error: # pylint: disable=broad-exception-caught async with self._current_lock: if self._current is not None: # The error is due to the current task, other tasks might success, # let's fail this task and keep working - await self._handle_fatal_error(fatal_error, *self._current) + await self._handle_fatal_error(fatal_error, self._current) return # The error was in the worker's code, something is wrong that won't change # at the next task, let's make the worker crash raise fatal_error - async def _handle_fatal_error( - self, fatal_error: BaseException, task: Task, project: str - ): + async def _handle_fatal_error(self, fatal_error: BaseException, task: Task): if isinstance(fatal_error, MaxRetriesExceeded): self.error('Task(id="%s") exceeded max retries, nacking it...', task.id) else: self.error( 'fatal error during Task(id="%s") execution, nacking it...', task.id ) - task_error = TaskError.from_exception(fatal_error) - await self.save_error(error=task_error, task=task, project=project) - await self.negatively_acknowledge(task, project, requeue=False) + task_error = TaskError.from_exception(fatal_error, task) + await self.save_error(error=task_error) + # Once the error has been saved, we notify the event consumers, they are + # responsible for reflecting the fact that the error has occurred wherever + # relevant. The source of truth will be error storage + await self.publish_error_event(task_error, task) + await self.negatively_acknowledge(task, requeue=False) async def _handle_cancel_event(self, e: asyncio.CancelledError): async with self._current_lock: @@ -237,38 +238,37 @@ async def _handle_cancel_event(self, e: asyncio.CancelledError): " in between, discarding cancel event !" ) return - current_t, current_p = self._current - event = self._get_cancel_event(current_t, current_p) - self.info('Task(id="%s") cancellation requested !', current_t.id) + event = self._get_cancel_event(self._current) + self.info('Task(id="%s") cancellation requested !', self._current.id) async with self._current_lock: await self._negatively_acknowledge_running_task(event.requeue, cancel=True) @final - async def acknowledge(self, task: Task, project: str): + async def acknowledge(self, task: Task): completed_at = datetime.now() self.info('Task(id="%s") acknowledging...', task.id) - await self._acknowledge(task, project, completed_at) + await self._acknowledge(task, completed_at) self.info('Task(id="%s") acknowledged', task.id) self.debug('Task(id="%s") publishing acknowledgement event', task.id) - event = TaskEvent( + task_event = TaskEvent( task_id=task.id, status=TaskStatus.DONE, progress=100, completed_at=completed_at, + project_id=task.project_id, ) + event = task_event # Tell the listeners that the task succeeded - await self.publish_event(event, project) + await self.publish_event(event, task) self.info('Task(id="%s") successful !', task.id) self._current = None @abstractmethod - async def _acknowledge( - self, task: Task, project: str, completed_at: datetime - ) -> Task: ... + async def _acknowledge(self, task: Task, completed_at: datetime) -> Task: ... @final async def negatively_acknowledge( - self, task: Task, project: str, *, requeue: bool, cancel: bool = False + self, task: Task, *, requeue: bool, cancel: bool = False ): self.info( "negatively acknowledging Task(id=%s) with (requeue=%s)...", @@ -287,18 +287,16 @@ async def negatively_acknowledge( else: update = {"status": TaskStatus.ERROR} nacked = safe_copy(task, update=update) - await self._negatively_acknowledge(nacked, project, cancelled=cancel) + await self._negatively_acknowledge(nacked, cancelled=cancel) self.info( "Task(id=%s) negatively acknowledged (requeue=%s)!", nacked.id, requeue ) self._current = None @abstractmethod - async def _negatively_acknowledge( - self, nacked: Task, project: str, *, cancelled: bool - ): ... + async def _negatively_acknowledge(self, nacked: Task, *, cancelled: bool): ... - # TODO: the cancel parameter is needed in some implementation to know whether + # TODO: the cancelled parameter is needed in some implementation to know whether # the cancellation happened, it which case the implem might want to clean # cancellation events. Another alternative would be to add a _clean_cancelled # method and use it in the negatively_acknowledge method right after calling @@ -306,56 +304,45 @@ async def _negatively_acknowledge( # it's nice to nack and clean in a single transaction @abstractmethod - async def _consume(self) -> Tuple[Task, str]: ... + async def _consume(self) -> Task: ... @abstractmethod - async def _save_result(self, result: TaskResult, project: str): + async def _save_result(self, result: TaskResult): """Save the result in a safe place""" @abstractmethod - async def _save_error(self, error: TaskError, task: Task, project: str): + async def _save_error(self, error: TaskError): """Save the error in a safe place""" @final - async def save_result(self, result: TaskResult, project: str): + async def save_result(self, result: TaskResult): self.info('Task(id="%s") saving result...', result.task_id) - await self._save_result(result, project) + await self._save_result(result) self.info('Task(id="%s") result saved !', result.task_id) @final - async def save_error(self, error: TaskError, task: Task, project: str): - self.error('Task(id="%s"): %s\n%s', task.id, error.title, error.detail) + async def save_error(self, error: TaskError): + self.error('Task(id="%s"): %s\n%s', error.task_id, error.title, error.detail) # Save the error in the appropriate location - self.debug('Task(id="%s") saving error', task.id) - await self._save_error(error, task, project) - # Once the error has been saved, we notify the event consumers, they are - # responsible for reflecting the fact that the error has occurred wherever - # relevant. The source of truth will be error storage - await self.publish_error_event(error=error, task_id=task.id, project=project) + self.debug('Task(id="%s") saving error', error.task_id) + await self._save_error(error) @final async def publish_error_event( - self, - *, - error: TaskError, - task_id: str, - project: str, - retries: Optional[int] = None, + self, error: TaskError, task: Task, retries: Optional[int] = None ): # Tell the listeners that the task failed - self.debug('Task(id="%s") publish error event', task_id) - event = TaskEvent.from_error(error, task_id, retries) - await self.publish_event(event, project) + self.debug('Task(id="%s") publish error event', task.id) + event = TaskEvent.from_error(error, task.id, retries) + await self.publish_event(event, task) @final - async def _publish_progress(self, progress: float, task: Task, project: str): + async def _publish_progress(self, progress: float, task: Task): event = TaskEvent(progress=progress, task_id=task.id) - await self.publish_event(event, project) + await self.publish_event(event, task) @final - def parse_task( - self, task: Task, project: str - ) -> Tuple[Callable, Tuple[Type[Exception], ...]]: + def parse_task(self, task: Task) -> Tuple[Callable, Tuple[Type[Exception], ...]]: registered = _retrieve_registered_task(task, self._app) recoverable = registered.recover_from task_fn = registered.task @@ -364,9 +351,7 @@ def parse_task( for param in signature(task_fn).parameters.values() ) if supports_progress: - publish_progress = functools.partial( - self._publish_progress, task=task, project=project - ) + publish_progress = functools.partial(self._publish_progress, task=task) task_fn = functools.partial(task_fn, progress=publish_progress) return task_fn, recoverable @@ -374,8 +359,8 @@ def parse_task( async def _watch_cancelled(self): try: while True: - cancel_event, project = await self._consume_cancelled() - self._cancelled[project][cancel_event.task_id] = cancel_event + cancel_event = await self._consume_cancelled() + self._cancelled[cancel_event.task_id] = cancel_event async with self._current_lock, self._cancellation_lock: if self._current is None: continue @@ -385,7 +370,7 @@ async def _watch_cancelled(self): ): continue self._cancelling = cancel_event.task_id - task, project = self._current + task = self._current if task.id != cancel_event.task_id: continue self.info( @@ -403,7 +388,7 @@ async def _watch_cancelled(self): raise e @abstractmethod - async def _consume_cancelled(self) -> Tuple[CancelledTaskEvent, str]: ... + async def _consume_cancelled(self) -> CancelledTaskEvent: ... @final def check_retries(self, retries: int, task: Task): @@ -473,10 +458,8 @@ async def _negatively_acknowledge_running_task( self, requeue: bool, cancel: bool = False ): if self._current: - task, project = self._current - await self.negatively_acknowledge( - task, project, requeue=requeue, cancel=cancel - ) + task = self._current + await self.negatively_acknowledge(task, requeue=requeue, cancel=cancel) @final async def shutdown(self): @@ -499,18 +482,16 @@ def _retrieve_registered_task( return registered -async def task_wrapper(worker: Worker, task: Task, project: str): +async def task_wrapper(worker: Worker, task: Task): # Skips if already reserved if task.status is TaskStatus.CANCELLED: worker.info('Task(id="%s") already cancelled skipping it !', task.id) raise TaskAlreadyCancelled(task_id=task.id) # Parse task to retrieve recoverable errors and max retries - task_fn, recoverable_errors = worker.parse_task(task, project) - task_inputs = add_missing_args( - task_fn, task.inputs, config=worker.config, project=project - ) + task_fn, recoverable_errors = worker.parse_task(task) + task_inputs = add_missing_args(task_fn, task.inputs, config=worker.config) # Retry task until success, fatal error or max retry exceeded - await _retry_task(worker, task, task_fn, task_inputs, project, recoverable_errors) + await _retry_task(worker, task, task_fn, task_inputs, recoverable_errors) async def _retry_task( @@ -518,14 +499,13 @@ async def _retry_task( task: Task, task_fn: Callable, task_inputs: Dict, - project: str, recoverable_errors: Tuple[Type[Exception], ...], ): retries = task.retries or 0 if retries: # In the case of the retry, let's reset the progress event = TaskEvent(task_id=task.id, progress=0.0) - await worker.publish_event(event, project) + await worker.publish_event(event, task) try: task_res = task_fn(**task_inputs) if isawaitable(task_res): @@ -533,17 +513,12 @@ async def _retry_task( except recoverable_errors as e: # This will throw a MaxRetriesExceeded when necessary worker.check_retries(retries, task) - error = TaskError.from_exception(e) - await worker.publish_error_event( - error=error, - task_id=task.id, - project=project, - retries=retries + 1, - ) + error = TaskError.from_exception(e, task) + await worker.publish_error_event(error, task, retries=retries + 1) raise RecoverableError() from e worker.info('Task(id="%s") complete, saving result...', task.id) - result = TaskResult(task_id=task.id, result=task_res) - await worker.save_result(result, project) + result = TaskResult.from_task(task=task, result=task_res) + await worker.save_result(result) return