Skip to content

Commit

Permalink
apply discussed changes in one go
Browse files Browse the repository at this point in the history
  • Loading branch information
stan-dot committed May 17, 2024
1 parent 5950772 commit 971dda1
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 1 deletion.
3 changes: 3 additions & 0 deletions src/blueapi/service/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ def begin_task(self, task: WorkerTask) -> WorkerTask:
self._worker.begin_task(task.task_id)
return task

def get_tasks_by_status(self, status:str) -> list[TrackableTask[Task]]:
return self._worker.get_tasks_by_status(status)

@property
def active_task(self) -> TrackableTask | None:
return self._worker.get_active_task()
Expand Down
10 changes: 10 additions & 0 deletions src/blueapi/service/handler_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,16 @@ def clear_task(self, task_id: str) -> str:
def begin_task(self, task: WorkerTask) -> WorkerTask:
"""Trigger a task. Will fail if the worker is busy"""

@abstractmethod
def get_tasks_by_status(self, status: str) -> list[TrackableTask[Task]]:
"""
Retrieve a list of tasks based on their status.
Args:
str: The status to filter tasks by.
Returns:
list[TrackableTask[T]]: A list of tasks that match the given status.
"""

@property
@abstractmethod
def active_task(self) -> TrackableTask | None:
Expand Down
16 changes: 16 additions & 0 deletions src/blueapi/service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Depends,
FastAPI,
HTTPException,
Query,
Request,
Response,
status,
Expand Down Expand Up @@ -154,6 +155,21 @@ def submit_task(
) from e


@app.get("/tasks")
def get_tasks(
status: str = Query("unstarted", description="The status of the tasks to retrieve"),
handler: BlueskyHandler = Depends(get_handler),
) -> list[TrackableTask]:
"""
Retrieve tasks based on their status. The default status is 'unstarted'.
"""
try:
tasks = handler.get_tasks_by_status(status)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) # noqa: B904
return tasks


@app.delete("/tasks/{task_id}", status_code=status.HTTP_200_OK)
def delete_submitted_task(
task_id: str,
Expand Down
16 changes: 15 additions & 1 deletion src/blueapi/worker/reworker.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
)
from .multithread import run_worker_in_own_thread
from .task import Task
from .worker import TrackableTask, Worker
from .worker import TaskStatusEnum, TrackableTask, Worker
from .worker_errors import WorkerAlreadyStartedError, WorkerBusyError

LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -136,6 +136,20 @@ def submit_task(self, task: Task) -> str:
self._tasks[task_id] = trackable_task
return task_id

def get_tasks_by_status(self, status: TaskStatusEnum) -> list[TrackableTask[Task]]:
if status == TaskStatusEnum.UNDERWAY:
return [
task
for task in self._tasks.values()
if not task.is_complete and not task.is_pending
]
elif status == TaskStatusEnum.PENDING:
return [task for task in self._tasks.values() if task.is_pending]
elif status == TaskStatusEnum.COMPLETE:
return [task for task in self._tasks.values() if task.is_complete]
else:
raise ValueError("Unknown status")

def _submit_trackable_task(self, trackable_task: TrackableTask) -> None:
if self.state is not WorkerState.IDLE:
raise WorkerBusyError(f"Worker is in state {self.state}")
Expand Down
18 changes: 18 additions & 0 deletions src/blueapi/worker/worker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Generic, TypeVar

from pydantic import Field
Expand All @@ -11,6 +12,13 @@
T = TypeVar("T")


class TaskStatusEnum(str, Enum):
PENDING = "PENDING"
COMPLETE = "COMPLETE"
ERROR = "ERROR"
UNDERWAY = "UNDERWAY"


class TrackableTask(BlueapiBaseModel, Generic[T]):
"""
A representation of a task that the worker recognizes
Expand Down Expand Up @@ -107,6 +115,16 @@ def submit_task(self, task: T) -> str:
str: A unique ID to refer to this task
"""

@abstractmethod
def get_tasks_by_status(self, status: str) -> list[TrackableTask[T]]:
"""
Retrieve a list of tasks based on their status.
Args:
str: The status to filter tasks by.
Returns:
list[TrackableTask[T]]: A list of tasks that match the given status.
"""

@abstractmethod
def start(self) -> None:
"""
Expand Down
33 changes: 33 additions & 0 deletions tests/service/test_rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from blueapi.service.model import WorkerTask
from blueapi.worker import WorkerState
from blueapi.worker.task import Task
from blueapi.worker.worker import TrackableTask

_TASK = Task(name="count", params={"detectors": ["x"]})

Expand Down Expand Up @@ -223,6 +224,38 @@ def test_put_plan_begins_task(handler: Handler, client: TestClient) -> None:
handler.stop()


tasks_data = [
TrackableTask(
task_id="1", task=Task(name="first_task"), is_complete=False, is_pending=False
),
TrackableTask(
task_id="2", task=Task(name="first_task"), is_complete=False, is_pending=True
),
]


def test_get_unstarted_tasks(handler: Handler, client: TestClient):
handler.start()
# handler.tasks = tasks_data # overriding the property
handler._worker.get_tasks_by_status = Mock(return_value=tasks_data)
response = client.get("/tasks?status=unstarted")
assert response.status_code == 200
assert (
len(response.json()) == 1
) # As per our mock data, only 1 task should be 'unstarted'
assert (
response.json()[0]["task_id"] == "1"
) # Check that the correct task ID is returned


def test_get_tasks_bad_status(handler: Handler, client: TestClient):
handler.start()
# handler.tasks = tasks_data
handler._worker.get_tasks_by_status = Mock(return_value=tasks_data)
response = client.get("/tasks?status=invalid")
assert response.status_code == 400
assert "Unsupported status" in response.json()["detail"]

def test_worker_task_is_none_on_startup(handler: Handler, client: TestClient) -> None:
handler.start()
resp = client.get("/worker/task")
Expand Down

0 comments on commit 971dda1

Please sign in to comment.