From 0a634dc34b2e8683d449a3faba9b146431de1a77 Mon Sep 17 00:00:00 2001 From: DiamondJoseph <53935796+DiamondJoseph@users.noreply.github.com> Date: Mon, 29 Apr 2024 15:59:24 +0100 Subject: [PATCH 1/7] Remove DirectoryProvider and pre-processor handling to Dodal (#376) --- pyproject.toml | 2 +- src/blueapi/config.py | 7 - src/blueapi/core/context.py | 14 +- src/blueapi/data_management/__init__.py | 0 .../visit_directory_provider.py | 127 ------ src/blueapi/preprocessors/attach_metadata.py | 41 -- src/blueapi/service/handler.py | 40 -- src/blueapi/worker/task.py | 4 +- .../test_visit_directory_provider.py | 66 --- tests/preprocessors/__init__.py | 0 tests/preprocessors/test_attach_metadata.py | 399 ------------------ 11 files changed, 3 insertions(+), 697 deletions(-) delete mode 100644 src/blueapi/data_management/__init__.py delete mode 100644 src/blueapi/data_management/visit_directory_provider.py delete mode 100644 src/blueapi/preprocessors/attach_metadata.py delete mode 100644 tests/data_management/test_visit_directory_provider.py delete mode 100644 tests/preprocessors/__init__.py delete mode 100644 tests/preprocessors/test_attach_metadata.py diff --git a/pyproject.toml b/pyproject.toml index cc0f65aeb..393a0f459 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ dependencies = [ "uvicorn", "requests", "dls-bluesky-core", #requires ophyd-async - "dls-dodal<1.21", + "dls-dodal", ] dynamic = ["version"] license.file = "LICENSE" diff --git a/src/blueapi/config.py b/src/blueapi/config.py index d4376fb4e..4196b8123 100644 --- a/src/blueapi/config.py +++ b/src/blueapi/config.py @@ -51,12 +51,6 @@ class StompConfig(BaseModel): auth: BasicAuthentication | None = None -class DataWritingConfig(BlueapiBaseModel): - visit_service_url: str | None = None # e.g. "http://localhost:8088/api" - visit_directory: Path = Path("/tmp/0-0") - group_name: str = "example" - - class WorkerEventConfig(BlueapiBaseModel): """ Config for event broadcasting via the message bus @@ -78,7 +72,6 @@ class EnvironmentConfig(BlueapiBaseModel): Source(kind=SourceKind.PLAN_FUNCTIONS, module="dls_bluesky_core.plans"), Source(kind=SourceKind.PLAN_FUNCTIONS, module="dls_bluesky_core.stubs"), ] - data_writing: DataWritingConfig = Field(default_factory=DataWritingConfig) events: WorkerEventConfig = Field(default_factory=WorkerEventConfig) diff --git a/src/blueapi/core/context.py b/src/blueapi/core/context.py index a6d5ba2ff..98493a913 100644 --- a/src/blueapi/core/context.py +++ b/src/blueapi/core/context.py @@ -1,6 +1,5 @@ -import functools import logging -from collections.abc import Callable, Sequence +from collections.abc import Callable from dataclasses import dataclass, field from importlib import import_module from inspect import Parameter, signature @@ -22,10 +21,8 @@ BLUESKY_PROTOCOLS, Device, HasName, - MsgGenerator, Plan, PlanGenerator, - PlanWrapper, is_bluesky_compatible_device, is_bluesky_plan_generator, ) @@ -45,7 +42,6 @@ class BlueskyContext: run_engine: RunEngine = field( default_factory=lambda: RunEngine(context_managers=[]) ) - plan_wrappers: Sequence[PlanWrapper] = field(default_factory=list) plans: dict[str, Plan] = field(default_factory=dict) devices: dict[str, Device] = field(default_factory=dict) plan_functions: dict[str, PlanGenerator] = field(default_factory=dict) @@ -53,14 +49,6 @@ class BlueskyContext: _reference_cache: dict[type, type] = field(default_factory=dict) - def wrap(self, plan: MsgGenerator) -> MsgGenerator: - wrapped_plan = functools.reduce( - lambda wrapped, next_wrapper: next_wrapper(wrapped), - self.plan_wrappers, - plan, - ) - yield from wrapped_plan - def find_device(self, addr: str | list[str]) -> Device | None: """ Find a device in this context, allows for recursive search. diff --git a/src/blueapi/data_management/__init__.py b/src/blueapi/data_management/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/blueapi/data_management/visit_directory_provider.py b/src/blueapi/data_management/visit_directory_provider.py deleted file mode 100644 index bc05040b7..000000000 --- a/src/blueapi/data_management/visit_directory_provider.py +++ /dev/null @@ -1,127 +0,0 @@ -import logging -from abc import ABC, abstractmethod -from pathlib import Path - -from aiohttp import ClientSession -from ophyd_async.core import DirectoryInfo, DirectoryProvider -from pydantic import BaseModel - - -class DataCollectionIdentifier(BaseModel): - collectionNumber: int - - -class VisitServiceClientBase(ABC): - """ - Object responsible for I/O in determining collection number - """ - - @abstractmethod - async def create_new_collection(self) -> DataCollectionIdentifier: - """Create new collection""" - - @abstractmethod - async def get_current_collection(self) -> DataCollectionIdentifier: - """Get current collection""" - - -class VisitServiceClient(VisitServiceClientBase): - _url: str - - def __init__(self, url: str) -> None: - self._url = url - - async def create_new_collection(self) -> DataCollectionIdentifier: - async with ClientSession() as session: - async with session.post(f"{self._url}/numtracker") as response: - if response.status == 200: - json = await response.json() - return DataCollectionIdentifier.parse_obj(json) - else: - raise Exception(response.status) - - async def get_current_collection(self) -> DataCollectionIdentifier: - async with ClientSession() as session: - async with session.get(f"{self._url}/numtracker") as response: - if response.status == 200: - json = await response.json() - return DataCollectionIdentifier.parse_obj(json) - else: - raise Exception(response.status) - - -class LocalVisitServiceClient(VisitServiceClientBase): - _count: int - - def __init__(self) -> None: - self._count = 0 - - async def create_new_collection(self) -> DataCollectionIdentifier: - self._count += 1 - return DataCollectionIdentifier(collectionNumber=self._count) - - async def get_current_collection(self) -> DataCollectionIdentifier: - return DataCollectionIdentifier(collectionNumber=self._count) - - -class VisitDirectoryProvider(DirectoryProvider): - """ - Gets information from a remote service to construct the path that detectors - should write to, and determine how their files should be named. - """ - - _data_group_name: str - _data_directory: Path - - _client: VisitServiceClientBase - _current_collection: DirectoryInfo | None - _session: ClientSession | None - - def __init__( - self, - data_group_name: str, - data_directory: Path, - client: VisitServiceClientBase, - ): - self._data_group_name = data_group_name - self._data_directory = data_directory - self._client = client - - self._current_collection = None - self._session = None - - async def update(self) -> None: - """ - Calls the visit service to create a new data collection in the current visit. - """ - # TODO: After visit service is more feature complete: - # TODO: Allow selecting visit as part of the request to BlueAPI - # TODO: Consume visit information from BlueAPI and pass down to this class - # TODO: Query visit service to get information about visit and data collection - # TODO: Use AuthN information as part of verification with visit service - - try: - collection_id_info = await self._client.create_new_collection() - self._current_collection = self._generate_directory_info(collection_id_info) - except Exception as ex: - # TODO: The catch all is needed because the RunEngine will not - # currently handle it, see - # https://github.com/bluesky/bluesky/pull/1623 - self._current_collection = None - logging.exception(ex) - - def _generate_directory_info( - self, - collection_id_info: DataCollectionIdentifier, - ) -> DirectoryInfo: - collection_id = collection_id_info.collectionNumber - file_prefix = f"{self._data_group_name}-{collection_id}" - return DirectoryInfo(str(self._data_directory), file_prefix) - - def __call__(self) -> DirectoryInfo: - if self._current_collection is not None: - return self._current_collection - else: - raise ValueError( - "No current collection, update() needs to be called at least once" - ) diff --git a/src/blueapi/preprocessors/attach_metadata.py b/src/blueapi/preprocessors/attach_metadata.py deleted file mode 100644 index 21d9ed8b4..000000000 --- a/src/blueapi/preprocessors/attach_metadata.py +++ /dev/null @@ -1,41 +0,0 @@ -import bluesky.plan_stubs as bps -import bluesky.preprocessors as bpp -from bluesky.utils import make_decorator - -from blueapi.core import MsgGenerator -from blueapi.data_management.visit_directory_provider import VisitDirectoryProvider - -DATA_SESSION = "data_session" -DATA_GROUPS = "data_groups" - - -def attach_metadata( - plan: MsgGenerator, - provider: VisitDirectoryProvider, -) -> MsgGenerator: - """ - Attach data session metadata to the runs within a plan and make it correlate - with an ophyd-async DirectoryProvider. - - This updates the directory provider (which in turn makes a call to to a service - to figure out which scan number we are using for such a scan), and ensures the - start document contains the correct data session. - - Args: - plan: The plan to preprocess - provider: The directory provider that participating detectors are aware of. - - Returns: - MsgGenerator: A plan - - Yields: - Iterator[Msg]: Plan messages - """ - yield from bps.wait_for([provider.update]) - directory_info = provider() - yield from bpp.inject_md_wrapper( - plan, md={DATA_SESSION: directory_info.filename_prefix} - ) - - -attach_metadata_decorator = make_decorator(attach_metadata) diff --git a/src/blueapi/service/handler.py b/src/blueapi/service/handler.py index 0280c5d3d..afa9818c1 100644 --- a/src/blueapi/service/handler.py +++ b/src/blueapi/service/handler.py @@ -5,15 +5,8 @@ from blueapi.config import ApplicationConfig from blueapi.core import BlueskyContext from blueapi.core.event import EventStream -from blueapi.data_management.visit_directory_provider import ( - LocalVisitServiceClient, - VisitDirectoryProvider, - VisitServiceClient, - VisitServiceClientBase, -) from blueapi.messaging import StompMessagingTemplate from blueapi.messaging.base import MessagingTemplate -from blueapi.preprocessors.attach_metadata import attach_metadata from blueapi.service.handler_base import BlueskyHandler from blueapi.service.model import DeviceModel, PlanModel, WorkerTask from blueapi.worker.event import WorkerState @@ -159,42 +152,9 @@ def setup_handler( ) -> None: global HANDLER - provider = None - plan_wrappers = [] - if config: - visit_service_client: VisitServiceClientBase - if config.env.data_writing.visit_service_url is not None: - visit_service_client = VisitServiceClient( - config.env.data_writing.visit_service_url - ) - else: - visit_service_client = LocalVisitServiceClient() - - provider = VisitDirectoryProvider( - data_group_name=config.env.data_writing.group_name, - data_directory=config.env.data_writing.visit_directory, - client=visit_service_client, - ) - - # Make all dodal devices created by the context use provider if they can - try: - from dodal.parameters.gda_directory_provider import ( - set_directory_provider_singleton, - ) - - set_directory_provider_singleton(provider) - except ImportError: - logging.error( - "Unable to set directory provider for ophyd-async devices, " - "a newer version of dodal is required" - ) - - plan_wrappers.append(lambda plan: attach_metadata(plan, provider)) - handler = Handler( config, context=BlueskyContext( - plan_wrappers=plan_wrappers, sim=False, ), ) diff --git a/src/blueapi/worker/task.py b/src/blueapi/worker/task.py index f080be875..1e48cb4d8 100644 --- a/src/blueapi/worker/task.py +++ b/src/blueapi/worker/task.py @@ -28,9 +28,7 @@ def do_task(self, ctx: BlueskyContext) -> None: func = ctx.plan_functions[self.name] prepared_params = self.prepare_params(ctx) - plan_generator = func(**prepared_params.dict()) - wrapped_plan_generator = ctx.wrap(plan_generator) - ctx.run_engine(wrapped_plan_generator) + ctx.run_engine(func(**prepared_params.dict())) def _lookup_params(ctx: BlueskyContext, task: Task) -> BaseModel: diff --git a/tests/data_management/test_visit_directory_provider.py b/tests/data_management/test_visit_directory_provider.py deleted file mode 100644 index 57d93d0ef..000000000 --- a/tests/data_management/test_visit_directory_provider.py +++ /dev/null @@ -1,66 +0,0 @@ -from pathlib import Path - -import pytest -from ophyd_async.core import DirectoryInfo - -from blueapi.data_management.visit_directory_provider import ( - DataCollectionIdentifier, - LocalVisitServiceClient, - VisitDirectoryProvider, - VisitServiceClientBase, -) - - -@pytest.fixture -def visit_service_client() -> VisitServiceClientBase: - return LocalVisitServiceClient() - - -@pytest.fixture -def visit_directory_provider( - visit_service_client: VisitServiceClientBase, -) -> VisitDirectoryProvider: - return VisitDirectoryProvider("example", Path("/tmp"), visit_service_client) - - -@pytest.mark.asyncio -async def test_client_can_view_collection( - visit_service_client: VisitServiceClientBase, -) -> None: - collection = await visit_service_client.get_current_collection() - assert collection == DataCollectionIdentifier(collectionNumber=0) - - -@pytest.mark.asyncio -async def test_client_can_create_collection( - visit_service_client: VisitServiceClientBase, -) -> None: - collection = await visit_service_client.create_new_collection() - assert collection == DataCollectionIdentifier(collectionNumber=1) - - -@pytest.mark.asyncio -async def test_update_sets_collection_number( - visit_directory_provider: VisitDirectoryProvider, -) -> None: - await visit_directory_provider.update() - assert visit_directory_provider() == DirectoryInfo( - directory_path="/tmp", - filename_prefix="example-1", - ) - - -@pytest.mark.asyncio -async def test_update_sets_collection_number_multi( - visit_directory_provider: VisitDirectoryProvider, -) -> None: - await visit_directory_provider.update() - assert visit_directory_provider() == DirectoryInfo( - directory_path="/tmp", - filename_prefix="example-1", - ) - await visit_directory_provider.update() - assert visit_directory_provider() == DirectoryInfo( - directory_path="/tmp", - filename_prefix="example-2", - ) diff --git a/tests/preprocessors/__init__.py b/tests/preprocessors/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/preprocessors/test_attach_metadata.py b/tests/preprocessors/test_attach_metadata.py deleted file mode 100644 index 8f879fc26..000000000 --- a/tests/preprocessors/test_attach_metadata.py +++ /dev/null @@ -1,399 +0,0 @@ -from collections.abc import Callable, Mapping -from pathlib import Path -from typing import Any - -import bluesky.plan_stubs as bps -import bluesky.plans as bp -import pytest -from bluesky import RunEngine -from bluesky.preprocessors import ( - run_decorator, - run_wrapper, - set_run_key_decorator, - set_run_key_wrapper, - stage_wrapper, -) -from bluesky.protocols import HasName, Readable, Reading, Status, Triggerable -from event_model.documents.event_descriptor import DataKey -from ophyd.status import StatusBase -from ophyd_async.core import DirectoryProvider - -from blueapi.core import DataEvent, MsgGenerator -from blueapi.data_management.visit_directory_provider import ( - DataCollectionIdentifier, - VisitDirectoryProvider, - VisitServiceClient, -) -from blueapi.preprocessors.attach_metadata import DATA_SESSION, attach_metadata - -DATA_DIRECTORY = Path("/tmp") -DATA_GROUP_NAME = "test" - - -RUN_0 = DATA_DIRECTORY / f"{DATA_GROUP_NAME}-0" -RUN_1 = DATA_DIRECTORY / f"{DATA_GROUP_NAME}-1" -RUN_2 = DATA_DIRECTORY / f"{DATA_GROUP_NAME}-2" - - -class MockVisitServiceClient(VisitServiceClient): - _count: int - _fail: bool - - def __init__(self) -> None: - super().__init__("http://example.com") - self._count = 0 - self._fail = False - - def always_fail(self) -> None: - self._fail = True - - async def create_new_collection(self) -> DataCollectionIdentifier: - if self._fail: - raise ConnectionError() - - count = self._count - self._count += 1 - return DataCollectionIdentifier(collectionNumber=count) - - async def get_current_collection(self) -> DataCollectionIdentifier: - if self._fail: - raise ConnectionError() - - return DataCollectionIdentifier(collectionNumber=self._count) - - -@pytest.fixture -def client() -> VisitServiceClient: - return MockVisitServiceClient() - - -@pytest.fixture -def provider(client: VisitServiceClient) -> VisitDirectoryProvider: - return VisitDirectoryProvider( - data_directory=DATA_DIRECTORY, - data_group_name=DATA_GROUP_NAME, - client=client, - ) - - -@pytest.fixture -def run_engine() -> RunEngine: - return RunEngine() - - -class FakeDetector(Readable, HasName, Triggerable): - _name: str - _provider: DirectoryProvider - - def __init__( - self, - name: str, - provider: DirectoryProvider, - ) -> None: - self._name = name - self._provider = provider - - async def read(self) -> dict[str, Reading]: - return { - f"{self.name}_data": { - "value": "test", - "timestamp": 0.0, - }, - } - - async def describe(self) -> dict[str, DataKey]: - directory_info = self._provider() - path = f"{directory_info.directory_path}/{directory_info.filename_prefix}" - return { - f"{self.name}_data": { - "dtype": "string", - "shape": [1], - "source": path, - } - } - - def trigger(self) -> Status: - status = StatusBase() - status.set_finished() - return status - - @property - def name(self) -> str: - return self._name - - @property - def parent(self) -> None: - return None - - -@pytest.fixture(params=[1, 2]) -def detectors(request, provider: VisitDirectoryProvider) -> list[Readable]: - number_of_detectors = request.param - return [ - FakeDetector( - name=f"test_detector_{i}", - provider=provider, - ) - for i in range(number_of_detectors) - ] - - -def simple_run(detectors: list[Readable]) -> MsgGenerator: - yield from bp.count(detectors) - - -def multi_run(detectors: list[Readable]) -> MsgGenerator: - yield from bp.count(detectors) - yield from bp.count(detectors) - - -def multi_nested_plan(detectors: list[Readable]) -> MsgGenerator: - yield from simple_run(detectors) - yield from simple_run(detectors) - - -def multi_run_single_stage(detectors: list[Readable]) -> MsgGenerator: - def stageless_count() -> MsgGenerator: - return (yield from bps.one_shot(detectors)) - - def inner_plan() -> MsgGenerator: - yield from run_wrapper(stageless_count()) - yield from run_wrapper(stageless_count()) - - yield from stage_wrapper(inner_plan(), detectors) - - -def multi_run_single_stage_multi_group( - detectors: list[Readable], -) -> MsgGenerator: - def stageless_count() -> MsgGenerator: - return (yield from bps.one_shot(detectors)) - - def inner_plan() -> MsgGenerator: - yield from run_wrapper(stageless_count(), md={DATA_SESSION: 1}) - yield from run_wrapper(stageless_count(), md={DATA_SESSION: 1}) - yield from run_wrapper(stageless_count(), md={DATA_SESSION: 2}) - yield from run_wrapper(stageless_count(), md={DATA_SESSION: 2}) - - yield from stage_wrapper(inner_plan(), detectors) - - -@run_decorator(md={DATA_SESSION: 12345}) -@set_run_key_decorator("outer") -def nested_run_with_metadata(detectors: list[Readable]) -> MsgGenerator: - yield from set_run_key_wrapper(bp.count(detectors), "inner") - yield from set_run_key_wrapper(bp.count(detectors), "inner") - - -@run_decorator() -@set_run_key_decorator("outer") -def nested_run_without_metadata( - detectors: list[Readable], -) -> MsgGenerator: - yield from set_run_key_wrapper(bp.count(detectors), "inner") - yield from set_run_key_wrapper(bp.count(detectors), "inner") - - -def test_simple_run_gets_scan_number( - run_engine: RunEngine, - detectors: list[Readable], - provider: DirectoryProvider, -) -> None: - docs = collect_docs( - run_engine, - simple_run(detectors), - provider, - ) - assert docs[0].name == "start" - assert docs[0].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" - assert_all_detectors_used_collection_numbers(docs, detectors, [RUN_0]) - - -@pytest.mark.parametrize("plan", [multi_run, multi_nested_plan]) -def test_multi_run_gets_scan_numbers( - run_engine: RunEngine, - detectors: list[Readable], - plan: Callable[[list[Readable]], MsgGenerator], - provider: DirectoryProvider, -) -> None: - """Test is here to demonstrate that multi run plans will overwrite files.""" - docs = collect_docs( - run_engine, - plan(detectors), - provider, - ) - start_docs = find_start_docs(docs) - assert len(start_docs) == 2 - assert start_docs[0].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" - assert start_docs[1].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" - assert_all_detectors_used_collection_numbers(docs, detectors, [RUN_0, RUN_0]) - - -def test_multi_run_single_stage( - run_engine: RunEngine, - detectors: list[Readable], - provider: DirectoryProvider, -) -> None: - docs = collect_docs( - run_engine, - multi_run_single_stage(detectors), - provider, - ) - start_docs = find_start_docs(docs) - assert len(start_docs) == 2 - assert start_docs[0].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" - assert start_docs[1].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" - assert_all_detectors_used_collection_numbers( - docs, - detectors, - [ - RUN_0, - RUN_0, - ], - ) - - -def test_multi_run_single_stage_multi_group( - run_engine: RunEngine, - detectors: list[Readable], - provider: DirectoryProvider, -) -> None: - docs = collect_docs( - run_engine, - multi_run_single_stage_multi_group(detectors), - provider, - ) - start_docs = find_start_docs(docs) - assert len(start_docs) == 4 - assert start_docs[0].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" - assert start_docs[1].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" - assert start_docs[2].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" - assert start_docs[3].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" - assert_all_detectors_used_collection_numbers( - docs, - detectors, - [ - RUN_0, - RUN_0, - RUN_0, - RUN_0, - ], - ) - - -def test_nested_run_with_metadata( - run_engine: RunEngine, - detectors: list[Readable], - provider: DirectoryProvider, -) -> None: - """Test is here to demonstrate that nested runs will be treated as a single run. - - That means detectors in such runs will overwrite files. - """ - docs = collect_docs( - run_engine, - nested_run_with_metadata(detectors), - provider, - ) - start_docs = find_start_docs(docs) - assert len(start_docs) == 3 - assert start_docs[0].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" - assert start_docs[1].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" - assert start_docs[2].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" - assert_all_detectors_used_collection_numbers(docs, detectors, [RUN_0, RUN_0]) - - -def test_nested_run_without_metadata( - run_engine: RunEngine, - detectors: list[Readable], - provider: DirectoryProvider, -) -> None: - """Test is here to demonstrate that nested runs will be treated as a single run. - - That means detectors in such runs will overwrite files. - """ - docs = collect_docs( - run_engine, - nested_run_without_metadata(detectors), - provider, - ) - start_docs = find_start_docs(docs) - assert len(start_docs) == 3 - assert start_docs[0].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" - assert start_docs[1].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" - assert start_docs[2].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" - assert_all_detectors_used_collection_numbers(docs, detectors, [RUN_0, RUN_0]) - - -def test_visit_directory_provider_fails( - run_engine: RunEngine, - detectors: list[Readable], - provider: DirectoryProvider, - client: MockVisitServiceClient, -) -> None: - client.always_fail() - with pytest.raises(ValueError): - collect_docs( - run_engine, - simple_run(detectors), - provider, - ) - - -def test_visit_directory_provider_fails_after_one_sucess( - run_engine: RunEngine, - detectors: list[Readable], - provider: DirectoryProvider, - client: MockVisitServiceClient, -) -> None: - collect_docs( - run_engine, - simple_run(detectors), - provider, - ) - client.always_fail() - with pytest.raises(ValueError): - collect_docs( - run_engine, - simple_run(detectors), - provider, - ) - - -def collect_docs( - run_engine: RunEngine, - plan: MsgGenerator, - provider: DirectoryProvider, -) -> list[DataEvent]: - events = [] - - def on_event(name: str, doc: Mapping[str, Any]) -> None: - events.append(DataEvent(name=name, doc=doc)) - - wrapped_plan = attach_metadata(plan, provider) - run_engine(wrapped_plan, on_event) - return events - - -def assert_all_detectors_used_collection_numbers( - docs: list[DataEvent], - detectors: list[Readable], - source_history: list[Path], -) -> None: - descriptors = find_descriptor_docs(docs) - assert len(descriptors) == len(source_history) - - for descriptor, expected_source in zip(descriptors, source_history, strict=False): - for detector in detectors: - source = descriptor.doc.get("data_keys", {}).get(f"{detector.name}_data")[ - "source" - ] - assert Path(source) == expected_source - - -def find_start_docs(docs: list[DataEvent]) -> list[DataEvent]: - return list(filter(lambda event: event.name == "start", docs)) - - -def find_descriptor_docs(docs: list[DataEvent]) -> list[DataEvent]: - return list(filter(lambda event: event.name == "descriptor", docs)) From 25e90280951f28ebcf27681fcef144605856525b Mon Sep 17 00:00:00 2001 From: Joe Shannon Date: Wed, 1 May 2024 17:09:43 +0100 Subject: [PATCH 2/7] Install git in runtime container (#446) This is required for installing python packages at runtime. It was previously present in the non-slim python container used before but this was accidentally removed with the switch to the copier template. Fixes #445. --- Dockerfile | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/Dockerfile b/Dockerfile index 94b4c3c67..2897aff71 100644 --- a/Dockerfile +++ b/Dockerfile @@ -21,6 +21,10 @@ RUN pip install . # The runtime stage copies the built venv into a slim runtime container FROM python:${PYTHON_VERSION}-slim as runtime # Add apt-get system dependecies for runtime here if needed +RUN apt-get update && apt-get install -y --no-install-recommends \ + # Git required for installing packages at runtime + git \ + && rm -rf /var/lib/apt/lists/* COPY --from=build /venv/ /venv/ COPY ./container-startup.sh /container-startup.sh ENV PATH=/venv/bin:$PATH From 9964a6527194567d188ba3875706230b3655c051 Mon Sep 17 00:00:00 2001 From: Joe Shannon Date: Thu, 2 May 2024 14:52:24 +0100 Subject: [PATCH 3/7] Auto restart deployment on config change (#450) Add new restartOnConfigChange property. This will cause the deployment to be restarted if the config is changed, e.g. the set of plans or device modules are updated, when running helm upgrade. Use a variation of approach at: https://helm.sh/docs/howto/charts_tips_and_tricks/#automatically-roll-deployments Fixes 451. --- helm/blueapi/templates/deployment.yaml | 5 ++++- helm/blueapi/values.yaml | 2 ++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/helm/blueapi/templates/deployment.yaml b/helm/blueapi/templates/deployment.yaml index fdcbb7342..c211c26c6 100644 --- a/helm/blueapi/templates/deployment.yaml +++ b/helm/blueapi/templates/deployment.yaml @@ -11,8 +11,11 @@ spec: {{- include "blueapi.selectorLabels" . | nindent 6 }} template: metadata: - {{- with .Values.podAnnotations }} annotations: + {{- if .Values.restartOnConfigChange }} + checksum/config: {{ include (print $.Template.BasePath "/configmap.yaml") . | sha256sum }} + {{- end }} + {{- with .Values.podAnnotations }} {{- toYaml . | nindent 8 }} {{- end }} labels: diff --git a/helm/blueapi/values.yaml b/helm/blueapi/values.yaml index 240230b43..2e3788556 100644 --- a/helm/blueapi/values.yaml +++ b/helm/blueapi/values.yaml @@ -68,6 +68,8 @@ affinity: {} hostNetwork: false # May be needed for talking to arcane protocols such as EPICS +restartOnConfigChange: true + listener: enabled: true resources: {} From 5b8d60a9da8c1ab32300fef9e328100988cb2a43 Mon Sep 17 00:00:00 2001 From: Dominic Oram Date: Tue, 7 May 2024 13:13:57 +0100 Subject: [PATCH 4/7] Pin sphinx-autobuild to fix starlette clash (#456) Fixes #454 --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 393a0f459..d6c89c959 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ "aiohttp", "PyYAML", "click<8.1.4", - "fastapi[all]<0.99", + "fastapi[all]<0.99", # Later versions use a newer openapi schema, which is incompatible with swagger see https://github.com/swagger-api/swagger-codegen/issues/10446 "uvicorn", "requests", "dls-bluesky-core", #requires ophyd-async @@ -44,7 +44,7 @@ dev = [ "pytest-cov", "pytest-asyncio", "ruff", - "sphinx-autobuild", + "sphinx-autobuild==2024.2.4", # Later versions have a clash with fastapi<0.99, remove pin when fastapi is a higher version "sphinx-copybutton", "sphinx-click", "sphinx-design", From 22693e6b3e2fb3148d0659ba1809136d64e19f1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Malinowski?= <56644812+stan-dot@users.noreply.github.com> Date: Mon, 13 May 2024 11:04:40 +0100 Subject: [PATCH 5/7] rename amq client to event bus client (#466) --- src/blueapi/cli/cli.py | 20 ++++++++++--------- .../cli/{amq.py => event_bus_client.py} | 2 +- src/blueapi/cli/rest.py | 2 +- 3 files changed, 13 insertions(+), 11 deletions(-) rename src/blueapi/cli/{amq.py => event_bus_client.py} (99%) diff --git a/src/blueapi/cli/cli.py b/src/blueapi/cli/cli.py index d8be05ad9..7e7c81710 100644 --- a/src/blueapi/cli/cli.py +++ b/src/blueapi/cli/cli.py @@ -9,7 +9,7 @@ from requests.exceptions import ConnectionError from blueapi import __version__ -from blueapi.cli.amq import AmqClient +from blueapi.cli.event_bus_client import EventBusClient from blueapi.config import ApplicationConfig, ConfigLoader from blueapi.core import DataEvent from blueapi.messaging import MessageContext @@ -135,7 +135,9 @@ def listen_to_events(obj: dict) -> None: """Listen to events output by blueapi""" config: ApplicationConfig = obj["config"] if config.stomp is not None: - amq_client = AmqClient(StompMessagingTemplate.autoconfigured(config.stomp)) + event_bus_client = EventBusClient( + StompMessagingTemplate.autoconfigured(config.stomp) + ) else: raise RuntimeError("Message bus needs to be configured") @@ -150,8 +152,8 @@ def on_event( "Subscribing to all bluesky events from " f"{config.stomp.host}:{config.stomp.port}" ) - with amq_client: - amq_client.subscribe_to_all_events(on_event) + with event_bus_client: + event_bus_client.subscribe_to_all_events(on_event) input("Press enter to exit") @@ -181,7 +183,7 @@ def run_plan( raise RuntimeError( "Cannot run plans without Stomp configuration to track progress" ) - amq_client = AmqClient(_message_template) + event_bus_client = EventBusClient(_message_template) finished_event: deque[WorkerEvent] = deque() def store_finished_event(event: WorkerEvent) -> None: @@ -194,13 +196,13 @@ def store_finished_event(event: WorkerEvent) -> None: resp = client.create_task(task) task_id = resp.task_id - with amq_client: - amq_client.subscribe_to_topics(task_id, on_event=store_finished_event) + with event_bus_client: + event_bus_client.subscribe_to_topics(task_id, on_event=store_finished_event) updated = client.update_worker_task(WorkerTask(task_id=task_id)) - amq_client.wait_for_complete(timeout=timeout) + event_bus_client.wait_for_complete(timeout=timeout) - if amq_client.timed_out: + if event_bus_client.timed_out: logger.error(f"Plan did not complete within {timeout} seconds") return diff --git a/src/blueapi/cli/amq.py b/src/blueapi/cli/event_bus_client.py similarity index 99% rename from src/blueapi/cli/amq.py rename to src/blueapi/cli/event_bus_client.py index face01b4b..afa2e4416 100644 --- a/src/blueapi/cli/amq.py +++ b/src/blueapi/cli/event_bus_client.py @@ -18,7 +18,7 @@ def __init__(self, message: str) -> None: _Event = WorkerEvent | ProgressEvent | DataEvent -class AmqClient: +class EventBusClient: app: MessagingTemplate complete: threading.Event timed_out: bool | None diff --git a/src/blueapi/cli/rest.py b/src/blueapi/cli/rest.py index 5e363faee..0fe7abd6e 100644 --- a/src/blueapi/cli/rest.py +++ b/src/blueapi/cli/rest.py @@ -15,7 +15,7 @@ ) from blueapi.worker import Task, TrackableTask, WorkerState -from .amq import BlueskyRemoteError +from .event_bus_client import BlueskyRemoteError T = TypeVar("T") From 4f1f44cb45e63ce5ed81c1d722a50317044210b9 Mon Sep 17 00:00:00 2001 From: Joe Shannon Date: Mon, 13 May 2024 14:15:46 +0100 Subject: [PATCH 6/7] Remove ophyd_async_connect (#462) The device connection is now handled by device_instantiation in dodal. This function also provides the option on whether to wait for connection, so it is not needed here too. Additionally it can lead to undefined (currently) behaviour if the device is initially created with fake_with_ophyd_sim = True but then later connected again by blueapi with fake_with_ophyd_sim = False. This also leaves the sim property on BlueskyContext redundant so that is removed too. For full customisation and flexibility of lazy connect we need #440. Fixes #461. --- src/blueapi/core/context.py | 11 +--- src/blueapi/service/handler.py | 4 +- src/blueapi/utils/__init__.py | 2 - src/blueapi/utils/ophyd_async_connect.py | 54 ----------------- tests/utils/test_ophyd_async_connect.py | 77 ------------------------ 5 files changed, 2 insertions(+), 146 deletions(-) delete mode 100644 src/blueapi/utils/ophyd_async_connect.py delete mode 100644 tests/utils/test_ophyd_async_connect.py diff --git a/src/blueapi/core/context.py b/src/blueapi/core/context.py index 98493a913..50623967e 100644 --- a/src/blueapi/core/context.py +++ b/src/blueapi/core/context.py @@ -6,14 +6,13 @@ from types import ModuleType, UnionType from typing import Any, Generic, TypeVar, Union, get_args, get_origin, get_type_hints -from bluesky.run_engine import RunEngine, call_in_bluesky_event_loop +from bluesky.run_engine import RunEngine from pydantic import create_model from pydantic.fields import FieldInfo, ModelField from blueapi.config import EnvironmentConfig, SourceKind from blueapi.utils import ( BlueapiPlanModelConfig, - connect_ophyd_async_devices, load_module_all, ) @@ -45,7 +44,6 @@ class BlueskyContext: plans: dict[str, Plan] = field(default_factory=dict) devices: dict[str, Device] = field(default_factory=dict) plan_functions: dict[str, PlanGenerator] = field(default_factory=dict) - sim: bool = field(default=False) _reference_cache: dict[type, type] = field(default_factory=dict) @@ -78,13 +76,6 @@ def with_config(self, config: EnvironmentConfig) -> None: elif source.kind is SourceKind.DODAL: self.with_dodal_module(mod) - call_in_bluesky_event_loop( - connect_ophyd_async_devices( - self.devices.values(), - self.sim, - ) - ) - def with_plan_module(self, module: ModuleType) -> None: """ Register all functions in the module supplied as plans. diff --git a/src/blueapi/service/handler.py b/src/blueapi/service/handler.py index afa9818c1..24cdd312a 100644 --- a/src/blueapi/service/handler.py +++ b/src/blueapi/service/handler.py @@ -154,9 +154,7 @@ def setup_handler( handler = Handler( config, - context=BlueskyContext( - sim=False, - ), + context=BlueskyContext(), ) handler.start() diff --git a/src/blueapi/utils/__init__.py b/src/blueapi/utils/__init__.py index b3c212a51..b871f842a 100644 --- a/src/blueapi/utils/__init__.py +++ b/src/blueapi/utils/__init__.py @@ -1,7 +1,6 @@ from .base_model import BlueapiBaseModel, BlueapiModelConfig, BlueapiPlanModelConfig from .invalid_config_error import InvalidConfigError from .modules import load_module_all -from .ophyd_async_connect import connect_ophyd_async_devices from .serialization import serialize from .thread_exception import handle_all_exceptions @@ -14,5 +13,4 @@ "BlueapiModelConfig", "BlueapiPlanModelConfig", "InvalidConfigError", - "connect_ophyd_async_devices", ] diff --git a/src/blueapi/utils/ophyd_async_connect.py b/src/blueapi/utils/ophyd_async_connect.py deleted file mode 100644 index 382b412bf..000000000 --- a/src/blueapi/utils/ophyd_async_connect.py +++ /dev/null @@ -1,54 +0,0 @@ -import asyncio -import logging -from collections.abc import Iterable -from contextlib import suppress -from typing import Any - -from ophyd_async.core import DEFAULT_TIMEOUT, NotConnected -from ophyd_async.core import Device as OphydAsyncDevice - - -async def connect_ophyd_async_devices( - devices: Iterable[Any], - sim: bool = False, - timeout: float = DEFAULT_TIMEOUT, -) -> None: - tasks: dict[asyncio.Task, str] = {} - for device in devices: - if isinstance(device, OphydAsyncDevice): - task = asyncio.create_task(device.connect(sim=sim)) - tasks[task] = device.name - if tasks: - await _wait_for_tasks(tasks, timeout=timeout) - - -async def _wait_for_tasks(tasks: dict[asyncio.Task, str], timeout: float): - done, pending = await asyncio.wait(tasks, timeout=timeout) - if pending: - msg = f"{len(pending)} Devices did not connect:" - for t in pending: - t.cancel() - with suppress(Exception): - await t - msg += _format_awaited_task_error_message(tasks, t) - logging.error(msg) - raised = [t for t in done if t.exception()] - if raised: - logging.error(f"{len(raised)} Devices raised an error:") - for t in raised: - logging.exception(f" {tasks[t]}:", exc_info=t.exception()) - if pending or raised: - raise NotConnected("Not all Devices connected") - - -def _format_awaited_task_error_message( - tasks: dict[asyncio.Task, str], t: asyncio.Task -) -> str: - e = t.exception() - part_one = f"\n {tasks[t]}: {type(e).__name__}" - lines = str(e).splitlines() - - part_two = ( - f": {e}" if len(lines) <= 1 else "".join(f"\n {line}" for line in lines) - ) - return part_one + part_two diff --git a/tests/utils/test_ophyd_async_connect.py b/tests/utils/test_ophyd_async_connect.py deleted file mode 100644 index f3dcba767..000000000 --- a/tests/utils/test_ophyd_async_connect.py +++ /dev/null @@ -1,77 +0,0 @@ -import asyncio -import unittest - -from blueapi.utils.ophyd_async_connect import _format_awaited_task_error_message -from blueapi.worker.task import Task - -_SIMPLE_TASK = Task(name="sleep", params={"time": 0.0}) -_LONG_TASK = Task(name="sleep", params={"time": 1.0}) - - -class TestFormatErrorMessage(unittest.TestCase): - def setUp(self): - # Setup the asyncio event loop for each test - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.loop) - - def tearDown(self): - # Close the loop at the end of each test - self.loop.close() - - async def _create_task_with_exception(self, exception): - """Helper coroutine to create a task that raises an exception.""" - - async def raise_exception(): - raise exception - - task = self.loop.create_task(raise_exception()) - await asyncio.sleep(0.1) # Allow time for the task to raise the exception - return task - - def test_format_error_message_single_line(self): - # Test formatting with an exception that has a single-line message - exception = ValueError("A single-line error") - task = self.loop.run_until_complete(self._create_task_with_exception(exception)) - tasks = {task: "Task1"} - expected_output = "\n Task1: ValueError: A single-line error" - self.assertEqual( - _format_awaited_task_error_message(tasks, task), expected_output - ) - - def test_format_error_message_multi_line(self): - # Test formatting with an exception that has a multi-line message - exception = ValueError("A multi-line\nerror message") - task = self.loop.run_until_complete(self._create_task_with_exception(exception)) - tasks = {task: "Task2"} - expected_output = "\n Task2: ValueError\n A multi-line\n error message" - self.assertEqual( - _format_awaited_task_error_message(tasks, task), expected_output - ) - - def test_format_error_message_simple_task_failure(self): - # Test formatting with the _SIMPLE_TASK key and a failing asyncio task - exception = RuntimeError("Simple task error") - failing_task = self.loop.run_until_complete( - self._create_task_with_exception(exception) - ) - tasks = {failing_task: _SIMPLE_TASK.name} - expected_output = "\n sleep: RuntimeError: Simple task error" - self.assertEqual( - _format_awaited_task_error_message(tasks, failing_task), expected_output - ) - - def test_format_error_message_long_task_failure(self): - # Test formatting with the _LONG_TASK key and a failing asyncio task - exception = RuntimeError("Long task error") - failing_task = self.loop.run_until_complete( - self._create_task_with_exception(exception) - ) - tasks = {failing_task: _LONG_TASK.name} - expected_output = "\n sleep: RuntimeError: Long task error" - self.assertEqual( - _format_awaited_task_error_message(tasks, failing_task), expected_output - ) - - -if __name__ == "__main__": - unittest.main() From 59507723e7e8abddbd5e0607f656d3c5fb3a14ac Mon Sep 17 00:00:00 2001 From: Keith Ralphs Date: Fri, 17 May 2024 09:24:27 +0100 Subject: [PATCH 7/7] Add minimum dodal release (#469) Depend on at least Dodal 1.24.0 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index d6c89c959..7304ec3b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ dependencies = [ "uvicorn", "requests", "dls-bluesky-core", #requires ophyd-async - "dls-dodal", + "dls-dodal>=1.24.0", ] dynamic = ["version"] license.file = "LICENSE"