Skip to content

Commit

Permalink
add a test and fix the string param passing into enum
Browse files Browse the repository at this point in the history
  • Loading branch information
stan-dot committed May 17, 2024
1 parent ce40afd commit 0e52ac1
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 5 deletions.
6 changes: 3 additions & 3 deletions src/blueapi/service/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from blueapi.worker.event import WorkerState
from blueapi.worker.reworker import TaskWorker
from blueapi.worker.task import Task
from blueapi.worker.worker import TrackableTask, Worker
from blueapi.worker.worker import TaskStatusEnum, TrackableTask, Worker

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -115,8 +115,8 @@ def begin_task(self, task: WorkerTask) -> WorkerTask:
self._worker.begin_task(task.task_id)
return task

def get_tasks_by_status(self, status:str) -> list[TrackableTask[Task]]:
return self._worker.get_tasks_by_status(status)
def get_tasks_by_status(self, status: TaskStatusEnum) -> list[TrackableTask[Task]]:
return self._worker.get_tasks_by_status(status)

@property
def active_task(self) -> TrackableTask | None:
Expand Down
4 changes: 2 additions & 2 deletions src/blueapi/service/handler_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from blueapi.service.model import DeviceModel, PlanModel, WorkerTask
from blueapi.worker.event import WorkerState
from blueapi.worker.task import Task
from blueapi.worker.worker import TrackableTask
from blueapi.worker.worker import TaskStatusEnum, TrackableTask


class BlueskyHandler(ABC):
Expand Down Expand Up @@ -50,7 +50,7 @@ def begin_task(self, task: WorkerTask) -> WorkerTask:
"""Trigger a task. Will fail if the worker is busy"""

@abstractmethod
def get_tasks_by_status(self, status: str) -> list[TrackableTask[Task]]:
def get_tasks_by_status(self, status: TaskStatusEnum) -> list[TrackableTask[Task]]:
"""
Retrieve a list of tasks based on their status.
Args:
Expand Down
4 changes: 4 additions & 0 deletions src/blueapi/service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from blueapi.config import ApplicationConfig
from blueapi.worker import Task, TrackableTask, WorkerState
from blueapi.worker.worker import TaskStatusEnum

from .handler_base import BlueskyHandler
from .model import (
Expand Down Expand Up @@ -163,6 +164,9 @@ def get_tasks(
"""
Retrieve tasks based on their status. The default status is 'unstarted'.
"""
if status not in TaskStatusEnum.__members__:
raise HTTPException(status_code=400, detail="Invalid status query parameter")

try:
tasks = handler.get_tasks_by_status(status)
except ValueError as e:
Expand Down
1 change: 1 addition & 0 deletions src/blueapi/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
T = TypeVar("T")


# NOTE this is interim until the refactor
class TaskStatusEnum(str, Enum):
PENDING = "PENDING"
COMPLETE = "COMPLETE"
Expand Down
1 change: 1 addition & 0 deletions tests/service/test_rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def test_get_tasks_bad_status(handler: Handler, client: TestClient):
assert response.status_code == 400
assert "Unsupported status" in response.json()["detail"]


def test_worker_task_is_none_on_startup(handler: Handler, client: TestClient) -> None:
handler.start()
resp = client.get("/worker/task")
Expand Down
43 changes: 43 additions & 0 deletions tests/worker/test_reworker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
WorkerEvent,
WorkerState,
)
from blueapi.worker.worker import TaskStatusEnum

_SIMPLE_TASK = Task(name="sleep", params={"time": 0.0})
_LONG_TASK = Task(name="sleep", params={"time": 1.0})
Expand Down Expand Up @@ -215,6 +216,48 @@ def test_produces_worker_events(worker: Worker, num_runs: int) -> None:
assert_run_produces_worker_events(events, worker, task_id)


@pytest.mark.parametrize(
"status, expected_task_ids",
[
(TaskStatusEnum.UNDERWAY, ["task1"]),
(TaskStatusEnum.PENDING, ["task2"]),
(TaskStatusEnum.COMPLETE, ["task3"]),
],
)
def test_get_tasks_by_status(worker: Worker, status, expected_task_ids):
worker._tasks = {
"task1": TrackableTask(
task_id="task1",
task=Task(
name="set_absolute", params={"movable": "fake_device", "value": 4.0}
),
is_complete=False,
is_pending=False,
),
"task2": TrackableTask(
task_id="task2",
task=Task(
name="set_absolute", params={"movable": "fake_device", "value": 4.0}
),
is_complete=False,
is_pending=True,
),
"task3": TrackableTask(
task_id="task3",
task=Task(
name="set_absolute", params={"movable": "fake_device", "value": 4.0}
),
is_complete=True,
is_pending=False,
),
}

result = worker.get_tasks_by_status(status)
result_ids = [task_id for task_id, task in worker._tasks.items() if task in result]

assert result_ids == expected_task_ids


def _sleep_events(task_id: str) -> list[WorkerEvent]:
return [
WorkerEvent(
Expand Down

0 comments on commit 0e52ac1

Please sign in to comment.