Skip to content

Commit

Permalink
Remove parameter caching inside task class (#370)
Browse files Browse the repository at this point in the history
Closes #369

The validated and processed parameters from a plan request were being
cached inside the `Task` class because they were needed in two separate
threads. Unfortunately, with the introduction of the subprocess (#343),
pickling was causing issues with the cached object (see #369 for more
details).

This PR is the simplest solution: remove the cache and generate the
parameters twice. It also adds regression tests for the bug case in #369
  • Loading branch information
callumforrester committed Feb 7, 2024
1 parent ad943e4 commit 598d789
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 11 deletions.
2 changes: 1 addition & 1 deletion src/blueapi/worker/reworker.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def begin_task(self, task_id: str) -> None:
raise KeyError(f"No pending task with ID {task_id}")

def submit_task(self, task: Task) -> str:
task.prepare_params(self._ctx)
task.prepare_params(self._ctx) # Will raise if parameters are invalid
task_id: str = str(uuid.uuid4())
trackable_task = TrackableTask(task_id=task_id, task=task)
self._pending_tasks[task_id] = trackable_task
Expand Down
14 changes: 4 additions & 10 deletions src/blueapi/worker/task.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Any, Mapping, Optional
from typing import Any, Mapping

from pydantic import BaseModel, Field

Expand All @@ -18,25 +18,19 @@ class Task(BlueapiBaseModel):
params: Mapping[str, Any] = Field(
description="Values for parameters to plan, if any", default_factory=dict
)
_prepared_params: Optional[BaseModel] = None

def prepare_params(self, ctx: BlueskyContext) -> None:
self._ensure_params(ctx)
def prepare_params(self, ctx: BlueskyContext) -> BaseModel:
return _lookup_params(ctx, self)

def do_task(self, ctx: BlueskyContext) -> None:
LOGGER.info(f"Asked to run plan {self.name} with {self.params}")

func = ctx.plan_functions[self.name]
prepared_params = self._ensure_params(ctx)
prepared_params = self.prepare_params(ctx)
plan_generator = func(**prepared_params.dict())
wrapped_plan_generator = ctx.wrap(plan_generator)
ctx.run_engine(wrapped_plan_generator)

def _ensure_params(self, ctx: BlueskyContext) -> BaseModel:
if self._prepared_params is None:
self._prepared_params = _lookup_params(ctx, self)
return self._prepared_params


# Here for backward compatibility pending
# https://github.com/DiamondLightSource/blueapi/issues/253
Expand Down
26 changes: 26 additions & 0 deletions tests/service/test_rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from blueapi.core.bluesky_types import Plan
from blueapi.service.handler import Handler
from blueapi.service.main import get_handler, setup_handler, teardown_handler
from blueapi.service.model import WorkerTask
from blueapi.worker.task import RunPlan
from src.blueapi.worker import WorkerState

Expand Down Expand Up @@ -215,6 +216,31 @@ def test_put_plan_begins_task(handler: Handler, client: TestClient) -> None:
task_json = {"task_id": task_id}
client.put("/worker/task", json=task_json)

resp = client.get("/worker/task")
assert resp.status_code == 200
active_task = WorkerTask(**resp.json())
assert active_task is not None
assert active_task.task_id == task_id
handler.stop()


def test_worker_task_is_none_on_startup(handler: Handler, client: TestClient) -> None:
handler.start()
resp = client.get("/worker/task")
assert resp.status_code == 200
active_task = WorkerTask(**resp.json())
assert active_task.task_id is None
handler.stop()


def test_get_worker_task(handler: Handler, client: TestClient) -> None:
handler.start()
response = client.post("/tasks", json=_TASK.dict())
task_id = response.json()["task_id"]

task_json = {"task_id": task_id}
client.put("/worker/task", json=task_json)

active_task = handler.active_task
assert active_task is not None
assert active_task.task_id == task_id
Expand Down

0 comments on commit 598d789

Please sign in to comment.