From 454cc9540caa4fe272a3206714d0089d4caa6a4b Mon Sep 17 00:00:00 2001 From: Callum Forrester Date: Mon, 15 May 2023 16:37:47 +0100 Subject: [PATCH 1/3] Add transaction mode to worker class --- src/blueapi/worker/reworker.py | 4 ++++ src/blueapi/worker/worker.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/src/blueapi/worker/reworker.py b/src/blueapi/worker/reworker.py index bec23f768..6bb6e3bcf 100644 --- a/src/blueapi/worker/reworker.py +++ b/src/blueapi/worker/reworker.py @@ -226,7 +226,11 @@ def _report_status( warnings = self._warnings if self._current is not None: task_status = TaskStatus( +<<<<<<< HEAD task_id=self._current.task_id, +======= + task_name=self._current.task_id, +>>>>>>> 185ad3ba (Add transaction mode to worker class) task_complete=self._current.is_complete, task_failed=self._current.is_error or bool(errors), ) diff --git a/src/blueapi/worker/worker.py b/src/blueapi/worker/worker.py index d935ee6a7..237d22487 100644 --- a/src/blueapi/worker/worker.py +++ b/src/blueapi/worker/worker.py @@ -1,5 +1,9 @@ from abc import ABC, abstractmethod +<<<<<<< HEAD from typing import Generic, List, TypeVar +======= +from typing import Generic, Optional, TypeVar +>>>>>>> 185ad3ba (Add transaction mode to worker class) from blueapi.core import DataEvent, EventStream from blueapi.utils import BlueapiBaseModel From 2fd5ead8b9fa5394da86faf8cd941b78a1b6fc5c Mon Sep 17 00:00:00 2001 From: Rose Yemelyanova Date: Tue, 23 May 2023 14:54:23 +0000 Subject: [PATCH 2/3] mocking out the RunEngine asyncio.sleep method in tests, standardising timeouts --- src/blueapi/worker/reworker.py | 4 -- src/blueapi/worker/worker.py | 4 -- tests/conftest.py | 7 +++ tests/core/test_event.py | 14 +++--- tests/messaging/test_stomptemplate.py | 53 ++++++++++++--------- tests/worker/test_reworker.py | 68 +++++++++++++++++++-------- 6 files changed, 93 insertions(+), 57 deletions(-) diff --git a/src/blueapi/worker/reworker.py b/src/blueapi/worker/reworker.py index 6bb6e3bcf..bec23f768 100644 --- a/src/blueapi/worker/reworker.py +++ b/src/blueapi/worker/reworker.py @@ -226,11 +226,7 @@ def _report_status( warnings = self._warnings if self._current is not None: task_status = TaskStatus( -<<<<<<< HEAD task_id=self._current.task_id, -======= - task_name=self._current.task_id, ->>>>>>> 185ad3ba (Add transaction mode to worker class) task_complete=self._current.is_complete, task_failed=self._current.is_error or bool(errors), ) diff --git a/src/blueapi/worker/worker.py b/src/blueapi/worker/worker.py index 237d22487..d935ee6a7 100644 --- a/src/blueapi/worker/worker.py +++ b/src/blueapi/worker/worker.py @@ -1,9 +1,5 @@ from abc import ABC, abstractmethod -<<<<<<< HEAD from typing import Generic, List, TypeVar -======= -from typing import Generic, Optional, TypeVar ->>>>>>> 185ad3ba (Add transaction mode to worker class) from blueapi.core import DataEvent, EventStream from blueapi.utils import BlueapiBaseModel diff --git a/tests/conftest.py b/tests/conftest.py index 101dd09c3..7f4df841c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,8 @@ from blueapi.service.main import app from src.blueapi.core import BlueskyContext +_TIMEOUT = 10.0 + def pytest_addoption(parser): parser.addoption( @@ -58,3 +60,8 @@ def no_op(): @pytest.fixture(scope="session") def client(handler: Handler) -> TestClient: return Client(handler).client + + +@pytest.fixture(scope="session") +def timeout() -> float: + return _TIMEOUT diff --git a/tests/core/test_event.py b/tests/core/test_event.py index 9f832de01..21335d916 100644 --- a/tests/core/test_event.py +++ b/tests/core/test_event.py @@ -7,8 +7,6 @@ from blueapi.core import EventPublisher -_TIMEOUT: float = 10.0 - @dataclass class MyEvent: @@ -20,22 +18,22 @@ def publisher() -> EventPublisher[MyEvent]: return EventPublisher() -def test_publishes_event(publisher: EventPublisher[MyEvent]) -> None: +def test_publishes_event(timeout: float, publisher: EventPublisher[MyEvent]) -> None: event = MyEvent("a") f: Future = Future() publisher.subscribe(lambda r, _: f.set_result(r)) publisher.publish(event) - assert f.result(timeout=_TIMEOUT) == event + assert f.result(timeout=timeout) == event -def test_multi_subscriber(publisher: EventPublisher[MyEvent]) -> None: +def test_multi_subscriber(timeout: float, publisher: EventPublisher[MyEvent]) -> None: event = MyEvent("a") f1: Future = Future() f2: Future = Future() publisher.subscribe(lambda r, _: f1.set_result(r)) publisher.subscribe(lambda r, _: f2.set_result(r)) publisher.publish(event) - assert f1.result(timeout=_TIMEOUT) == f2.result(timeout=_TIMEOUT) == event + assert f1.result(timeout=timeout) == f2.result(timeout=timeout) == event def test_can_unsubscribe(publisher: EventPublisher[MyEvent]) -> None: @@ -67,13 +65,13 @@ def test_can_unsubscribe_all(publisher: EventPublisher[MyEvent]) -> None: assert list(_drain(q)) == [event_a, event_a, event_c] -def test_correlation_id(publisher: EventPublisher[MyEvent]) -> None: +def test_correlation_id(timeout: float, publisher: EventPublisher[MyEvent]) -> None: event = MyEvent("a") correlation_id = "foobar" f: Future = Future() publisher.subscribe(lambda _, c: f.set_result(c)) publisher.publish(event, correlation_id) - assert f.result(timeout=_TIMEOUT) == correlation_id + assert f.result(timeout=timeout) == correlation_id def _drain(queue: Queue) -> Iterable: diff --git a/tests/messaging/test_stomptemplate.py b/tests/messaging/test_stomptemplate.py index 66e76ffbf..054d47072 100644 --- a/tests/messaging/test_stomptemplate.py +++ b/tests/messaging/test_stomptemplate.py @@ -9,7 +9,6 @@ from blueapi.config import StompConfig from blueapi.messaging import MessageContext, MessagingTemplate, StompMessagingTemplate -_TIMEOUT: float = 10.0 _COUNT = itertools.count() @@ -41,7 +40,7 @@ def test_topic(template: MessagingTemplate) -> str: @pytest.mark.stomp -def test_send(template: MessagingTemplate, test_queue: str) -> None: +def test_send(template: MessagingTemplate, timeout: float, test_queue: str) -> None: f: Future = Future() def callback(ctx: MessageContext, message: str) -> None: @@ -49,11 +48,13 @@ def callback(ctx: MessageContext, message: str) -> None: template.subscribe(test_queue, callback) template.send(test_queue, "test_message") - assert f.result(timeout=_TIMEOUT) + assert f.result(timeout=timeout) @pytest.mark.stomp -def test_send_to_topic(template: MessagingTemplate, test_topic: str) -> None: +def test_send_to_topic( + template: MessagingTemplate, timeout: float, test_topic: str +) -> None: f: Future = Future() def callback(ctx: MessageContext, message: str) -> None: @@ -61,11 +62,13 @@ def callback(ctx: MessageContext, message: str) -> None: template.subscribe(test_topic, callback) template.send(test_topic, "test_message") - assert f.result(timeout=_TIMEOUT) + assert f.result(timeout=timeout) @pytest.mark.stomp -def test_send_on_reply(template: MessagingTemplate, test_queue: str) -> None: +def test_send_on_reply( + template: MessagingTemplate, timeout: float, test_queue: str +) -> None: acknowledge(template, test_queue) f: Future = Future() @@ -74,18 +77,20 @@ def callback(ctx: MessageContext, message: str) -> None: f.set_result(message) template.send(test_queue, "test_message", callback) - assert f.result(timeout=_TIMEOUT) + assert f.result(timeout=timeout) @pytest.mark.stomp -def test_send_and_receive(template: MessagingTemplate, test_queue: str) -> None: +def test_send_and_receive( + template: MessagingTemplate, timeout: float, test_queue: str +) -> None: acknowledge(template, test_queue) - reply = template.send_and_receive(test_queue, "test", str).result(timeout=_TIMEOUT) + reply = template.send_and_receive(test_queue, "test", str).result(timeout=timeout) assert reply == "ack" @pytest.mark.stomp -def test_listener(template: MessagingTemplate, test_queue: str) -> None: +def test_listener(template: MessagingTemplate, timeout: float, test_queue: str) -> None: @template.listener(test_queue) def server(ctx: MessageContext, message: str) -> None: reply_queue = ctx.reply_destination @@ -93,7 +98,7 @@ def server(ctx: MessageContext, message: str) -> None: raise RuntimeError("reply queue is None") template.send(reply_queue, "ack") - reply = template.send_and_receive(test_queue, "test", str).result(timeout=_TIMEOUT) + reply = template.send_and_receive(test_queue, "test", str).result(timeout=timeout) assert reply == "ack" @@ -108,7 +113,11 @@ class Foo(BaseModel): [("test", str), (1, int), (Foo(a=1, b="test"), Foo)], ) def test_deserialization( - template: MessagingTemplate, test_queue: str, message: Any, message_type: Type + template: MessagingTemplate, + timeout: float, + test_queue: str, + message: Any, + message_type: Type, ) -> None: def server(ctx: MessageContext, message: message_type) -> None: # type: ignore reply_queue = ctx.reply_destination @@ -118,37 +127,39 @@ def server(ctx: MessageContext, message: message_type) -> None: # type: ignore template.subscribe(test_queue, server) reply = template.send_and_receive(test_queue, message, message_type).result( - timeout=_TIMEOUT + timeout=timeout ) assert reply == message @pytest.mark.stomp def test_subscribe_before_connect( - disconnected_template: MessagingTemplate, test_queue: str + disconnected_template: MessagingTemplate, timeout: float, test_queue: str ) -> None: acknowledge(disconnected_template, test_queue) disconnected_template.connect() reply = disconnected_template.send_and_receive(test_queue, "test", str).result( - timeout=_TIMEOUT + timeout=timeout ) assert reply == "ack" @pytest.mark.stomp -def test_reconnect(template: MessagingTemplate, test_queue: str) -> None: +def test_reconnect( + template: MessagingTemplate, timeout: float, test_queue: str +) -> None: acknowledge(template, test_queue) - reply = template.send_and_receive(test_queue, "test", str).result(timeout=_TIMEOUT) + reply = template.send_and_receive(test_queue, "test", str).result(timeout=timeout) assert reply == "ack" template.disconnect() template.connect() - reply = template.send_and_receive(test_queue, "test", str).result(timeout=_TIMEOUT) + reply = template.send_and_receive(test_queue, "test", str).result(timeout=timeout) assert reply == "ack" @pytest.mark.stomp def test_correlation_id( - template: MessagingTemplate, test_queue: str, test_queue_2: str + template: MessagingTemplate, timeout: float, test_queue: str, test_queue_2: str ) -> None: correlation_id = "foobar" q: Queue = Queue() @@ -164,9 +175,9 @@ def client(ctx: MessageContext, msg: str) -> None: template.subscribe(test_queue_2, client) template.send(test_queue, "test", None, correlation_id) - ctx_req: MessageContext = q.get(timeout=_TIMEOUT) + ctx_req: MessageContext = q.get(timeout=timeout) assert ctx_req.correlation_id == correlation_id - ctx_ack: MessageContext = q.get(timeout=_TIMEOUT) + ctx_ack: MessageContext = q.get(timeout=timeout) assert ctx_ack.correlation_id == correlation_id diff --git a/tests/worker/test_reworker.py b/tests/worker/test_reworker.py index bdda025f3..1dfc101b6 100644 --- a/tests/worker/test_reworker.py +++ b/tests/worker/test_reworker.py @@ -3,6 +3,7 @@ from concurrent.futures import Future from typing import Callable, Iterable, List, Optional, TypeVar +import mock import pytest from blueapi.config import EnvironmentConfig, Source, SourceKind @@ -20,8 +21,14 @@ WorkerState, ) -_SIMPLE_TASK = RunPlan(name="sleep", params={"time": 0.0}) -_LONG_TASK = RunPlan(name="sleep", params={"time": 1.0}) + +class SleepMock(mock.MagicMock): + async def __call__(self, *args, **kwargs): + return super(SleepMock, self).__call__(*args, **kwargs) + + +_SIMPLE_TASK = RunPlan(name="sleep", params={"time": 10.0}) +_LONG_TASK = RunPlan(name="sleep", params={"time": 200.0}) _INDEFINITE_TASK = RunPlan( name="set_absolute", params={"movable": "fake_device", "value": 4.0}, @@ -50,14 +57,16 @@ def fake_device() -> FakeDevice: @pytest.fixture def context(fake_device: FakeDevice) -> BlueskyContext: - ctx = BlueskyContext() - ctx_config = EnvironmentConfig() - ctx_config.sources.append( - Source(kind=SourceKind.DEVICE_FUNCTIONS, module="devices") - ) - ctx.device(fake_device) - ctx.with_config(ctx_config) - return ctx + with mock.patch("bluesky.run_engine.asyncio.sleep", new_callable=SleepMock): + ctx = BlueskyContext() + + ctx_config = EnvironmentConfig() + ctx_config.sources.append( + Source(kind=SourceKind.DEVICE_FUNCTIONS, module="devices") + ) + ctx.device(fake_device) + ctx.with_config(ctx_config) + yield ctx @pytest.fixture @@ -77,6 +86,15 @@ def test_stop_doesnt_hang(inert_worker: Worker) -> None: inert_worker.stop() +def test_stop_doesnt_hang(inert_worker: Worker) -> None: + inert_worker.start() + inert_worker.stop() + + +def test_stop_is_idempontent_if_worker_not_started(inert_worker: Worker) -> None: + inert_worker.stop() + + def test_stop_is_idempontent_if_worker_not_started(inert_worker: Worker) -> None: inert_worker.stop() @@ -87,6 +105,12 @@ def test_multi_stop(inert_worker: Worker) -> None: inert_worker.stop() +def test_multi_stop(inert_worker: Worker) -> None: + inert_worker.start() + inert_worker.stop() + inert_worker.stop() + + def test_multi_start(inert_worker: Worker) -> None: inert_worker.start() with pytest.raises(Exception): @@ -173,12 +197,12 @@ def test_does_not_allow_simultaneous_running_tasks( @pytest.mark.parametrize("num_runs", [0, 1, 2]) -def test_produces_worker_events(worker: Worker, num_runs: int) -> None: +def test_produces_worker_events(worker: Worker, timeout: float, num_runs: int) -> None: task_ids = [worker.submit_task(_SIMPLE_TASK) for _ in range(num_runs)] event_sequences = [_sleep_events(task_id) for task_id in task_ids] for task_id, events in zip(task_ids, event_sequences): - assert_run_produces_worker_events(events, worker, task_id) + assert_run_produces_worker_events(events, worker, task_id, timeout) def _sleep_events(task_id: str) -> List[WorkerEvent]: @@ -210,7 +234,7 @@ def _sleep_events(task_id: str) -> List[WorkerEvent]: ] -def test_no_additional_progress_events_after_complete(worker: Worker): +def test_no_additional_progress_events_after_complete(worker: Worker, timeout: float): """ See https://github.com/bluesky/ophyd/issues/1115 """ @@ -222,7 +246,7 @@ def test_no_additional_progress_events_after_complete(worker: Worker): name="move", params={"moves": {"additional_status_device": 5.0}} ) task_id = worker.submit_task(task) - begin_task_and_wait_until_complete(worker, task_id) + begin_task_and_wait_until_complete(worker, task_id, timeout) # Extract all the display_name fields from the events list_of_dict_keys = [pe.statuses.values() for pe in progress_events] @@ -238,23 +262,27 @@ def test_no_additional_progress_events_after_complete(worker: Worker): def assert_run_produces_worker_events( - expected_events: List[WorkerEvent], - worker: Worker, - task_id: str, + expected_events: List[WorkerEvent], worker: Worker, task_id: str, timeout: float ) -> None: - assert begin_task_and_wait_until_complete(worker, task_id) == expected_events + assert ( + begin_task_and_wait_until_complete(worker, task_id, timeout) == expected_events + ) + + +# +# Worker helpers +# def begin_task_and_wait_until_complete( worker: Worker, task_id: str, - timeout: float = 5.0, + timeout: float, ) -> List[WorkerEvent]: events: "Future[List[WorkerEvent]]" = take_events( worker.worker_events, lambda event: event.is_complete(), ) - worker.begin_task(task_id) return events.result(timeout=timeout) From fe40a5f69ca4bb17df5f2381d1223f75a6e4761d Mon Sep 17 00:00:00 2001 From: Rose Yemelyanova Date: Tue, 23 May 2023 15:18:50 +0000 Subject: [PATCH 3/3] fixed linting --- tests/worker/test_reworker.py | 19 ++----------------- 1 file changed, 2 insertions(+), 17 deletions(-) diff --git a/tests/worker/test_reworker.py b/tests/worker/test_reworker.py index 1dfc101b6..b9db994c6 100644 --- a/tests/worker/test_reworker.py +++ b/tests/worker/test_reworker.py @@ -1,7 +1,7 @@ import itertools import threading from concurrent.futures import Future -from typing import Callable, Iterable, List, Optional, TypeVar +from typing import Callable, Generator, Iterable, List, Optional, TypeVar import mock import pytest @@ -56,7 +56,7 @@ def fake_device() -> FakeDevice: @pytest.fixture -def context(fake_device: FakeDevice) -> BlueskyContext: +def context(fake_device: FakeDevice) -> Generator[BlueskyContext, None, None]: with mock.patch("bluesky.run_engine.asyncio.sleep", new_callable=SleepMock): ctx = BlueskyContext() @@ -86,25 +86,10 @@ def test_stop_doesnt_hang(inert_worker: Worker) -> None: inert_worker.stop() -def test_stop_doesnt_hang(inert_worker: Worker) -> None: - inert_worker.start() - inert_worker.stop() - - def test_stop_is_idempontent_if_worker_not_started(inert_worker: Worker) -> None: inert_worker.stop() -def test_stop_is_idempontent_if_worker_not_started(inert_worker: Worker) -> None: - inert_worker.stop() - - -def test_multi_stop(inert_worker: Worker) -> None: - inert_worker.start() - inert_worker.stop() - inert_worker.stop() - - def test_multi_stop(inert_worker: Worker) -> None: inert_worker.start() inert_worker.stop()