diff --git a/tests/conftest.py b/tests/conftest.py index 101dd09c3..7f4df841c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,8 @@ from blueapi.service.main import app from src.blueapi.core import BlueskyContext +_TIMEOUT = 10.0 + def pytest_addoption(parser): parser.addoption( @@ -58,3 +60,8 @@ def no_op(): @pytest.fixture(scope="session") def client(handler: Handler) -> TestClient: return Client(handler).client + + +@pytest.fixture(scope="session") +def timeout() -> float: + return _TIMEOUT diff --git a/tests/core/test_event.py b/tests/core/test_event.py index 9f832de01..21335d916 100644 --- a/tests/core/test_event.py +++ b/tests/core/test_event.py @@ -7,8 +7,6 @@ from blueapi.core import EventPublisher -_TIMEOUT: float = 10.0 - @dataclass class MyEvent: @@ -20,22 +18,22 @@ def publisher() -> EventPublisher[MyEvent]: return EventPublisher() -def test_publishes_event(publisher: EventPublisher[MyEvent]) -> None: +def test_publishes_event(timeout: float, publisher: EventPublisher[MyEvent]) -> None: event = MyEvent("a") f: Future = Future() publisher.subscribe(lambda r, _: f.set_result(r)) publisher.publish(event) - assert f.result(timeout=_TIMEOUT) == event + assert f.result(timeout=timeout) == event -def test_multi_subscriber(publisher: EventPublisher[MyEvent]) -> None: +def test_multi_subscriber(timeout: float, publisher: EventPublisher[MyEvent]) -> None: event = MyEvent("a") f1: Future = Future() f2: Future = Future() publisher.subscribe(lambda r, _: f1.set_result(r)) publisher.subscribe(lambda r, _: f2.set_result(r)) publisher.publish(event) - assert f1.result(timeout=_TIMEOUT) == f2.result(timeout=_TIMEOUT) == event + assert f1.result(timeout=timeout) == f2.result(timeout=timeout) == event def test_can_unsubscribe(publisher: EventPublisher[MyEvent]) -> None: @@ -67,13 +65,13 @@ def test_can_unsubscribe_all(publisher: EventPublisher[MyEvent]) -> None: assert list(_drain(q)) == [event_a, event_a, event_c] -def test_correlation_id(publisher: EventPublisher[MyEvent]) -> None: +def test_correlation_id(timeout: float, publisher: EventPublisher[MyEvent]) -> None: event = MyEvent("a") correlation_id = "foobar" f: Future = Future() publisher.subscribe(lambda _, c: f.set_result(c)) publisher.publish(event, correlation_id) - assert f.result(timeout=_TIMEOUT) == correlation_id + assert f.result(timeout=timeout) == correlation_id def _drain(queue: Queue) -> Iterable: diff --git a/tests/messaging/test_stomptemplate.py b/tests/messaging/test_stomptemplate.py index 66e76ffbf..054d47072 100644 --- a/tests/messaging/test_stomptemplate.py +++ b/tests/messaging/test_stomptemplate.py @@ -9,7 +9,6 @@ from blueapi.config import StompConfig from blueapi.messaging import MessageContext, MessagingTemplate, StompMessagingTemplate -_TIMEOUT: float = 10.0 _COUNT = itertools.count() @@ -41,7 +40,7 @@ def test_topic(template: MessagingTemplate) -> str: @pytest.mark.stomp -def test_send(template: MessagingTemplate, test_queue: str) -> None: +def test_send(template: MessagingTemplate, timeout: float, test_queue: str) -> None: f: Future = Future() def callback(ctx: MessageContext, message: str) -> None: @@ -49,11 +48,13 @@ def callback(ctx: MessageContext, message: str) -> None: template.subscribe(test_queue, callback) template.send(test_queue, "test_message") - assert f.result(timeout=_TIMEOUT) + assert f.result(timeout=timeout) @pytest.mark.stomp -def test_send_to_topic(template: MessagingTemplate, test_topic: str) -> None: +def test_send_to_topic( + template: MessagingTemplate, timeout: float, test_topic: str +) -> None: f: Future = Future() def callback(ctx: MessageContext, message: str) -> None: @@ -61,11 +62,13 @@ def callback(ctx: MessageContext, message: str) -> None: template.subscribe(test_topic, callback) template.send(test_topic, "test_message") - assert f.result(timeout=_TIMEOUT) + assert f.result(timeout=timeout) @pytest.mark.stomp -def test_send_on_reply(template: MessagingTemplate, test_queue: str) -> None: +def test_send_on_reply( + template: MessagingTemplate, timeout: float, test_queue: str +) -> None: acknowledge(template, test_queue) f: Future = Future() @@ -74,18 +77,20 @@ def callback(ctx: MessageContext, message: str) -> None: f.set_result(message) template.send(test_queue, "test_message", callback) - assert f.result(timeout=_TIMEOUT) + assert f.result(timeout=timeout) @pytest.mark.stomp -def test_send_and_receive(template: MessagingTemplate, test_queue: str) -> None: +def test_send_and_receive( + template: MessagingTemplate, timeout: float, test_queue: str +) -> None: acknowledge(template, test_queue) - reply = template.send_and_receive(test_queue, "test", str).result(timeout=_TIMEOUT) + reply = template.send_and_receive(test_queue, "test", str).result(timeout=timeout) assert reply == "ack" @pytest.mark.stomp -def test_listener(template: MessagingTemplate, test_queue: str) -> None: +def test_listener(template: MessagingTemplate, timeout: float, test_queue: str) -> None: @template.listener(test_queue) def server(ctx: MessageContext, message: str) -> None: reply_queue = ctx.reply_destination @@ -93,7 +98,7 @@ def server(ctx: MessageContext, message: str) -> None: raise RuntimeError("reply queue is None") template.send(reply_queue, "ack") - reply = template.send_and_receive(test_queue, "test", str).result(timeout=_TIMEOUT) + reply = template.send_and_receive(test_queue, "test", str).result(timeout=timeout) assert reply == "ack" @@ -108,7 +113,11 @@ class Foo(BaseModel): [("test", str), (1, int), (Foo(a=1, b="test"), Foo)], ) def test_deserialization( - template: MessagingTemplate, test_queue: str, message: Any, message_type: Type + template: MessagingTemplate, + timeout: float, + test_queue: str, + message: Any, + message_type: Type, ) -> None: def server(ctx: MessageContext, message: message_type) -> None: # type: ignore reply_queue = ctx.reply_destination @@ -118,37 +127,39 @@ def server(ctx: MessageContext, message: message_type) -> None: # type: ignore template.subscribe(test_queue, server) reply = template.send_and_receive(test_queue, message, message_type).result( - timeout=_TIMEOUT + timeout=timeout ) assert reply == message @pytest.mark.stomp def test_subscribe_before_connect( - disconnected_template: MessagingTemplate, test_queue: str + disconnected_template: MessagingTemplate, timeout: float, test_queue: str ) -> None: acknowledge(disconnected_template, test_queue) disconnected_template.connect() reply = disconnected_template.send_and_receive(test_queue, "test", str).result( - timeout=_TIMEOUT + timeout=timeout ) assert reply == "ack" @pytest.mark.stomp -def test_reconnect(template: MessagingTemplate, test_queue: str) -> None: +def test_reconnect( + template: MessagingTemplate, timeout: float, test_queue: str +) -> None: acknowledge(template, test_queue) - reply = template.send_and_receive(test_queue, "test", str).result(timeout=_TIMEOUT) + reply = template.send_and_receive(test_queue, "test", str).result(timeout=timeout) assert reply == "ack" template.disconnect() template.connect() - reply = template.send_and_receive(test_queue, "test", str).result(timeout=_TIMEOUT) + reply = template.send_and_receive(test_queue, "test", str).result(timeout=timeout) assert reply == "ack" @pytest.mark.stomp def test_correlation_id( - template: MessagingTemplate, test_queue: str, test_queue_2: str + template: MessagingTemplate, timeout: float, test_queue: str, test_queue_2: str ) -> None: correlation_id = "foobar" q: Queue = Queue() @@ -164,9 +175,9 @@ def client(ctx: MessageContext, msg: str) -> None: template.subscribe(test_queue_2, client) template.send(test_queue, "test", None, correlation_id) - ctx_req: MessageContext = q.get(timeout=_TIMEOUT) + ctx_req: MessageContext = q.get(timeout=timeout) assert ctx_req.correlation_id == correlation_id - ctx_ack: MessageContext = q.get(timeout=_TIMEOUT) + ctx_ack: MessageContext = q.get(timeout=timeout) assert ctx_ack.correlation_id == correlation_id diff --git a/tests/worker/test_reworker.py b/tests/worker/test_reworker.py index bdda025f3..b9db994c6 100644 --- a/tests/worker/test_reworker.py +++ b/tests/worker/test_reworker.py @@ -1,8 +1,9 @@ import itertools import threading from concurrent.futures import Future -from typing import Callable, Iterable, List, Optional, TypeVar +from typing import Callable, Generator, Iterable, List, Optional, TypeVar +import mock import pytest from blueapi.config import EnvironmentConfig, Source, SourceKind @@ -20,8 +21,14 @@ WorkerState, ) -_SIMPLE_TASK = RunPlan(name="sleep", params={"time": 0.0}) -_LONG_TASK = RunPlan(name="sleep", params={"time": 1.0}) + +class SleepMock(mock.MagicMock): + async def __call__(self, *args, **kwargs): + return super(SleepMock, self).__call__(*args, **kwargs) + + +_SIMPLE_TASK = RunPlan(name="sleep", params={"time": 10.0}) +_LONG_TASK = RunPlan(name="sleep", params={"time": 200.0}) _INDEFINITE_TASK = RunPlan( name="set_absolute", params={"movable": "fake_device", "value": 4.0}, @@ -49,15 +56,17 @@ def fake_device() -> FakeDevice: @pytest.fixture -def context(fake_device: FakeDevice) -> BlueskyContext: - ctx = BlueskyContext() - ctx_config = EnvironmentConfig() - ctx_config.sources.append( - Source(kind=SourceKind.DEVICE_FUNCTIONS, module="devices") - ) - ctx.device(fake_device) - ctx.with_config(ctx_config) - return ctx +def context(fake_device: FakeDevice) -> Generator[BlueskyContext, None, None]: + with mock.patch("bluesky.run_engine.asyncio.sleep", new_callable=SleepMock): + ctx = BlueskyContext() + + ctx_config = EnvironmentConfig() + ctx_config.sources.append( + Source(kind=SourceKind.DEVICE_FUNCTIONS, module="devices") + ) + ctx.device(fake_device) + ctx.with_config(ctx_config) + yield ctx @pytest.fixture @@ -173,12 +182,12 @@ def test_does_not_allow_simultaneous_running_tasks( @pytest.mark.parametrize("num_runs", [0, 1, 2]) -def test_produces_worker_events(worker: Worker, num_runs: int) -> None: +def test_produces_worker_events(worker: Worker, timeout: float, num_runs: int) -> None: task_ids = [worker.submit_task(_SIMPLE_TASK) for _ in range(num_runs)] event_sequences = [_sleep_events(task_id) for task_id in task_ids] for task_id, events in zip(task_ids, event_sequences): - assert_run_produces_worker_events(events, worker, task_id) + assert_run_produces_worker_events(events, worker, task_id, timeout) def _sleep_events(task_id: str) -> List[WorkerEvent]: @@ -210,7 +219,7 @@ def _sleep_events(task_id: str) -> List[WorkerEvent]: ] -def test_no_additional_progress_events_after_complete(worker: Worker): +def test_no_additional_progress_events_after_complete(worker: Worker, timeout: float): """ See https://github.com/bluesky/ophyd/issues/1115 """ @@ -222,7 +231,7 @@ def test_no_additional_progress_events_after_complete(worker: Worker): name="move", params={"moves": {"additional_status_device": 5.0}} ) task_id = worker.submit_task(task) - begin_task_and_wait_until_complete(worker, task_id) + begin_task_and_wait_until_complete(worker, task_id, timeout) # Extract all the display_name fields from the events list_of_dict_keys = [pe.statuses.values() for pe in progress_events] @@ -238,23 +247,27 @@ def test_no_additional_progress_events_after_complete(worker: Worker): def assert_run_produces_worker_events( - expected_events: List[WorkerEvent], - worker: Worker, - task_id: str, + expected_events: List[WorkerEvent], worker: Worker, task_id: str, timeout: float ) -> None: - assert begin_task_and_wait_until_complete(worker, task_id) == expected_events + assert ( + begin_task_and_wait_until_complete(worker, task_id, timeout) == expected_events + ) + + +# +# Worker helpers +# def begin_task_and_wait_until_complete( worker: Worker, task_id: str, - timeout: float = 5.0, + timeout: float, ) -> List[WorkerEvent]: events: "Future[List[WorkerEvent]]" = take_events( worker.worker_events, lambda event: event.is_complete(), ) - worker.begin_task(task_id) return events.result(timeout=timeout)