Skip to content

Commit

Permalink
chore: add the project_id attribute to tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
ClemDoum committed Mar 27, 2024
1 parent 095fc77 commit a1cb239
Show file tree
Hide file tree
Showing 21 changed files with 796 additions and 603 deletions.
18 changes: 16 additions & 2 deletions icij-worker/icij_worker/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,20 @@
from .app import AsyncApp
from .task import Task, TaskError, TaskEvent, TaskResult, TaskStatus
from .task_manager import TaskManager
from .worker import Worker, WorkerConfig, WorkerType
from .worker.neo4j import Neo4jWorker

try:
from icij_worker.worker.worker.amqp import AMQPWorker, AMQPWorkerConfig
from icij_worker.event_publisher.amqp import AMQPPublisher
except ImportError:
pass

try:
from icij_worker.worker.neo4j_ import Neo4jWorker, Neo4jWorkerConfig
from icij_worker.event_publisher.neo4j_ import Neo4jEventPublisher
from icij_worker.task_manager.neo4j_ import Neo4JTaskManager
except ImportError:
pass

from .backend import WorkerBackend
from .event_publisher import EventPublisher, Neo4jEventPublisher
from .event_publisher import EventPublisher
10 changes: 0 additions & 10 deletions icij-worker/icij_worker/event_publisher/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1 @@
from .event_publisher import EventPublisher

try:
from .neo4j import Neo4jEventPublisher
except ImportError:
pass

try:
from .amqp import AMQPPublisher
except ImportError:
pass
38 changes: 16 additions & 22 deletions icij-worker/icij_worker/event_publisher/amqp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from icij_common.logging_utils import LogWithNameMixin
from icij_common.pydantic_utils import LowerCamelCaseModel, NoEnumModel
from icij_worker import Task, TaskError, TaskEvent, TaskResult
from icij_worker import TaskError, TaskEvent, TaskResult
from . import EventPublisher


Expand Down Expand Up @@ -112,13 +112,7 @@ def err_routing(cls) -> Routing:
def _routings(self) -> List[Routing]:
return [self.evt_routing(), self.res_routing(), self.err_routing()]

async def publish_event(self, event: TaskEvent, project: str):
# pylint: disable=unused-argument
# TODO: for now project information is not leverage on the AMQP side which is
# not very convenient as clients will won't know from which project the event
# is coming. This is limitating as for instance when it comes to logs errors,
# such events must be save in separate DBs in order to avoid project data
# leaking to other projects through the DB
async def _publish_event(self, event: TaskEvent):
message = event.json().encode()
await self._publish_message(
message,
Expand All @@ -127,13 +121,14 @@ async def publish_event(self, event: TaskEvent, project: str):
mandatory=False,
)

async def publish_result(self, result: TaskResult, project: str):
# pylint: disable=unused-argument
# TODO: for now project information is not leverage on the AMQP side which is
# not very convenient as clients will won't know from which project the result
# is coming. This is limitating as for instance when as result must probably
# be saved in separate DBs in order to avoid project data leaking to other
# projects through the DB
publish_event_ = _publish_event

async def publish_result(self, result: TaskResult):
# TODO: for now task project information is not leverage on the AMQP side which
# is not very convenient as clients will won't know from which project the
# result is coming. This is limitating as for instance when as result must
# probably be saved in separate DBs in order to avoid project data leaking to
# other projects through the DB
message = result.json().encode()
await self._publish_message(
message,
Expand All @@ -142,13 +137,12 @@ async def publish_result(self, result: TaskResult, project: str):
mandatory=True, # This is important
)

async def publish_error(self, error: TaskError, task: Task, project: str):
# pylint: disable=unused-argument
# TODO: for now project information is not leverage on the AMQP side which is
# not very convenient as clients will won't know from which project the error
# is coming. This is limitating as for instance when as error must probably
# be saved in separate DBs in order to avoid project data leaking to other
# projects through the DB
async def publish_error(self, error: TaskError):
# TODO: for now task project information is not leverage on the AMQP side which
# is not very convenient as clients will won't know from which project the
# result is coming. This is limitating as for instance when as result must
# probably be saved in separate DBs in order to avoid project data leaking to
# other projects through the DB
message = error.json().encode()
await self._publish_message(
message,
Expand Down
10 changes: 8 additions & 2 deletions icij-worker/icij_worker/event_publisher/event_publisher.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from abc import ABC, abstractmethod
from typing import final

from icij_worker import TaskEvent
from icij_worker import Task, TaskEvent


class EventPublisher(ABC):
@final
async def publish_event(self, event: TaskEvent, task: Task):
event = event.with_project_id(task)
await self._publish_event(event)

@abstractmethod
async def publish_event(self, event: TaskEvent, project: str):
async def _publish_event(self, event: TaskEvent):
pass
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from contextlib import asynccontextmanager
from typing import AsyncGenerator, Dict, Optional
from typing import AsyncGenerator, Dict, List, Optional

import neo4j
from neo4j.exceptions import ResultNotSingleError
Expand All @@ -10,31 +10,71 @@
TASK_ID,
TASK_NODE,
)
from icij_common.neo4j.migrate import retrieve_projects
from icij_common.neo4j.projects import project_db_session
from . import EventPublisher
from .. import Task, TaskEvent, TaskStatus
from ..exceptions import UnknownTask
from icij_worker.event_publisher.event_publisher import EventPublisher
from icij_worker.task import Task, TaskEvent, TaskStatus
from icij_worker.exceptions import UnknownTask


class Neo4jEventPublisher(EventPublisher):
class Neo4jTaskProjectMixin:
_driver: neo4j.AsyncDriver
_task_projects: Dict[str, str] = dict()

@asynccontextmanager
async def _project_session(
self, project: str
) -> AsyncGenerator[neo4j.AsyncSession, None]:
async with project_db_session(self._driver, project) as sess:
yield sess

async def _get_task_project_id(self, task_id: str) -> str:
if task_id not in self._task_projects:
await self._refresh_task_projects()
try:
return self._task_projects[task_id]
except KeyError as e:
raise UnknownTask(task_id) from e

async def _refresh_task_projects(self):
projects = await retrieve_projects(self._driver)
for p in projects:
async with self._project_session(p) as sess:
# Here we make the assumption that task IDs are unique across
# projects and not per project
task_projects = {
t: p.name for t in await sess.execute_read(_get_task_ids_tx)
}
self._task_projects.update(task_projects)


async def _get_task_ids_tx(tx: neo4j.AsyncTransaction) -> List[str]:
query = f"""MATCH (task:{TASK_NODE})
RETURN task.{TASK_ID} as taskId"""
res = await tx.run(query)
ids = [rec["taskId"] async for rec in res]
return ids


class Neo4jEventPublisher(Neo4jTaskProjectMixin, EventPublisher):
def __init__(self, driver: neo4j.AsyncDriver):
self._driver = driver

async def publish_event(self, event: TaskEvent, project: str):
async def _publish_event(self, event: TaskEvent):
project = event.project_id
if project is None:
msg = (
"neo4j expects project to be provided in order to fetch tasks from"
" the project's DB"
)
raise ValueError(msg)
async with self._project_session(project) as sess:
await _publish_event(sess, event)

@property
def driver(self) -> neo4j.AsyncDriver:
return self._driver

@asynccontextmanager
async def _project_session(
self, project: str
) -> AsyncGenerator[neo4j.AsyncSession, None]:
async with project_db_session(self._driver, project) as sess:
yield sess


async def _publish_event(sess: neo4j.AsyncSession, event: TaskEvent):
event = {k: v for k, v in event.dict(by_alias=True).items() if v is not None}
Expand All @@ -48,6 +88,7 @@ async def _publish_event_tx(
tx: neo4j.AsyncTransaction, event: Dict, error: Optional[Dict]
):
task_id = event["taskId"]
project_id = event["project"]
create_task = f"""MERGE (task:{TASK_NODE} {{{TASK_ID}: $taskId }})
ON CREATE SET task += $createProps"""
status = event.get("status")
Expand All @@ -58,7 +99,7 @@ async def _publish_event_tx(
create_props = Task.mandatory_fields(event_as_event, keep_id=False)
create_props.pop("status", None)
res = await tx.run(create_task, taskId=task_id, createProps=create_props)
tasks = [Task.from_neo4j(rec) async for rec in res]
tasks = [Task.from_neo4j(rec, project_id=project_id) async for rec in res]
task = tasks[0]
resolved = task.resolve_event(event_as_event)
resolved = (
Expand Down
Loading

0 comments on commit a1cb239

Please sign in to comment.