diff --git a/src/blueapi/worker/reworker.py b/src/blueapi/worker/reworker.py index 0ff3a6ec2..659d6c5bb 100644 --- a/src/blueapi/worker/reworker.py +++ b/src/blueapi/worker/reworker.py @@ -129,7 +129,7 @@ def begin_task(self, task_id: str) -> None: raise KeyError(f"No pending task with ID {task_id}") def submit_task(self, task: Task) -> str: - task.prepare_params(self._ctx) + task.prepare_params(self._ctx) # Will raise if parameters are invalid task_id: str = str(uuid.uuid4()) trackable_task = TrackableTask(task_id=task_id, task=task) self._pending_tasks[task_id] = trackable_task diff --git a/src/blueapi/worker/task.py b/src/blueapi/worker/task.py index d1b4de07d..84f5a21cd 100644 --- a/src/blueapi/worker/task.py +++ b/src/blueapi/worker/task.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Mapping, Optional +from typing import Any, Mapping from pydantic import BaseModel, Field @@ -18,25 +18,19 @@ class Task(BlueapiBaseModel): params: Mapping[str, Any] = Field( description="Values for parameters to plan, if any", default_factory=dict ) - _prepared_params: Optional[BaseModel] = None - def prepare_params(self, ctx: BlueskyContext) -> None: - self._ensure_params(ctx) + def prepare_params(self, ctx: BlueskyContext) -> BaseModel: + return _lookup_params(ctx, self) def do_task(self, ctx: BlueskyContext) -> None: LOGGER.info(f"Asked to run plan {self.name} with {self.params}") func = ctx.plan_functions[self.name] - prepared_params = self._ensure_params(ctx) + prepared_params = self.prepare_params(ctx) plan_generator = func(**prepared_params.dict()) wrapped_plan_generator = ctx.wrap(plan_generator) ctx.run_engine(wrapped_plan_generator) - def _ensure_params(self, ctx: BlueskyContext) -> BaseModel: - if self._prepared_params is None: - self._prepared_params = _lookup_params(ctx, self) - return self._prepared_params - # Here for backward compatibility pending # https://github.com/DiamondLightSource/blueapi/issues/253 diff --git a/tests/service/test_rest_api.py b/tests/service/test_rest_api.py index c72103302..ea3aa3deb 100644 --- a/tests/service/test_rest_api.py +++ b/tests/service/test_rest_api.py @@ -12,6 +12,7 @@ from blueapi.core.bluesky_types import Plan from blueapi.service.handler import Handler from blueapi.service.main import get_handler, setup_handler, teardown_handler +from blueapi.service.model import WorkerTask from blueapi.worker.task import RunPlan from src.blueapi.worker import WorkerState @@ -215,6 +216,31 @@ def test_put_plan_begins_task(handler: Handler, client: TestClient) -> None: task_json = {"task_id": task_id} client.put("/worker/task", json=task_json) + resp = client.get("/worker/task") + assert resp.status_code == 200 + active_task = WorkerTask(**resp.json()) + assert active_task is not None + assert active_task.task_id == task_id + handler.stop() + + +def test_worker_task_is_none_on_startup(handler: Handler, client: TestClient) -> None: + handler.start() + resp = client.get("/worker/task") + assert resp.status_code == 200 + active_task = WorkerTask(**resp.json()) + assert active_task.task_id is None + handler.stop() + + +def test_get_worker_task(handler: Handler, client: TestClient) -> None: + handler.start() + response = client.post("/tasks", json=_TASK.dict()) + task_id = response.json()["task_id"] + + task_json = {"task_id": task_id} + client.put("/worker/task", json=task_json) + active_task = handler.active_task assert active_task is not None assert active_task.task_id == task_id