diff --git a/docs/developer/explanations/decisions/0002-no-queues.rst b/docs/developer/explanations/decisions/0002-no-queues.rst new file mode 100644 index 000000000..ebc5ef9c8 --- /dev/null +++ b/docs/developer/explanations/decisions/0002-no-queues.rst @@ -0,0 +1,27 @@ +2. No Queues +============ + +Date: 2023-05-22 + +Status +------ + +Accepted + +Context +------- + +In asking whether this service should hold and execute a queue of tasks. + +Decision +-------- + +We will not hold any queues. The worker can execute one task at a time and will return +an error if asked to execute one task while another is running. Queueing should be the +responsibility of a different service. + +Consequences +------------ + +The API must be kept queue-free, although transactions are permitted where the server +caches requests. diff --git a/src/blueapi/cli/amq.py b/src/blueapi/cli/amq.py index 8006bee4e..9cce7b2b5 100644 --- a/src/blueapi/cli/amq.py +++ b/src/blueapi/cli/amq.py @@ -61,7 +61,7 @@ def on_progress_event_wrapper( task_response = self.app.send_and_receive( "worker.run", {"name": name, "params": params}, reply_type=TaskResponse ).result(5.0) - task_id = task_response.task_name + task_id = task_response.task_id if timeout is not None: complete.wait(timeout) diff --git a/src/blueapi/cli/updates.py b/src/blueapi/cli/updates.py index 51a7b4f9f..d9279b5b6 100644 --- a/src/blueapi/cli/updates.py +++ b/src/blueapi/cli/updates.py @@ -43,15 +43,15 @@ def _update(self, name: str, view: StatusView) -> None: class CliEventRenderer: - _task_name: Optional[str] + _task_id: Optional[str] _pbar_renderer: ProgressBarRenderer def __init__( self, - task_name: Optional[str] = None, + task_id: Optional[str] = None, pbar_renderer: Optional[ProgressBarRenderer] = None, ) -> None: - self._task_name = task_name + self._task_id = task_id if pbar_renderer is None: pbar_renderer = ProgressBarRenderer() self._pbar_renderer = pbar_renderer @@ -65,14 +65,14 @@ def on_worker_event(self, event: WorkerEvent) -> None: print(str(event.state)) def _relates_to_task(self, event: Union[WorkerEvent, ProgressEvent]) -> bool: - if self._task_name is None: + if self._task_id is None: return True elif isinstance(event, WorkerEvent): return ( event.task_status is not None - and event.task_status.task_name == self._task_name + and event.task_status.task_id == self._task_id ) elif isinstance(event, ProgressEvent): - return event.task_name == self._task_name + return event.task_id == self._task_id else: return False diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index df6efb162..9039ee8e1 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -70,8 +70,9 @@ def submit_task( handler: Handler = Depends(get_handler), ): """Submit a task onto the worker queue.""" - handler.worker.submit_task(name, RunPlan(name=name, params=task)) - return TaskResponse(task_name=name) + task_id = handler.worker.submit_task(RunPlan(name=name, params=task)) + handler.worker.begin_task(task_id) + return TaskResponse(task_id=task_id) def start(config: ApplicationConfig): diff --git a/src/blueapi/service/model.py b/src/blueapi/service/model.py index 17cd57c06..b5599fcd0 100644 --- a/src/blueapi/service/model.py +++ b/src/blueapi/service/model.py @@ -80,4 +80,4 @@ class TaskResponse(BlueapiBaseModel): Acknowledgement that a task has started, includes its ID """ - task_name: str = Field(description="Unique identifier for the task") + task_id: str = Field(description="Unique identifier for the task") diff --git a/src/blueapi/worker/__init__.py b/src/blueapi/worker/__init__.py index bde984831..78309230e 100644 --- a/src/blueapi/worker/__init__.py +++ b/src/blueapi/worker/__init__.py @@ -2,7 +2,8 @@ from .multithread import run_worker_in_own_thread from .reworker import RunEngineWorker from .task import RunPlan, Task -from .worker import Worker +from .worker import TrackableTask, Worker +from .worker_busy_error import WorkerBusyError __all__ = [ "run_worker_in_own_thread", @@ -15,4 +16,6 @@ "StatusView", "ProgressEvent", "TaskStatus", + "TrackableTask", + "WorkerBusyError", ] diff --git a/src/blueapi/worker/event.py b/src/blueapi/worker/event.py index 6193f4068..9e9b7e8e3 100644 --- a/src/blueapi/worker/event.py +++ b/src/blueapi/worker/event.py @@ -88,7 +88,7 @@ class ProgressEvent(BlueapiBaseModel): such as moving motors and exposing detectors. """ - task_name: str + task_id: str statuses: Mapping[str, StatusView] = Field(default_factory=dict) @@ -97,7 +97,7 @@ class TaskStatus(BlueapiBaseModel): Status of a task the worker is running. """ - task_name: str + task_id: str task_complete: bool task_failed: bool diff --git a/src/blueapi/worker/reworker.py b/src/blueapi/worker/reworker.py index 13012db3e..349afa3ae 100644 --- a/src/blueapi/worker/reworker.py +++ b/src/blueapi/worker/reworker.py @@ -25,8 +25,8 @@ WorkerState, ) from .multithread import run_worker_in_own_thread -from .task import ActiveTask, Task -from .worker import Worker +from .task import Task +from .worker import TrackableTask, Worker from .worker_busy_error import WorkerBusyError LOGGER = logging.getLogger(__name__) @@ -47,14 +47,13 @@ class RunEngineWorker(Worker[Task]): _ctx: BlueskyContext _stop_timeout: float - _transaction_lock: RLock - _pending_transaction: Optional[ActiveTask] + _pending_tasks: Dict[str, TrackableTask] _state: WorkerState _errors: List[str] _warnings: List[str] _task_queue: Queue # type: ignore - _current: Optional[ActiveTask] + _current: Optional[TrackableTask] _status_lock: RLock _status_snapshot: Dict[str, StatusView] _completed_statuses: Set[str] @@ -73,8 +72,7 @@ def __init__( self._ctx = ctx self._stop_timeout = stop_timeout - self._transaction_lock = RLock() - self._pending_transaction = None + self._pending_tasks = {} self._state = WorkerState.from_bluesky_state(ctx.run_engine.state) self._errors = [] @@ -91,52 +89,33 @@ def __init__( self._stopping = Event() self._stopped = Event() - def begin_transaction(self, task: Task) -> str: - task_id: str = str(uuid.uuid4()) - with self._transaction_lock: - if self._pending_transaction is not None: - raise WorkerBusyError("There is already a transaction in progress") - self._pending_transaction = ActiveTask(task_id, task) - return task_id - - def clear_transaction(self) -> str: - with self._transaction_lock: - if self._pending_transaction is None: - raise Exception("No transaction to clear") - - task_id = self._pending_transaction.task_id - self._pending_transaction = None - return task_id + def clear_task(self, task_id: str) -> bool: + if task_id in self._pending_tasks: + del self._pending_tasks[task_id] + return True + else: + return False - def commit_transaction(self, task_id: str) -> None: - with self._transaction_lock: - if self._pending_transaction is None: - raise Exception("No transaction to commit") + def get_pending_tasks(self) -> List[TrackableTask[Task]]: + return list(self._pending_tasks.values()) - pending_id = self._pending_transaction.task_id - if pending_id == task_id: - self._submit_active_task(self._pending_transaction) - else: - raise KeyError( - "Not committing the transaction requested, asked to commit" - f"{task_id} when {pending_id} is in progress" - ) - - def get_pending(self) -> Optional[Task]: - with self._transaction_lock: - if self._pending_transaction is None: - return None - else: - return self._pending_transaction.task + def begin_task(self, task_id: str) -> None: + task = self._pending_tasks.get(task_id) + if task is not None: + self._submit_trackable_task(task) + else: + raise KeyError(f"No pending task with ID {task_id}") - def submit_task(self, task_id: str, task: Task) -> None: - active_task = ActiveTask(task_id, task) - self._submit_active_task(active_task) + def submit_task(self, task: Task) -> str: + task_id: str = str(uuid.uuid4()) + trackable_task = TrackableTask(task_id=task_id, task=task) + self._pending_tasks[task_id] = trackable_task + return task_id - def _submit_active_task(self, active_task: ActiveTask) -> None: - LOGGER.info(f"Submitting: {active_task}") + def _submit_trackable_task(self, trackable_task: TrackableTask) -> None: + LOGGER.info(f"Submitting: {trackable_task}") try: - self._task_queue.put_nowait(active_task) + self._task_queue.put_nowait(trackable_task) except Full: LOGGER.error("Cannot submit task while another is running") raise WorkerBusyError("Cannot submit task while another is running") @@ -181,8 +160,8 @@ def _cycle_with_error_handling(self) -> None: def _cycle(self) -> None: try: LOGGER.info("Awaiting task") - next_task: Union[ActiveTask, KillSignal] = self._task_queue.get() - if isinstance(next_task, ActiveTask): + next_task: Union[TrackableTask, KillSignal] = self._task_queue.get() + if isinstance(next_task, TrackableTask): LOGGER.info(f"Got new task: {next_task}") self._current = next_task # Informing mypy that the task is not None self._current.task.do_task(self._ctx) @@ -243,7 +222,7 @@ def _report_status( warnings = self._warnings if self._current is not None: task_status = TaskStatus( - task_name=self._current.task_id, + task_id=self._current.task_id, task_complete=self._current.is_complete, task_failed=self._current.is_error or bool(errors), ) @@ -336,10 +315,10 @@ def _publish_status_snapshot(self) -> None: else: self._progress_events.publish( ProgressEvent( - task_name=self._current.name, + task_id=self._current.task_id, statuses=self._status_snapshot, ), - self._current.name, + self._current.task_id, ) diff --git a/src/blueapi/worker/task.py b/src/blueapi/worker/task.py index f1b0e3e5e..fdec51202 100644 --- a/src/blueapi/worker/task.py +++ b/src/blueapi/worker/task.py @@ -1,6 +1,5 @@ import logging from abc import ABC, abstractmethod -from dataclasses import dataclass from typing import Any, Mapping from pydantic import BaseModel, Field, parse_obj_as @@ -65,11 +64,3 @@ def _lookup_params( model = plan.model return parse_obj_as(model, params) - - -@dataclass -class ActiveTask: - task_id: str - task: Task - is_complete: bool = False - is_error: bool = False diff --git a/src/blueapi/worker/worker.py b/src/blueapi/worker/worker.py index d9431710a..dffc32b69 100644 --- a/src/blueapi/worker/worker.py +++ b/src/blueapi/worker/worker.py @@ -1,13 +1,25 @@ from abc import ABC, abstractmethod -from typing import Generic, Optional, TypeVar +from typing import Generic, List, TypeVar from blueapi.core import DataEvent, EventStream +from blueapi.utils import BlueapiBaseModel from .event import ProgressEvent, WorkerEvent T = TypeVar("T") +class TrackableTask(BlueapiBaseModel, Generic[T]): + """ + A representation of a task that the worker recognizes + """ + + task_id: str + task: T + is_complete: bool = False + is_error: bool = False + + class Worker(ABC, Generic[T]): """ Entity that takes and runs tasks. Intended to be a central, @@ -15,56 +27,47 @@ class Worker(ABC, Generic[T]): """ @abstractmethod - def begin_transaction(self, __task: T) -> str: + def get_pending_tasks(self) -> List[TrackableTask[T]]: """ - Begin a new transaction, lock the worker with a pending task, - do not allow new transactions until this one is run or cleared. - - Args: - __task: The task to run if this transaction is committed + Return a list of all tasks pending on the worker, + any one of which can be triggered with begin_task. Returns: - str: An ID for the task + List[TrackableTask[T]]: List of task objects """ @abstractmethod - def clear_transaction(self) -> str: + def clear_task(self, task_id: str) -> bool: """ - Clear any existing transaction. Raise an error if - unable. + Remove a pending task from the worker + Args: + task_id: The ID of the task to be removed Returns: - str: The ID of the task cleared + bool: True if the task existed in the first place """ @abstractmethod - def commit_transaction(self, __task_id: str) -> None: + def begin_task(self, task_id: str) -> None: """ - Commit the pending transaction and run the - embedded task + Trigger a pending task. Will fail if the worker is busy. Args: - __task_id: The ID of the task to run, must match - the pending transaction + task_id: The ID of the task to be triggered + Throws: + WorkerBusyError: If the worker is already running a task. + KeyError: If the task ID does not exist """ @abstractmethod - def get_pending(self) -> Optional[T]: - """_summary_ - - Returns: - Optional[Task]: _description_ + def submit_task(self, task: T) -> str: """ - - @abstractmethod - def submit_task(self, __task_id: str, __task: T) -> None: - """ - Submit a task to be run + Submit a task to be run on begin_task Args: - __name (str): name of the plan to be run - __task (T): The task to run - __correlation_id (str): unique identifier of the task + task: A description of the task + Returns: + str: A unique ID to refer to this task """ @abstractmethod diff --git a/tests/conftest.py b/tests/conftest.py index 1174e7859..c9619a8d7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,8 @@ from blueapi.service.main import app from blueapi.worker.reworker import RunEngineWorker +_TIMEOUT = 10.0 + def pytest_addoption(parser): parser.addoption( @@ -65,3 +67,8 @@ def handler() -> MockHandler: @pytest.fixture(scope="session") def client(handler: MockHandler) -> TestClient: return Client(handler).client + + +@pytest.fixture(scope="session") +def timeout() -> float: + return _TIMEOUT diff --git a/tests/core/test_event.py b/tests/core/test_event.py index 9f832de01..21335d916 100644 --- a/tests/core/test_event.py +++ b/tests/core/test_event.py @@ -7,8 +7,6 @@ from blueapi.core import EventPublisher -_TIMEOUT: float = 10.0 - @dataclass class MyEvent: @@ -20,22 +18,22 @@ def publisher() -> EventPublisher[MyEvent]: return EventPublisher() -def test_publishes_event(publisher: EventPublisher[MyEvent]) -> None: +def test_publishes_event(timeout: float, publisher: EventPublisher[MyEvent]) -> None: event = MyEvent("a") f: Future = Future() publisher.subscribe(lambda r, _: f.set_result(r)) publisher.publish(event) - assert f.result(timeout=_TIMEOUT) == event + assert f.result(timeout=timeout) == event -def test_multi_subscriber(publisher: EventPublisher[MyEvent]) -> None: +def test_multi_subscriber(timeout: float, publisher: EventPublisher[MyEvent]) -> None: event = MyEvent("a") f1: Future = Future() f2: Future = Future() publisher.subscribe(lambda r, _: f1.set_result(r)) publisher.subscribe(lambda r, _: f2.set_result(r)) publisher.publish(event) - assert f1.result(timeout=_TIMEOUT) == f2.result(timeout=_TIMEOUT) == event + assert f1.result(timeout=timeout) == f2.result(timeout=timeout) == event def test_can_unsubscribe(publisher: EventPublisher[MyEvent]) -> None: @@ -67,13 +65,13 @@ def test_can_unsubscribe_all(publisher: EventPublisher[MyEvent]) -> None: assert list(_drain(q)) == [event_a, event_a, event_c] -def test_correlation_id(publisher: EventPublisher[MyEvent]) -> None: +def test_correlation_id(timeout: float, publisher: EventPublisher[MyEvent]) -> None: event = MyEvent("a") correlation_id = "foobar" f: Future = Future() publisher.subscribe(lambda _, c: f.set_result(c)) publisher.publish(event, correlation_id) - assert f.result(timeout=_TIMEOUT) == correlation_id + assert f.result(timeout=timeout) == correlation_id def _drain(queue: Queue) -> Iterable: diff --git a/tests/messaging/test_stomptemplate.py b/tests/messaging/test_stomptemplate.py index 66e76ffbf..054d47072 100644 --- a/tests/messaging/test_stomptemplate.py +++ b/tests/messaging/test_stomptemplate.py @@ -9,7 +9,6 @@ from blueapi.config import StompConfig from blueapi.messaging import MessageContext, MessagingTemplate, StompMessagingTemplate -_TIMEOUT: float = 10.0 _COUNT = itertools.count() @@ -41,7 +40,7 @@ def test_topic(template: MessagingTemplate) -> str: @pytest.mark.stomp -def test_send(template: MessagingTemplate, test_queue: str) -> None: +def test_send(template: MessagingTemplate, timeout: float, test_queue: str) -> None: f: Future = Future() def callback(ctx: MessageContext, message: str) -> None: @@ -49,11 +48,13 @@ def callback(ctx: MessageContext, message: str) -> None: template.subscribe(test_queue, callback) template.send(test_queue, "test_message") - assert f.result(timeout=_TIMEOUT) + assert f.result(timeout=timeout) @pytest.mark.stomp -def test_send_to_topic(template: MessagingTemplate, test_topic: str) -> None: +def test_send_to_topic( + template: MessagingTemplate, timeout: float, test_topic: str +) -> None: f: Future = Future() def callback(ctx: MessageContext, message: str) -> None: @@ -61,11 +62,13 @@ def callback(ctx: MessageContext, message: str) -> None: template.subscribe(test_topic, callback) template.send(test_topic, "test_message") - assert f.result(timeout=_TIMEOUT) + assert f.result(timeout=timeout) @pytest.mark.stomp -def test_send_on_reply(template: MessagingTemplate, test_queue: str) -> None: +def test_send_on_reply( + template: MessagingTemplate, timeout: float, test_queue: str +) -> None: acknowledge(template, test_queue) f: Future = Future() @@ -74,18 +77,20 @@ def callback(ctx: MessageContext, message: str) -> None: f.set_result(message) template.send(test_queue, "test_message", callback) - assert f.result(timeout=_TIMEOUT) + assert f.result(timeout=timeout) @pytest.mark.stomp -def test_send_and_receive(template: MessagingTemplate, test_queue: str) -> None: +def test_send_and_receive( + template: MessagingTemplate, timeout: float, test_queue: str +) -> None: acknowledge(template, test_queue) - reply = template.send_and_receive(test_queue, "test", str).result(timeout=_TIMEOUT) + reply = template.send_and_receive(test_queue, "test", str).result(timeout=timeout) assert reply == "ack" @pytest.mark.stomp -def test_listener(template: MessagingTemplate, test_queue: str) -> None: +def test_listener(template: MessagingTemplate, timeout: float, test_queue: str) -> None: @template.listener(test_queue) def server(ctx: MessageContext, message: str) -> None: reply_queue = ctx.reply_destination @@ -93,7 +98,7 @@ def server(ctx: MessageContext, message: str) -> None: raise RuntimeError("reply queue is None") template.send(reply_queue, "ack") - reply = template.send_and_receive(test_queue, "test", str).result(timeout=_TIMEOUT) + reply = template.send_and_receive(test_queue, "test", str).result(timeout=timeout) assert reply == "ack" @@ -108,7 +113,11 @@ class Foo(BaseModel): [("test", str), (1, int), (Foo(a=1, b="test"), Foo)], ) def test_deserialization( - template: MessagingTemplate, test_queue: str, message: Any, message_type: Type + template: MessagingTemplate, + timeout: float, + test_queue: str, + message: Any, + message_type: Type, ) -> None: def server(ctx: MessageContext, message: message_type) -> None: # type: ignore reply_queue = ctx.reply_destination @@ -118,37 +127,39 @@ def server(ctx: MessageContext, message: message_type) -> None: # type: ignore template.subscribe(test_queue, server) reply = template.send_and_receive(test_queue, message, message_type).result( - timeout=_TIMEOUT + timeout=timeout ) assert reply == message @pytest.mark.stomp def test_subscribe_before_connect( - disconnected_template: MessagingTemplate, test_queue: str + disconnected_template: MessagingTemplate, timeout: float, test_queue: str ) -> None: acknowledge(disconnected_template, test_queue) disconnected_template.connect() reply = disconnected_template.send_and_receive(test_queue, "test", str).result( - timeout=_TIMEOUT + timeout=timeout ) assert reply == "ack" @pytest.mark.stomp -def test_reconnect(template: MessagingTemplate, test_queue: str) -> None: +def test_reconnect( + template: MessagingTemplate, timeout: float, test_queue: str +) -> None: acknowledge(template, test_queue) - reply = template.send_and_receive(test_queue, "test", str).result(timeout=_TIMEOUT) + reply = template.send_and_receive(test_queue, "test", str).result(timeout=timeout) assert reply == "ack" template.disconnect() template.connect() - reply = template.send_and_receive(test_queue, "test", str).result(timeout=_TIMEOUT) + reply = template.send_and_receive(test_queue, "test", str).result(timeout=timeout) assert reply == "ack" @pytest.mark.stomp def test_correlation_id( - template: MessagingTemplate, test_queue: str, test_queue_2: str + template: MessagingTemplate, timeout: float, test_queue: str, test_queue_2: str ) -> None: correlation_id = "foobar" q: Queue = Queue() @@ -164,9 +175,9 @@ def client(ctx: MessageContext, msg: str) -> None: template.subscribe(test_queue_2, client) template.send(test_queue, "test", None, correlation_id) - ctx_req: MessageContext = q.get(timeout=_TIMEOUT) + ctx_req: MessageContext = q.get(timeout=timeout) assert ctx_req.correlation_id == correlation_id - ctx_ack: MessageContext = q.get(timeout=_TIMEOUT) + ctx_ack: MessageContext = q.get(timeout=timeout) assert ctx_ack.correlation_id == correlation_id diff --git a/tests/service/test_rest_api.py b/tests/service/test_rest_api.py index 5035c51f4..4363792d2 100644 --- a/tests/service/test_rest_api.py +++ b/tests/service/test_rest_api.py @@ -74,13 +74,15 @@ class MyDevice: def test_put_plan_submits_task(handler: Handler, client: TestClient) -> None: task_json = {"detectors": ["x"]} - task_name = "count" + plan_name = "count" submitted_tasks = {} + task_id = "fake-task" - def on_submit(name: str, task: Task): - submitted_tasks[name] = task + def on_submit(task: Task) -> str: + submitted_tasks[task_id] = task + return task_id handler.worker.submit_task.side_effect = on_submit # type: ignore - client.put(f"/task/{task_name}", json=task_json) - assert submitted_tasks == {task_name: RunPlan(name=task_name, params=task_json)} + client.put(f"/task/{plan_name}", json=task_json) + assert submitted_tasks == {task_id: RunPlan(name=plan_name, params=task_json)} diff --git a/tests/worker/test_reworker.py b/tests/worker/test_reworker.py index d04fd42ff..409e1b026 100644 --- a/tests/worker/test_reworker.py +++ b/tests/worker/test_reworker.py @@ -1,112 +1,252 @@ import itertools +import threading from concurrent.futures import Future from typing import Callable, Iterable, List, Optional, TypeVar +import mock import pytest from blueapi.config import EnvironmentConfig, Source, SourceKind from blueapi.core import BlueskyContext, EventStream from blueapi.worker import ( + ProgressEvent, RunEngineWorker, RunPlan, Task, TaskStatus, + TrackableTask, Worker, + WorkerBusyError, WorkerEvent, WorkerState, ) -from blueapi.worker.event import ProgressEvent -from blueapi.worker.worker_busy_error import WorkerBusyError + + +class SleepMock(mock.MagicMock): + async def __call__(self, *args, **kwargs): + return super(SleepMock, self).__call__(*args, **kwargs) + + +_SIMPLE_TASK = RunPlan(name="sleep", params={"time": 10.0}) +_LONG_TASK = RunPlan(name="sleep", params={"time": 200.0}) +_INDEFINITE_TASK = RunPlan( + name="set_absolute", + params={"movable": "fake_device", "value": 4.0}, +) + + +class FakeDevice: + event: threading.Event + + @property + def name(self) -> str: + return "fake_device" + + def __init__(self) -> None: + self.event = threading.Event() + + def set(self, pos: float) -> None: + self.event.wait() + self.event.clear() @pytest.fixture -def context() -> BlueskyContext: - ctx = BlueskyContext() - ctx_config = EnvironmentConfig() - ctx_config.sources.append( - Source(kind=SourceKind.DEVICE_FUNCTIONS, module="devices") - ) - ctx.with_config(ctx_config) - return ctx +def fake_device() -> FakeDevice: + return FakeDevice() @pytest.fixture -def worker(context: BlueskyContext) -> Iterable[Worker[Task]]: - worker = RunEngineWorker(context) - yield worker - worker.stop() +def context(fake_device: FakeDevice) -> BlueskyContext: + with mock.patch("bluesky.run_engine.asyncio.sleep", new_callable=SleepMock): + ctx = BlueskyContext() + ctx_config = EnvironmentConfig() + ctx_config.sources.append( + Source(kind=SourceKind.DEVICE_FUNCTIONS, module="devices") + ) + ctx.device(fake_device) + ctx.with_config(ctx_config) + yield ctx -def test_stop_doesnt_hang(worker: Worker) -> None: - worker.start() +@pytest.fixture +def inert_worker(context: BlueskyContext) -> Worker[Task]: + return RunEngineWorker(context, stop_timeout=2.0) + + +@pytest.fixture +def worker(inert_worker: Worker[Task]) -> Iterable[Worker[Task]]: + inert_worker.start() + yield inert_worker + inert_worker.stop() + + +def test_stop_doesnt_hang(inert_worker: Worker) -> None: + inert_worker.start() + inert_worker.stop() -def test_stop_is_idempontent_if_worker_not_started(worker: Worker) -> None: - ... +def test_stop_is_idempontent_if_worker_not_started(inert_worker: Worker) -> None: + inert_worker.stop() -def test_multi_stop(worker: Worker) -> None: - worker.start() - worker.stop() +def test_multi_stop(inert_worker: Worker) -> None: + inert_worker.start() + inert_worker.stop() + inert_worker.stop() -def test_multi_start(worker: Worker) -> None: - worker.start() + +def test_multi_start(inert_worker: Worker) -> None: + inert_worker.start() with pytest.raises(Exception): - worker.start() - - -def test_runs_plan(worker: Worker) -> None: - assert_run_produces_worker_events( - [ - WorkerEvent( - state=WorkerState.RUNNING, - task_status=TaskStatus( - task_name="test", task_complete=False, task_failed=False - ), - errors=[], - warnings=[], + inert_worker.start() + inert_worker.stop() + + +def test_submit_task(worker: Worker) -> None: + assert worker.get_pending_tasks() == [] + task_id = worker.submit_task(_SIMPLE_TASK) + assert worker.get_pending_tasks() == [ + TrackableTask(task_id=task_id, task=_SIMPLE_TASK) + ] + + +def test_submit_multiple_tasks(worker: Worker) -> None: + assert worker.get_pending_tasks() == [] + task_id_1 = worker.submit_task(_SIMPLE_TASK) + assert worker.get_pending_tasks() == [ + TrackableTask(task_id=task_id_1, task=_SIMPLE_TASK) + ] + task_id_2 = worker.submit_task(_LONG_TASK) + assert worker.get_pending_tasks() == [ + TrackableTask(task_id=task_id_1, task=_SIMPLE_TASK), + TrackableTask(task_id=task_id_2, task=_LONG_TASK), + ] + + +def test_stop_with_task_pending(inert_worker: Worker) -> None: + inert_worker.start() + inert_worker.submit_task(_SIMPLE_TASK) + inert_worker.stop() + + +def test_clear_task(worker: Worker) -> None: + task_id = worker.submit_task(_SIMPLE_TASK) + assert worker.get_pending_tasks() == [ + TrackableTask(task_id=task_id, task=_SIMPLE_TASK) + ] + assert worker.clear_task(task_id) + assert worker.get_pending_tasks() == [] + + +def test_clear_nonexistant_task(worker: Worker) -> None: + assert not worker.clear_task("foo") + + +def test_does_not_allow_simultaneous_running_tasks( + worker: Worker, + fake_device: FakeDevice, +) -> None: + task_ids = [ + worker.submit_task(_INDEFINITE_TASK), + worker.submit_task(_INDEFINITE_TASK), + ] + with pytest.raises(WorkerBusyError): + for task_id in task_ids: + worker.begin_task(task_id) + fake_device.event.set() + + +@pytest.mark.parametrize("num_runs", [0, 1, 2]) +def test_produces_worker_events(worker: Worker, timeout: float, num_runs: int) -> None: + task_ids = [worker.submit_task(_SIMPLE_TASK) for _ in range(num_runs)] + event_sequences = [_sleep_events(task_id) for task_id in task_ids] + + for task_id, events in zip(task_ids, event_sequences): + assert_run_produces_worker_events(events, worker, task_id, timeout) + + +def _sleep_events(task_id: str) -> List[WorkerEvent]: + return [ + WorkerEvent( + state=WorkerState.RUNNING, + task_status=TaskStatus( + task_id=task_id, task_complete=False, task_failed=False ), - WorkerEvent( - state=WorkerState.IDLE, - task_status=TaskStatus( - task_name="test", task_complete=False, task_failed=False - ), - errors=[], - warnings=[], + errors=[], + warnings=[], + ), + WorkerEvent( + state=WorkerState.IDLE, + task_status=TaskStatus( + task_id=task_id, task_complete=False, task_failed=False ), - WorkerEvent( - state=WorkerState.IDLE, - task_status=TaskStatus( - task_name="test", task_complete=True, task_failed=False - ), - errors=[], - warnings=[], + errors=[], + warnings=[], + ), + WorkerEvent( + state=WorkerState.IDLE, + task_status=TaskStatus( + task_id=task_id, task_complete=True, task_failed=False ), - ], - worker, + errors=[], + warnings=[], + ), + ] + + +def test_no_additional_progress_events_after_complete(worker: Worker, timeout: float): + """ + See https://github.com/bluesky/ophyd/issues/1115 + """ + + progress_events: List[ProgressEvent] = [] + worker.progress_events.subscribe(lambda event, id: progress_events.append(event)) + + task: Task = RunPlan( + name="move", params={"moves": {"additional_status_device": 5.0}} + ) + task_id = worker.submit_task(task) + begin_task_and_wait_until_complete(worker, task_id, timeout) + + # Extract all the display_name fields from the events + list_of_dict_keys = [pe.statuses.values() for pe in progress_events] + status_views = [item for sublist in list_of_dict_keys for item in sublist] + display_names = [view.display_name for view in status_views] + + assert "STATUS_AFTER_FINISH" not in display_names + + +# +# Worker helpers +# + + +def assert_run_produces_worker_events( + expected_events: List[WorkerEvent], worker: Worker, task_id: str, timeout: float +) -> None: + assert ( + begin_task_and_wait_until_complete(worker, task_id, timeout) == expected_events ) -def submit_task_and_wait_until_complete( - worker: Worker, task: Task, timeout: float = 5.0 +def begin_task_and_wait_until_complete( + worker: Worker, + task_id: str, + timeout: float, ) -> List[WorkerEvent]: events: "Future[List[WorkerEvent]]" = take_events( worker.worker_events, lambda event: event.is_complete(), ) + worker.begin_task(task_id) - worker.submit_task("test", task) return events.result(timeout=timeout) -def assert_run_produces_worker_events( - expected_events: List[WorkerEvent], - worker: Worker, - task: Task = RunPlan(name="sleep", params={"time": 0.0}), -) -> None: - worker.start() - assert submit_task_and_wait_until_complete(worker, task) == expected_events +# +# Event stream helpers +# E = TypeVar("E") @@ -136,34 +276,3 @@ def on_event(event: E, event_id: Optional[str]) -> None: sub = stream.subscribe(on_event) future.add_done_callback(lambda _: stream.unsubscribe(sub)) return future - - -def test_worker_only_accepts_one_task_on_queue(worker: Worker): - worker.start() - task: Task = RunPlan(name="sleep", params={"time": 1.0}) - - worker.submit_task("first_task", task) - with pytest.raises(WorkerBusyError): - worker.submit_task("second_task", task) - - -def test_no_additional_progress_events_after_complete(worker: Worker): - """ - See https://github.com/bluesky/ophyd/issues/1115 - """ - worker.start() - - progress_events: List[ProgressEvent] = [] - worker.progress_events.subscribe(lambda event, id: progress_events.append(event)) - - task: Task = RunPlan( - name="move", params={"moves": {"additional_status_device": 5.0}} - ) - submit_task_and_wait_until_complete(worker, task) - - # Exctract all the display_name fields from the events - list_of_dict_keys = [pe.statuses.values() for pe in progress_events] - status_views = [item for sublist in list_of_dict_keys for item in sublist] - display_names = [view.display_name for view in status_views] - - assert "STATUS_AFTER_FINISH" not in display_names