From 971dda10adc68691513f7e5a452915306d813f8d Mon Sep 17 00:00:00 2001 From: Stanislaw Malinowski Date: Fri, 17 May 2024 11:48:56 +0100 Subject: [PATCH] apply discussed changes in one go --- src/blueapi/service/handler.py | 3 +++ src/blueapi/service/handler_base.py | 10 +++++++++ src/blueapi/service/main.py | 16 ++++++++++++++ src/blueapi/worker/reworker.py | 16 +++++++++++++- src/blueapi/worker/worker.py | 18 ++++++++++++++++ tests/service/test_rest_api.py | 33 +++++++++++++++++++++++++++++ 6 files changed, 95 insertions(+), 1 deletion(-) diff --git a/src/blueapi/service/handler.py b/src/blueapi/service/handler.py index 24cdd312a..f0513b348 100644 --- a/src/blueapi/service/handler.py +++ b/src/blueapi/service/handler.py @@ -115,6 +115,9 @@ def begin_task(self, task: WorkerTask) -> WorkerTask: self._worker.begin_task(task.task_id) return task + def get_tasks_by_status(self, status:str) -> list[TrackableTask[Task]]: + return self._worker.get_tasks_by_status(status) + @property def active_task(self) -> TrackableTask | None: return self._worker.get_active_task() diff --git a/src/blueapi/service/handler_base.py b/src/blueapi/service/handler_base.py index 0671ebad8..02bbee04d 100644 --- a/src/blueapi/service/handler_base.py +++ b/src/blueapi/service/handler_base.py @@ -49,6 +49,16 @@ def clear_task(self, task_id: str) -> str: def begin_task(self, task: WorkerTask) -> WorkerTask: """Trigger a task. Will fail if the worker is busy""" + @abstractmethod + def get_tasks_by_status(self, status: str) -> list[TrackableTask[Task]]: + """ + Retrieve a list of tasks based on their status. + Args: + str: The status to filter tasks by. + Returns: + list[TrackableTask[T]]: A list of tasks that match the given status. + """ + @property @abstractmethod def active_task(self) -> TrackableTask | None: diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 9c0551cd7..2a3bb2ca4 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -6,6 +6,7 @@ Depends, FastAPI, HTTPException, + Query, Request, Response, status, @@ -154,6 +155,21 @@ def submit_task( ) from e +@app.get("/tasks") +def get_tasks( + status: str = Query("unstarted", description="The status of the tasks to retrieve"), + handler: BlueskyHandler = Depends(get_handler), +) -> list[TrackableTask]: + """ + Retrieve tasks based on their status. The default status is 'unstarted'. + """ + try: + tasks = handler.get_tasks_by_status(status) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) # noqa: B904 + return tasks + + @app.delete("/tasks/{task_id}", status_code=status.HTTP_200_OK) def delete_submitted_task( task_id: str, diff --git a/src/blueapi/worker/reworker.py b/src/blueapi/worker/reworker.py index 97a523613..04af65326 100644 --- a/src/blueapi/worker/reworker.py +++ b/src/blueapi/worker/reworker.py @@ -28,7 +28,7 @@ ) from .multithread import run_worker_in_own_thread from .task import Task -from .worker import TrackableTask, Worker +from .worker import TaskStatusEnum, TrackableTask, Worker from .worker_errors import WorkerAlreadyStartedError, WorkerBusyError LOGGER = logging.getLogger(__name__) @@ -136,6 +136,20 @@ def submit_task(self, task: Task) -> str: self._tasks[task_id] = trackable_task return task_id + def get_tasks_by_status(self, status: TaskStatusEnum) -> list[TrackableTask[Task]]: + if status == TaskStatusEnum.UNDERWAY: + return [ + task + for task in self._tasks.values() + if not task.is_complete and not task.is_pending + ] + elif status == TaskStatusEnum.PENDING: + return [task for task in self._tasks.values() if task.is_pending] + elif status == TaskStatusEnum.COMPLETE: + return [task for task in self._tasks.values() if task.is_complete] + else: + raise ValueError("Unknown status") + def _submit_trackable_task(self, trackable_task: TrackableTask) -> None: if self.state is not WorkerState.IDLE: raise WorkerBusyError(f"Worker is in state {self.state}") diff --git a/src/blueapi/worker/worker.py b/src/blueapi/worker/worker.py index 026806074..63abe1158 100644 --- a/src/blueapi/worker/worker.py +++ b/src/blueapi/worker/worker.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from enum import Enum from typing import Generic, TypeVar from pydantic import Field @@ -11,6 +12,13 @@ T = TypeVar("T") +class TaskStatusEnum(str, Enum): + PENDING = "PENDING" + COMPLETE = "COMPLETE" + ERROR = "ERROR" + UNDERWAY = "UNDERWAY" + + class TrackableTask(BlueapiBaseModel, Generic[T]): """ A representation of a task that the worker recognizes @@ -107,6 +115,16 @@ def submit_task(self, task: T) -> str: str: A unique ID to refer to this task """ + @abstractmethod + def get_tasks_by_status(self, status: str) -> list[TrackableTask[T]]: + """ + Retrieve a list of tasks based on their status. + Args: + str: The status to filter tasks by. + Returns: + list[TrackableTask[T]]: A list of tasks that match the given status. + """ + @abstractmethod def start(self) -> None: """ diff --git a/tests/service/test_rest_api.py b/tests/service/test_rest_api.py index ca416bae8..b26795218 100644 --- a/tests/service/test_rest_api.py +++ b/tests/service/test_rest_api.py @@ -14,6 +14,7 @@ from blueapi.service.model import WorkerTask from blueapi.worker import WorkerState from blueapi.worker.task import Task +from blueapi.worker.worker import TrackableTask _TASK = Task(name="count", params={"detectors": ["x"]}) @@ -223,6 +224,38 @@ def test_put_plan_begins_task(handler: Handler, client: TestClient) -> None: handler.stop() +tasks_data = [ + TrackableTask( + task_id="1", task=Task(name="first_task"), is_complete=False, is_pending=False + ), + TrackableTask( + task_id="2", task=Task(name="first_task"), is_complete=False, is_pending=True + ), +] + + +def test_get_unstarted_tasks(handler: Handler, client: TestClient): + handler.start() + # handler.tasks = tasks_data # overriding the property + handler._worker.get_tasks_by_status = Mock(return_value=tasks_data) + response = client.get("/tasks?status=unstarted") + assert response.status_code == 200 + assert ( + len(response.json()) == 1 + ) # As per our mock data, only 1 task should be 'unstarted' + assert ( + response.json()[0]["task_id"] == "1" + ) # Check that the correct task ID is returned + + +def test_get_tasks_bad_status(handler: Handler, client: TestClient): + handler.start() + # handler.tasks = tasks_data + handler._worker.get_tasks_by_status = Mock(return_value=tasks_data) + response = client.get("/tasks?status=invalid") + assert response.status_code == 400 + assert "Unsupported status" in response.json()["detail"] + def test_worker_task_is_none_on_startup(handler: Handler, client: TestClient) -> None: handler.start() resp = client.get("/worker/task")