diff --git a/.github/workflows/_test.yml b/.github/workflows/_test.yml index 9efe01979..8fc0c8ee3 100644 --- a/.github/workflows/_test.yml +++ b/.github/workflows/_test.yml @@ -24,19 +24,7 @@ jobs: run: runs-on: ${{ inputs.runs-on }} - services: - activemq: - image: rmohr/activemq:5.14.5-alpine - ports: - - 61613:61613 - steps: - - name: Start RabbitMQ - uses: namoshek/rabbitmq-github-action@v1 - with: - ports: "61614:61613" - plugins: rabbitmq_stomp - - name: Checkout uses: actions/checkout@v4 with: diff --git a/dev-requirements.txt b/dev-requirements.txt index a48d6d772..317ce9015 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -17,6 +17,7 @@ bidict==0.23.1 bluesky==1.13.0a4 bluesky-kafka==0.10.0 bluesky-live==0.0.8 +bluesky-stomp==0.1.0 boltons==24.0.0 cachetools==5.5.0 caproto==1.1.1 @@ -212,10 +213,10 @@ tzlocal==5.2 urllib3==2.2.2 uvicorn==0.30.6 virtualenv==20.26.3 -watchfiles==0.23.0 +watchfiles==0.24.0 wcwidth==0.2.13 websocket-client==1.8.0 -websockets==13.0 +websockets==13.0.1 widgetsnbextension==4.0.13 workflows==2.27 xarray==2024.7.0 diff --git a/docs/reference/openapi.yaml b/docs/reference/openapi.yaml index 09f4b256c..a9fdf63cc 100644 --- a/docs/reference/openapi.yaml +++ b/docs/reference/openapi.yaml @@ -104,12 +104,10 @@ components: description: Request to change the state of the worker. properties: defer: - anyOf: - - type: boolean - - type: 'null' default: false description: Should worker defer Pausing until the next checkpoint title: Defer + type: boolean new_state: $ref: '#/components/schemas/WorkerState' reason: diff --git a/helm/blueapi/templates/deployment.yaml b/helm/blueapi/templates/statefulset.yaml similarity index 97% rename from helm/blueapi/templates/deployment.yaml rename to helm/blueapi/templates/statefulset.yaml index 8d03274a5..e7847c73b 100644 --- a/helm/blueapi/templates/deployment.yaml +++ b/helm/blueapi/templates/statefulset.yaml @@ -50,7 +50,7 @@ spec: image: "{{ .Values.image.repository }}:{{ .Values.image.tag | default .Chart.AppVersion }}" imagePullPolicy: {{ .Values.image.pullPolicy }} resources: - {{- toYaml .Values.resources | nindent 12 }} + {{- .Values.initResources | default .Values.resources | toYaml | nindent 12 }} command: [/bin/sh, -c] args: - | diff --git a/helm/blueapi/values.yaml b/helm/blueapi/values.yaml index 569d26076..b7a6220f3 100644 --- a/helm/blueapi/values.yaml +++ b/helm/blueapi/values.yaml @@ -60,6 +60,11 @@ resources: # cpu: 100m # memory: 128Mi +initResources: + {} + # Can optionally specify separate resource constraints for the scratch setup container. + # If left empty this defaults to the same as resources above. + nodeSelector: {} tolerations: [] diff --git a/pyproject.toml b/pyproject.toml index 4057b93b4..545ce0ae4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ dependencies = [ "dls-dodal>=1.24.0", "super-state-machine", # See GH issue 553 "GitPython", + "bluesky-stomp>=0.1.0" ] dynamic = ["version"] license.file = "LICENSE" @@ -90,9 +91,6 @@ addopts = """ filterwarnings = ["error", "ignore::DeprecationWarning"] # Doctest python code in docs, python code in src docstrings, test functions in tests testpaths = "docs src tests" -markers = [ - "handler: marks tests that interact with the global handler object in handler.py", -] asyncio_mode = "auto" [tool.coverage.run] diff --git a/src/blueapi/cli/cli.py b/src/blueapi/cli/cli.py index 199fe33f5..3802e4d13 100644 --- a/src/blueapi/cli/cli.py +++ b/src/blueapi/cli/cli.py @@ -6,6 +6,8 @@ import click from bluesky.callbacks.best_effort import BestEffortCallback +from bluesky_stomp.messaging import MessageContext, MessagingTemplate +from bluesky_stomp.models import Broker from pydantic import ValidationError from requests.exceptions import ConnectionError @@ -16,8 +18,6 @@ from blueapi.client.rest import BlueskyRemoteControlError from blueapi.config import ApplicationConfig, ConfigLoader from blueapi.core import DataEvent -from blueapi.messaging import MessageContext -from blueapi.messaging.stomptemplate import StompMessagingTemplate from blueapi.service.main import start from blueapi.service.openapi import ( DOCS_SCHEMA_LOCATION, @@ -147,14 +147,20 @@ def listen_to_events(obj: dict) -> None: config: ApplicationConfig = obj["config"] if config.stomp is not None: event_bus_client = EventBusClient( - StompMessagingTemplate.autoconfigured(config.stomp) + MessagingTemplate.for_broker( + broker=Broker( + host=config.stomp.host, + port=config.stomp.port, + auth=config.stomp.auth, + ) + ) ) else: raise RuntimeError("Message bus needs to be configured") def on_event( - context: MessageContext, event: WorkerEvent | ProgressEvent | DataEvent, + context: MessageContext, ) -> None: converted = json.dumps(event.dict(), indent=2) print(converted) diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py index 859491bd2..176f21350 100644 --- a/src/blueapi/client/client.py +++ b/src/blueapi/client/client.py @@ -1,9 +1,11 @@ import time from concurrent.futures import Future +from bluesky_stomp.messaging import MessageContext, MessagingTemplate +from bluesky_stomp.models import Broker + from blueapi.config import ApplicationConfig from blueapi.core.bluesky_types import DataEvent -from blueapi.messaging import MessageContext, StompMessagingTemplate from blueapi.service.model import ( DeviceModel, DeviceResponse, @@ -38,7 +40,13 @@ def __init__( def from_config(cls, config: ApplicationConfig) -> "BlueapiClient": rest = BlueapiRestClient(config.api) if config.stomp is not None: - template = StompMessagingTemplate.autoconfigured(config.stomp) + template = MessagingTemplate.for_broker( + broker=Broker( + host=config.stomp.host, + port=config.stomp.port, + auth=config.stomp.auth, + ) + ) events = EventBusClient(template) else: events = None @@ -178,7 +186,7 @@ def run_task( complete: Future[WorkerEvent] = Future() - def inner_on_event(ctx: MessageContext, event: AnyEvent) -> None: + def inner_on_event(event: AnyEvent, ctx: MessageContext) -> None: match event: case WorkerEvent(task_status=TaskStatus(task_id=test_id)): relates_to_task = test_id == task_id diff --git a/src/blueapi/client/event_bus.py b/src/blueapi/client/event_bus.py index bfd0afd18..94e374d4b 100644 --- a/src/blueapi/client/event_bus.py +++ b/src/blueapi/client/event_bus.py @@ -1,7 +1,9 @@ from collections.abc import Callable +from bluesky_stomp.messaging import MessageContext, MessagingTemplate +from bluesky_stomp.models import MessageTopic + from blueapi.core import DataEvent -from blueapi.messaging import MessageContext, MessagingTemplate from blueapi.worker import ProgressEvent, WorkerEvent @@ -28,11 +30,11 @@ def __exit__(self, exc_type, exc_value, exc_traceback) -> None: def subscribe_to_all_events( self, - on_event: Callable[[MessageContext, AnyEvent], None], + on_event: Callable[[AnyEvent, MessageContext], None], ) -> None: try: self.app.subscribe( - self.app.destinations.topic("public.worker.event"), + MessageTopic(name="public.worker.event"), on_event, ) except Exception as err: diff --git a/src/blueapi/messaging/__init__.py b/src/blueapi/messaging/__init__.py deleted file mode 100644 index 0aeb5eb6d..000000000 --- a/src/blueapi/messaging/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -from .base import DestinationProvider, MessageListener, MessagingTemplate -from .context import MessageContext -from .stomptemplate import StompDestinationProvider, StompMessagingTemplate - -__all__ = [ - "MessageListener", - "MessagingTemplate", - "MessageContext", - "StompMessagingTemplate", - "DestinationProvider", - "StompDestinationProvider", -] diff --git a/src/blueapi/messaging/base.py b/src/blueapi/messaging/base.py deleted file mode 100644 index 6c350639a..000000000 --- a/src/blueapi/messaging/base.py +++ /dev/null @@ -1,196 +0,0 @@ -from abc import ABC, abstractmethod -from collections.abc import Callable -from concurrent.futures import Future -from typing import Any - -from .context import MessageContext - -MessageListener = Callable[[MessageContext, Any], None] - - -class DestinationProvider(ABC): - """ - Class that provides destinations for specific types of message bus. - Implementation may be eager or lazy. - """ - - @abstractmethod - def default(self, name: str) -> str: - """ - A default type of destination with a given name. - For example, the provider could default to using queues if no - preference is specified. - - Args: - name (str): The name of the destination - - Returns: - str: Identifier for the destination - """ - - @abstractmethod - def queue(self, name: str) -> str: - """ - A queue with the given name - - Args: - name (str): Name of the queue - - Returns: - str: Identifier for the queue - """ - - @abstractmethod - def topic(self, name: str) -> str: - """ - A topic with the given name - - Args: - name (str): Name of the topic - - Returns: - str: Identifier for the topic - """ - - @abstractmethod - def temporary_queue(self, name: str) -> str: - """ - A temporary queue with the given name - - Args: - name (str): Name of the queue - - Returns: - str: Identifier for the queue - """ - - -class MessagingTemplate(ABC): - """ - Class meant for quickly building message-based applications. - Includes helpers for asynchronous production/consumption and - synchronous send/receive model - """ - - @property - @abstractmethod - def destinations(self) -> DestinationProvider: - """ - Get a destination provider that can create destination - identifiers for this particular template - - Returns: - DestinationProvider: Destination provider - """ - - def send_and_receive( - self, - destination: str, - obj: Any, - reply_type: type = str, - correlation_id: str | None = None, - ) -> Future: - """ - Send a message expecting a single reply. - - Args: - destination (str): Destination to send the message - obj (Any): Message to send, must be serializable - reply_type (Type, optional): Expected type of reply, used - in deserialization. Defaults to str. - correlation_id (Optional[str]): An id which correlates this request with - requests it spawns or the request which - spawned it etc. - Returns: - Future: Future representing the reply - """ - - future: Future = Future() - - def callback(_: MessageContext, reply: Any) -> None: - future.set_result(reply) - - callback.__annotations__["reply"] = reply_type - self.send(destination, obj, callback, correlation_id) - return future - - @abstractmethod - def send( - self, - destination: str, - obj: Any, - on_reply: MessageListener | None = None, - correlation_id: str | None = None, - ) -> None: - """ - Send a message to a destination - - Args: - destination (str): Destination to send the message - obj (Any): Message to send, must be serializable - on_reply (Optional[MessageListener], optional): Callback function for - a reply. Defaults to None. - correlation_id (Optional[str]): An id which correlates this request with - requests it spawns or the request which - spawned it etc. - """ - - def listener(self, destination: str): - """ - Decorator for subscribing to a topic: - - @my_app.listener("my-destination") - def callback(context: MessageContext, message: ???) -> None: - ... - - Args: - destination (str): Destination to subscribe to - """ - - def decorator(callback: MessageListener) -> MessageListener: - self.subscribe(destination, callback) - return callback - - return decorator - - @abstractmethod - def subscribe( - self, - destination: str, - callback: MessageListener, - ) -> None: - """ - Subscribe to messages from a particular destination. Requires - a callback of the form: - - def callback(context: MessageContext, message: ???) -> None: - ... - - The type annotation of the message will be inspected and used in - deserialization. - - Args: - destination (str): Destination to subscribe to - callback (MessageListener): What to do with each message - """ - - @abstractmethod - def connect(self) -> None: - """ - Connect the app to transport - """ - - @abstractmethod - def disconnect(self) -> None: - """ - Disconnect the app from transport - """ - - @abstractmethod - def is_connected(self) -> bool: - """ - Returns status of the connection between the app and the transport. - - Returns: - status (bool): Returns True if connected, False otherwise - """ diff --git a/src/blueapi/messaging/context.py b/src/blueapi/messaging/context.py deleted file mode 100644 index d202b700e..000000000 --- a/src/blueapi/messaging/context.py +++ /dev/null @@ -1,12 +0,0 @@ -from dataclasses import dataclass - - -@dataclass -class MessageContext: - """ - Context that comes with a message, provides useful information such as how to reply - """ - - destination: str - reply_destination: str | None - correlation_id: str | None diff --git a/src/blueapi/messaging/stomptemplate.py b/src/blueapi/messaging/stomptemplate.py deleted file mode 100644 index d535c0089..000000000 --- a/src/blueapi/messaging/stomptemplate.py +++ /dev/null @@ -1,249 +0,0 @@ -import itertools -import json -import logging -import time -import uuid -from collections.abc import Callable -from dataclasses import dataclass -from threading import Event -from typing import Any - -import orjson -import stomp -from pydantic import parse_obj_as -from stomp.exception import ConnectFailedException -from stomp.utils import Frame - -from blueapi.config import BasicAuthentication, StompConfig -from blueapi.utils import handle_all_exceptions, serialize - -from .base import DestinationProvider, MessageListener, MessagingTemplate -from .context import MessageContext -from .utils import determine_deserialization_type - -LOGGER = logging.getLogger(__name__) - -CORRELATION_ID_HEADER = "correlation-id" - - -class StompDestinationProvider(DestinationProvider): - """ - Destination provider for stomp, stateless so just - uses naming conventions - """ - - def queue(self, name: str) -> str: - return f"/queue/{name}" - - def topic(self, name: str) -> str: - return f"/topic/{name}" - - def temporary_queue(self, name: str) -> str: - return f"/temp-queue/{name}" - - default = queue - - -@dataclass -class StompReconnectPolicy: - """ - Details of how often stomp will try to reconnect if connection is unexpectedly lost - """ - - initial_delay: float = 0.0 - attempt_period: float = 10.0 - - -@dataclass -class Subscription: - """ - Details of a subscription, the template needs its own representation to - defer subscriptions until after connection - """ - - destination: str - callback: Callable[[Frame], None] - - -class StompMessagingTemplate(MessagingTemplate): - """ - MessagingTemplate that uses the stomp protocol, meant for use - with ActiveMQ. - """ - - _conn: stomp.Connection - _reconnect_policy: StompReconnectPolicy - _authentication: BasicAuthentication - _sub_num: itertools.count - _listener: stomp.ConnectionListener - _subscriptions: dict[str, Subscription] - _pending_subscriptions: set[str] - _disconnected: Event - - # Stateless implementation means attribute can be static - _destination_provider: DestinationProvider = StompDestinationProvider() - - def __init__( - self, - conn: stomp.Connection, - reconnect_policy: StompReconnectPolicy | None = None, - authentication: BasicAuthentication | None = None, - ) -> None: - self._conn = conn - self._reconnect_policy = reconnect_policy or StompReconnectPolicy() - self._authentication = authentication or BasicAuthentication() - - self._sub_num = itertools.count() - self._listener = stomp.ConnectionListener() - - self._listener.on_message = self._on_message - self._conn.set_listener("", self._listener) - - self._subscriptions = {} - - @classmethod - def autoconfigured(cls, config: StompConfig) -> MessagingTemplate: - return cls( - stomp.Connection( - [(config.host, config.port)], - auto_content_length=False, - ), - authentication=config.auth, - ) - - @property - def destinations(self) -> DestinationProvider: - return self._destination_provider - - def send( - self, - destination: str, - obj: Any, - on_reply: MessageListener | None = None, - correlation_id: str | None = None, - ) -> None: - self._send_str( - destination, - orjson.dumps(serialize(obj), option=orjson.OPT_SERIALIZE_NUMPY), - on_reply, - correlation_id, - ) - - def _send_str( - self, - destination: str, - message: bytes, - on_reply: MessageListener | None = None, - correlation_id: str | None = None, - ) -> None: - LOGGER.info(f"SENDING {message!r} to {destination}") - - headers: dict[str, Any] = {"JMSType": "TextMessage"} - if on_reply is not None: - reply_queue_name = self.destinations.temporary_queue(str(uuid.uuid1())) - headers = {**headers, "reply-to": reply_queue_name} - self.subscribe(reply_queue_name, on_reply) - if correlation_id: - headers = {**headers, CORRELATION_ID_HEADER: correlation_id} - self._conn.send(headers=headers, body=message, destination=destination) - - def subscribe(self, destination: str, callback: MessageListener) -> None: - LOGGER.debug(f"New subscription to {destination}") - obj_type = determine_deserialization_type(callback, default=str) - - def wrapper(frame: Frame) -> None: - as_dict = json.loads(frame.body) - value: Any = parse_obj_as(obj_type, as_dict) - - context = MessageContext( - frame.headers["destination"], - frame.headers.get("reply-to"), - frame.headers.get(CORRELATION_ID_HEADER), - ) - callback(context, value) - - sub_id = ( - destination - if destination.startswith("/temp-queue/") - else str(next(self._sub_num)) - ) - self._subscriptions[sub_id] = Subscription(destination, wrapper) - # If we're connected, subscribe immediately, otherwise the subscription is - # deferred until connection. - self._ensure_subscribed([sub_id]) - - def connect(self) -> None: - if self._conn.is_connected(): - return - - connected: Event = Event() - - def finished_connecting(_: Frame): - connected.set() - - self._listener.on_connected = finished_connecting - self._listener.on_disconnected = self._on_disconnected - - LOGGER.info("Connecting...") - - try: - self._conn.connect( - username=self._authentication.username, - passcode=self._authentication.passcode, - wait=True, - ) - connected.wait() - except ConnectFailedException as ex: - LOGGER.exception(msg="Failed to connect to message bus", exc_info=ex) - - self._ensure_subscribed() - - def _ensure_subscribed(self, sub_ids: list[str] | None = None) -> None: - # We must defer subscription until after connection, because stomp literally - # sends a SUB to the broker. But it still nice to be able to call subscribe - # on template before it connects, then just run the subscribes after connection. - if self._conn.is_connected(): - for sub_id in sub_ids or self._subscriptions.keys(): - sub = self._subscriptions[sub_id] - LOGGER.info(f"Subscribing to {sub.destination}") - self._conn.subscribe(destination=sub.destination, id=sub_id, ack="auto") - - def disconnect(self) -> None: - LOGGER.info("Disconnecting...") - if not self.is_connected(): - LOGGER.info("Already disconnected") - return - # We need to synchronise the disconnect on an event because the stomp Connection - # object doesn't do it for us - disconnected = Event() - self._listener.on_disconnected = disconnected.set - self._conn.disconnect() - disconnected.wait() - self._listener.on_disconnected = None - - @handle_all_exceptions - def _on_disconnected(self) -> None: - LOGGER.warn( - "Stomp connection lost, will attempt reconnection with " - f"policy {self._reconnect_policy}" - ) - time.sleep(self._reconnect_policy.initial_delay) - while not self._conn.is_connected(): - try: - self.connect() - except ConnectFailedException: - LOGGER.exception("Reconnect failed") - time.sleep(self._reconnect_policy.attempt_period) - - @handle_all_exceptions - def _on_message(self, frame: Frame) -> None: - LOGGER.info(f"Received {frame}") - sub_id = frame.headers.get("subscription") - sub = self._subscriptions.get(sub_id) - if sub is not None: - sub.callback(frame) - else: - LOGGER.warn(f"No subscription active for id: {sub_id}") - - def is_connected(self) -> bool: - return self._conn.is_connected() diff --git a/src/blueapi/messaging/utils.py b/src/blueapi/messaging/utils.py deleted file mode 100644 index 175005c65..000000000 --- a/src/blueapi/messaging/utils.py +++ /dev/null @@ -1,27 +0,0 @@ -import inspect - -from .base import MessageListener - - -def determine_deserialization_type( - listener: MessageListener, default: type = str -) -> type: - """ - Inspect a message listener function to determine the type to deserialize - a message to - - Args: - listener (MessageListener): The function that takes a deserialized message - default (Type, optional): If the type cannot be determined, what default - should we fall back on? Defaults to str. - - Returns: - Type: _description_ - """ - - _, message = inspect.signature(listener).parameters.values() - a_type = message.annotation - if a_type is not inspect.Parameter.empty: - return a_type - else: - return default diff --git a/src/blueapi/service/interface.py b/src/blueapi/service/interface.py index 55cf6018f..72d546831 100644 --- a/src/blueapi/service/interface.py +++ b/src/blueapi/service/interface.py @@ -3,11 +3,12 @@ from functools import lru_cache from typing import Any +from bluesky_stomp.messaging import MessagingTemplate +from bluesky_stomp.models import Broker, DestinationBase, MessageTopic + from blueapi.config import ApplicationConfig from blueapi.core.context import BlueskyContext from blueapi.core.event import EventStream -from blueapi.messaging.base import MessagingTemplate -from blueapi.messaging.stomptemplate import StompMessagingTemplate from blueapi.service.model import DeviceModel, PlanModel, WorkerTask from blueapi.worker.event import TaskStatusEnum, WorkerState from blueapi.worker.task import Task @@ -51,10 +52,14 @@ def worker() -> TaskWorker: def messaging_template() -> MessagingTemplate | None: stomp_config = config().stomp if stomp_config is not None: - template = StompMessagingTemplate.autoconfigured(stomp_config) + template = MessagingTemplate.for_broker( + broker=Broker( + host=stomp_config.host, port=stomp_config.port, auth=stomp_config.auth + ) + ) task_worker = worker() - event_topic = template.destinations.topic("public.worker.event") + event_topic = MessageTopic(name="public.worker.event") _publish_event_streams( { @@ -90,15 +95,17 @@ def teardown() -> None: messaging_template.cache_clear() -def _publish_event_streams(streams_to_destinations: Mapping[EventStream, str]) -> None: +def _publish_event_streams( + streams_to_destinations: Mapping[EventStream, DestinationBase], +) -> None: for stream, destination in streams_to_destinations.items(): _publish_event_stream(stream, destination) -def _publish_event_stream(stream: EventStream, destination: str) -> None: +def _publish_event_stream(stream: EventStream, destination: DestinationBase) -> None: def forward_message(event: Any, correlation_id: str | None) -> None: if (template := messaging_template()) is not None: - template.send(destination, event, None, correlation_id) + template.send(destination, event, None, correlation_id=correlation_id) stream.subscribe(forward_message) @@ -157,7 +164,7 @@ def get_worker_state() -> WorkerState: def pause_worker(defer: bool | None) -> None: """Command the worker to pause""" - worker().pause(defer) + worker().pause(defer or False) def resume_worker() -> None: diff --git a/src/blueapi/service/model.py b/src/blueapi/service/model.py index 8a7ffb899..8193a9d48 100644 --- a/src/blueapi/service/model.py +++ b/src/blueapi/service/model.py @@ -128,7 +128,7 @@ class StateChangeRequest(BlueapiBaseModel): """ new_state: WorkerState = Field() - defer: bool | None = Field( + defer: bool = Field( description="Should worker defer Pausing until the next checkpoint", default=False, ) diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 433e7115e..c2f23a626 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -2,12 +2,12 @@ from unittest.mock import MagicMock, Mock, call import pytest +from bluesky_stomp.messaging import MessageContext from blueapi.client.client import BlueapiClient from blueapi.client.event_bus import AnyEvent, BlueskyStreamingError, EventBusClient from blueapi.client.rest import BlueapiRestClient, BlueskyRemoteControlError from blueapi.core import DataEvent -from blueapi.messaging.context import MessageContext from blueapi.service.model import ( DeviceModel, DeviceResponse, @@ -307,9 +307,11 @@ def test_run_task_sets_up_control( ): mock_rest.create_task.return_value = TaskResponse(task_id="foo") mock_rest.update_worker_task.return_value = TaskResponse(task_id="foo") + ctx = Mock() + ctx.correlation_id = "foo" + mock_events.subscribe_to_all_events = lambda on_event: on_event(COMPLETE_EVENT, ctx) client_with_events.run_task(Task(name="foo")) - mock_rest.create_task.assert_called_once_with(Task(name="foo")) mock_rest.update_worker_task.assert_called_once_with(WorkerTask(task_id="foo")) @@ -324,7 +326,7 @@ def test_run_task_fails_on_failing_event( ctx = Mock() ctx.correlation_id = "foo" - mock_events.subscribe_to_all_events = lambda on_event: on_event(ctx, FAILED_EVENT) + mock_events.subscribe_to_all_events = lambda on_event: on_event(FAILED_EVENT, ctx) on_event = Mock() with pytest.raises(BlueskyStreamingError): @@ -360,9 +362,9 @@ def test_run_task_calls_event_callback( ctx = Mock() ctx.correlation_id = "foo" - def callback(on_event: Callable[[MessageContext, AnyEvent], None]): - on_event(ctx, test_event) - on_event(ctx, COMPLETE_EVENT) + def callback(on_event: Callable[[AnyEvent, MessageContext], None]): + on_event(test_event, ctx) + on_event(COMPLETE_EVENT, ctx) mock_events.subscribe_to_all_events = callback # type: ignore @@ -399,9 +401,9 @@ def test_run_task_ignores_non_matching_events( ctx = Mock() ctx.correlation_id = "foo" - def callback(on_event: Callable[[MessageContext, AnyEvent], None]): - on_event(ctx, test_event) - on_event(ctx, COMPLETE_EVENT) + def callback(on_event: Callable[[AnyEvent, MessageContext], None]): + on_event(test_event, ctx) # type: ignore + on_event(COMPLETE_EVENT, ctx) mock_events.subscribe_to_all_events = callback diff --git a/tests/client/test_event_bus.py b/tests/client/test_event_bus.py index 45aab501b..ff5eb3101 100644 --- a/tests/client/test_event_bus.py +++ b/tests/client/test_event_bus.py @@ -1,9 +1,9 @@ from unittest.mock import ANY, Mock import pytest +from bluesky_stomp.messaging import MessagingTemplate from blueapi.client.event_bus import BlueskyStreamingError, EventBusClient -from blueapi.messaging import MessagingTemplate @pytest.fixture diff --git a/tests/conftest.py b/tests/conftest.py index ed2caf1a8..838d4b219 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,27 +6,6 @@ from bluesky.run_engine import TransitionError -def pytest_addoption(parser): - parser.addoption( - "--skip-stomp", - action="store_true", - default=False, - help="skip stomp tests (e.g. because a server is unavailable)", - ) - - -def pytest_configure(config): - config.addinivalue_line("markers", "stomp: mark test as requiring stomp broker") - - -def pytest_collection_modifyitems(config, items): - if config.getoption("--skip-stomp"): - skip_stomp = pytest.mark.skip(reason="skipping stomp tests at user request") - for item in items: - if "stomp" in item.keywords: - item.add_marker(skip_stomp) - - @pytest.fixture(scope="function") def RE(request): loop = asyncio.new_event_loop() diff --git a/tests/messaging/__init__.py b/tests/messaging/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/messaging/test_stomptemplate.py b/tests/messaging/test_stomptemplate.py deleted file mode 100644 index a8c9750b3..000000000 --- a/tests/messaging/test_stomptemplate.py +++ /dev/null @@ -1,257 +0,0 @@ -import itertools -from collections.abc import Iterable -from concurrent.futures import Future -from queue import Queue -from typing import Any -from unittest.mock import ANY, MagicMock, call, patch - -import numpy as np -import pytest -from pydantic import BaseModel, Field -from pydantic_settings import BaseSettings -from stomp import Connection -from stomp.exception import ConnectFailedException, NotConnectedException - -from blueapi.config import StompConfig -from blueapi.messaging import MessageContext, MessagingTemplate, StompMessagingTemplate - -_TIMEOUT: float = 10.0 -_COUNT = itertools.count() - - -class StompTestingSettings(BaseSettings): - blueapi_test_stomp_ports: list[int] = Field(default=[61613]) - - def test_stomp_configs(self) -> Iterable[StompConfig]: - for port in self.blueapi_test_stomp_ports: - yield StompConfig(port=port) - - -@pytest.fixture(params=StompTestingSettings().test_stomp_configs()) -def disconnected_template(request: pytest.FixtureRequest) -> MessagingTemplate: - stomp_config = request.param - template = StompMessagingTemplate.autoconfigured(stomp_config) - assert template is not None - return template - - -@pytest.fixture(params=StompTestingSettings().test_stomp_configs()) -def template(request: pytest.FixtureRequest) -> Iterable[MessagingTemplate]: - stomp_config = request.param - template = StompMessagingTemplate.autoconfigured(stomp_config) - assert template is not None - template.connect() - yield template - template.disconnect() - - -@pytest.fixture -def test_queue(template: MessagingTemplate) -> str: - return template.destinations.queue(f"test-{next(_COUNT)}") - - -@pytest.fixture -def test_queue_2(template: MessagingTemplate) -> str: - return template.destinations.queue(f"test-{next(_COUNT)}") - - -@pytest.fixture -def test_topic(template: MessagingTemplate) -> str: - return template.destinations.topic(f"test-{next(_COUNT)}") - - -def test_disconnected_error(template: MessagingTemplate, test_queue: str) -> None: - acknowledge(template, test_queue) - - f: Future = Future() - - def callback(ctx: MessageContext, message: str) -> None: - f.set_result(message) - - if template.is_connected(): - template.disconnect() - with pytest.raises(NotConnectedException): - template.send(test_queue, "test_message", callback) - - with patch( - "blueapi.messaging.stomptemplate.LOGGER.info", autospec=True - ) as mock_logger: - template.disconnect() - assert not template.is_connected() - expected_calls = [ - call("Disconnecting..."), - call("Already disconnected"), - ] - mock_logger.assert_has_calls(expected_calls) - - -@pytest.mark.stomp -def test_send(template: MessagingTemplate, test_queue: str) -> None: - f: Future = Future() - - def callback(ctx: MessageContext, message: str) -> None: - f.set_result(message) - - template.subscribe(test_queue, callback) - template.send(test_queue, "test_message") - assert f.result(timeout=_TIMEOUT) - - -@pytest.mark.stomp -def test_send_to_topic(template: MessagingTemplate, test_topic: str) -> None: - f: Future = Future() - - def callback(ctx: MessageContext, message: str) -> None: - f.set_result(message) - - template.subscribe(test_topic, callback) - template.send(test_topic, "test_message") - assert f.result(timeout=_TIMEOUT) - - -@pytest.mark.stomp -def test_send_on_reply(template: MessagingTemplate, test_queue: str) -> None: - acknowledge(template, test_queue) - - f: Future = Future() - - def callback(ctx: MessageContext, message: str) -> None: - f.set_result(message) - - template.send(test_queue, "test_message", callback) - assert f.result(timeout=_TIMEOUT) - - -@pytest.mark.stomp -def test_send_and_receive(template: MessagingTemplate, test_queue: str) -> None: - acknowledge(template, test_queue) - reply = template.send_and_receive(test_queue, "test", str).result(timeout=_TIMEOUT) - assert reply == "ack" - - -@pytest.mark.stomp -def test_listener(template: MessagingTemplate, test_queue: str) -> None: - @template.listener(test_queue) - def server(ctx: MessageContext, message: str) -> None: - reply_queue = ctx.reply_destination - if reply_queue is None: - raise RuntimeError("reply queue is None") - template.send(reply_queue, "ack", correlation_id=ctx.correlation_id) - - reply = template.send_and_receive(test_queue, "test", str).result(timeout=_TIMEOUT) - assert reply == "ack" - - -class Foo(BaseModel): - a: int - b: str - - -@pytest.mark.stomp -@pytest.mark.parametrize( - "message,message_type", - [ - ("test", str), - (1, int), - (Foo(a=1, b="test"), Foo), - (np.array([1, 2, 3]), list), - ], -) -def test_deserialization( - template: MessagingTemplate, test_queue: str, message: Any, message_type: type -) -> None: - def server(ctx: MessageContext, message: message_type) -> None: # type: ignore - reply_queue = ctx.reply_destination - if reply_queue is None: - raise RuntimeError("reply queue is None") - template.send(reply_queue, message, correlation_id=ctx.correlation_id) - - template.subscribe(test_queue, server) - reply = template.send_and_receive(test_queue, message, message_type).result( - timeout=_TIMEOUT - ) - if type(message) is np.ndarray: - message = message.tolist() - assert reply == message - - -@pytest.mark.stomp -def test_subscribe_before_connect( - disconnected_template: MessagingTemplate, test_queue: str -) -> None: - acknowledge(disconnected_template, test_queue) - disconnected_template.connect() - reply = disconnected_template.send_and_receive(test_queue, "test", str).result( - timeout=_TIMEOUT - ) - assert reply == "ack" - - -@pytest.mark.stomp -def test_reconnect(template: MessagingTemplate, test_queue: str) -> None: - acknowledge(template, test_queue) - reply = template.send_and_receive(test_queue, "test", str).result(timeout=_TIMEOUT) - assert reply == "ack" - template.disconnect() - assert not template.is_connected() - template.connect() - assert template.is_connected() - reply = template.send_and_receive(test_queue, "test", str).result(timeout=_TIMEOUT) - assert reply == "ack" - - -@pytest.fixture() -def failing_template() -> MessagingTemplate: - def connection_exception(*args, **kwargs): - raise ConnectFailedException - - connection = Connection() - connection.connect = MagicMock(side_effect=connection_exception) - return StompMessagingTemplate(connection) - - -@pytest.mark.stomp -def test_failed_connect(failing_template: MessagingTemplate, test_queue: str) -> None: - assert not failing_template.is_connected() - with patch( - "blueapi.messaging.stomptemplate.LOGGER.error", autospec=True - ) as mock_logger: - failing_template.connect() - assert not failing_template.is_connected() - mock_logger.assert_called_once_with( - "Failed to connect to message bus", exc_info=ANY - ) - - -@pytest.mark.stomp -def test_correlation_id( - template: MessagingTemplate, test_queue: str, test_queue_2: str -) -> None: - correlation_id = "foobar" - q: Queue = Queue() - - def server(ctx: MessageContext, msg: str) -> None: - q.put(ctx) - template.send(test_queue_2, msg, correlation_id=ctx.correlation_id) - - def client(ctx: MessageContext, msg: str) -> None: - q.put(ctx) - - template.subscribe(test_queue, server) - template.subscribe(test_queue_2, client) - template.send(test_queue, "test", None, correlation_id) - - ctx_req: MessageContext = q.get(timeout=_TIMEOUT) - assert ctx_req.correlation_id == correlation_id - ctx_ack: MessageContext = q.get(timeout=_TIMEOUT) - assert ctx_ack.correlation_id == correlation_id - - -def acknowledge(template: MessagingTemplate, destination: str) -> None: - def server(ctx: MessageContext, message: str) -> None: - reply_queue = ctx.reply_destination - if reply_queue is None: - raise RuntimeError("reply queue is None") - template.send(reply_queue, "ack", correlation_id=ctx.correlation_id) - - template.subscribe(destination, server) diff --git a/tests/messaging/test_utils.py b/tests/messaging/test_utils.py deleted file mode 100644 index feafd2afa..000000000 --- a/tests/messaging/test_utils.py +++ /dev/null @@ -1,34 +0,0 @@ -from collections.abc import Mapping -from dataclasses import dataclass -from typing import Any - -import pytest - -from blueapi.messaging.utils import determine_deserialization_type - - -@dataclass -class Foo: - bar: int - baz: str - - -def test_determine_deserialization_type() -> None: - def on_message(headers: Mapping[str, Any], message: Foo) -> None: ... - - deserialization_type = determine_deserialization_type(on_message) # type: ignore - assert deserialization_type is Foo - - -def test_determine_deserialization_type_with_no_type() -> None: - def on_message(headers: Mapping[str, Any], message) -> None: ... - - deserialization_type = determine_deserialization_type(on_message) # type: ignore - assert deserialization_type is str - - -def test_determine_deserialization_type_with_wrong_signature() -> None: - def on_message(message: Foo) -> None: ... - - with pytest.raises(ValueError): - determine_deserialization_type(on_message) # type: ignore diff --git a/tests/service/test_interface.py b/tests/service/test_interface.py index 6a08eafc2..002d2d308 100644 --- a/tests/service/test_interface.py +++ b/tests/service/test_interface.py @@ -1,9 +1,11 @@ import uuid from dataclasses import dataclass -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, Mock, patch import pytest +from bluesky_stomp.messaging import MessagingTemplate from ophyd.sim import SynAxis +from stomp.connect import StompConnection11 as Connection from blueapi.config import ApplicationConfig, StompConfig from blueapi.core import MsgGenerator @@ -15,6 +17,18 @@ from blueapi.worker.task_worker import TrackableTask +@pytest.fixture +def mock_connection() -> Mock: + return Mock(spec=Connection) + + +@pytest.fixture +def template(mock_connection: Mock) -> MessagingTemplate: + template = MessagingTemplate(conn=mock_connection) + template.disconnect = MagicMock() + return template + + @pytest.fixture(autouse=True) def ensure_worker_stopped(): """This saves every test having to call this at the end. @@ -269,7 +283,9 @@ def test_get_task_by_id(context_mock: MagicMock): ) -@pytest.mark.stomp -def test_stomp_config(): - interface.set_config(ApplicationConfig(stomp=StompConfig())) - assert interface.messaging_template() is not None +def test_stomp_config(template: MessagingTemplate): + with patch( + "blueapi.service.interface.MessagingTemplate.for_broker", return_value=template + ): + interface.set_config(ApplicationConfig(stomp=StompConfig())) + assert interface.messaging_template() is not None diff --git a/tests/test_cli.py b/tests/test_cli.py index ab725c005..8c75cced7 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -8,10 +8,12 @@ import pytest import responses +from bluesky_stomp.messaging import MessagingTemplate from click.testing import CliRunner from pydantic import BaseModel, ValidationError from requests.exceptions import ConnectionError from responses import matchers +from stomp.connect import StompConnection11 as Connection from blueapi import __version__ from blueapi.cli.cli import main @@ -28,6 +30,16 @@ ) +@pytest.fixture +def mock_connection() -> Mock: + return Mock(spec=Connection) + + +@pytest.fixture +def template(mock_connection: Mock) -> MessagingTemplate: + return MessagingTemplate(conn=mock_connection) + + @pytest.fixture def runner(): return CliRunner() @@ -136,8 +148,13 @@ def test_cannot_run_plans_without_stomp_config(runner: CliRunner): ) -@pytest.mark.stomp -def test_valid_stomp_config_for_listener(runner: CliRunner): +@patch("blueapi.cli.cli.MessagingTemplate") +def test_valid_stomp_config_for_listener( + template: MessagingTemplate, + runner: CliRunner, + mock_connection: Mock, +): + mock_connection.is_connected.return_value = True result = runner.invoke( main, [