diff --git a/.github/actions/install_requirements/action.yml b/.github/actions/install_requirements/action.yml index 6036af752..b3525dd91 100644 --- a/.github/actions/install_requirements/action.yml +++ b/.github/actions/install_requirements/action.yml @@ -9,7 +9,7 @@ inputs: required: true python_version: description: Python version to install - default: "3.x" + default: "3.9" runs: using: composite diff --git a/config/bl38p.yaml b/config/bl38p.yaml index 34eeeeacb..b7dbf28f6 100644 --- a/config/bl38p.yaml +++ b/config/bl38p.yaml @@ -5,4 +5,7 @@ env: - kind: planFunctions module: dls_bluesky_core.plans - kind: planFunctions - module: dls_bluesky_core.stubs \ No newline at end of file + module: dls_bluesky_core.stubs + data_writing: + visit_directory: /dls/p38/data/2023/cm33874-1 + group_name: BL38P diff --git a/src/blueapi/core/context.py b/src/blueapi/core/context.py index 7295cf8d5..6bf91075a 100644 --- a/src/blueapi/core/context.py +++ b/src/blueapi/core/context.py @@ -26,7 +26,6 @@ from pydantic.fields import FieldInfo, ModelField from blueapi.config import EnvironmentConfig, SourceKind -from blueapi.data_management.gda_directory_provider import VisitDirectoryProvider from blueapi.utils import ( BlueapiPlanModelConfig, connect_ophyd_async_devices, @@ -64,7 +63,6 @@ class BlueskyContext: plans: Dict[str, Plan] = field(default_factory=dict) devices: Dict[str, Device] = field(default_factory=dict) plan_functions: Dict[str, PlanGenerator] = field(default_factory=dict) - directory_provider: Optional[VisitDirectoryProvider] = field(default=None) sim: bool = field(default=False) _reference_cache: Dict[Type, Type] = field(default_factory=dict) diff --git a/src/blueapi/data_management/__init__.py b/src/blueapi/data_management/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/blueapi/data_management/gda_directory_provider.py b/src/blueapi/data_management/gda_directory_provider.py deleted file mode 100644 index f966a3d7a..000000000 --- a/src/blueapi/data_management/gda_directory_provider.py +++ /dev/null @@ -1,128 +0,0 @@ -import logging -from abc import ABC, abstractmethod -from pathlib import Path -from typing import Optional - -from aiohttp import ClientSession -from ophyd_async.core import DirectoryInfo, DirectoryProvider -from pydantic import BaseModel - - -class DataCollectionIdentifier(BaseModel): - collectionNumber: int - - -class VisitServiceClientBase(ABC): - """ - Object responsible for I/O in determining collection number - """ - - @abstractmethod - async def create_new_collection(self) -> DataCollectionIdentifier: - ... - - @abstractmethod - async def get_current_collection(self) -> DataCollectionIdentifier: - ... - - -class VisitServiceClient(VisitServiceClientBase): - _url: str - - def __init__(self, url: str) -> None: - self._url = url - - async def create_new_collection(self) -> DataCollectionIdentifier: - async with ClientSession() as session: - async with session.post(f"{self._url}/numtracker") as response: - if response.status == 200: - json = await response.json() - return DataCollectionIdentifier.parse_obj(json) - else: - raise Exception(response.status) - - async def get_current_collection(self) -> DataCollectionIdentifier: - async with ClientSession() as session: - async with session.get(f"{self._url}/numtracker") as response: - if response.status == 200: - json = await response.json() - return DataCollectionIdentifier.parse_obj(json) - else: - raise Exception(response.status) - - -class LocalVisitServiceClient(VisitServiceClientBase): - _count: int - - def __init__(self) -> None: - self._count = 0 - - async def create_new_collection(self) -> DataCollectionIdentifier: - self._count += 1 - return DataCollectionIdentifier(collectionNumber=self._count) - - async def get_current_collection(self) -> DataCollectionIdentifier: - return DataCollectionIdentifier(collectionNumber=self._count) - - -class VisitDirectoryProvider(DirectoryProvider): - """ - Gets information from a remote service to construct the path that detectors - should write to, and determine how their files should be named. - """ - - _data_group_name: str - _data_directory: Path - - _client: VisitServiceClientBase - _current_collection: Optional[DirectoryInfo] - _session: Optional[ClientSession] - - def __init__( - self, - data_group_name: str, - data_directory: Path, - client: VisitServiceClientBase, - ): - self._data_group_name = data_group_name - self._data_directory = data_directory - self._client = client - - self._current_collection = None - self._session = None - - async def update(self) -> None: - """ - Calls the visit service to create a new data collection in the current visit. - """ - # TODO: After visit service is more feature complete: - # TODO: Allow selecting visit as part of the request to BlueAPI - # TODO: Consume visit information from BlueAPI and pass down to this class - # TODO: Query visit service to get information about visit and data collection - # TODO: Use AuthN information as part of verification with visit service - - try: - collection_id_info = await self._client.create_new_collection() - self._current_collection = self._generate_directory_info(collection_id_info) - except Exception as ex: - # TODO: The catch all is needed because the RunEngine will not - # currently handle it, see - # https://github.com/bluesky/bluesky/pull/1623 - self._current_collection = None - logging.exception(ex) - - def _generate_directory_info( - self, - collection_id_info: DataCollectionIdentifier, - ) -> DirectoryInfo: - collection_id = collection_id_info.collectionNumber - file_prefix = f"{self._data_group_name}-{collection_id}" - return DirectoryInfo(str(self._data_directory), file_prefix) - - def __call__(self) -> DirectoryInfo: - if self._current_collection is not None: - return self._current_collection - else: - raise ValueError( - "No current collection, update() needs to be called at least once" - ) diff --git a/src/blueapi/preprocessors/attach_metadata.py b/src/blueapi/preprocessors/attach_metadata.py index 0b7a0306a..3de02201e 100644 --- a/src/blueapi/preprocessors/attach_metadata.py +++ b/src/blueapi/preprocessors/attach_metadata.py @@ -1,8 +1,8 @@ -import bluesky.plan_stubs as bps +import bluesky.preprocessors as bpp from bluesky.utils import make_decorator +from ophyd_async.core import DirectoryProvider from blueapi.core import MsgGenerator -from blueapi.data_management.gda_directory_provider import VisitDirectoryProvider DATA_SESSION = "data_session" DATA_GROUPS = "data_groups" @@ -10,26 +10,14 @@ def attach_metadata( plan: MsgGenerator, - provider: VisitDirectoryProvider, + provider: DirectoryProvider, ) -> MsgGenerator: """ Attach data session metadata to the runs within a plan and make it correlate with an ophyd-async DirectoryProvider. - This wrapper is meant to ensure (on a best-effort basis) that detectors write - their data to the same place for a given run, and that their writings are - tied together in the run via the data_session metadata keyword in the run - start document. - - The wrapper groups data by staging and bundles it with runs as best it can. - Since staging is inherently decoupled from runs this is done on a best-effort - basis. In the following sequence of messages: - - |stage|, stage, |open_run|, close_run, unstage, unstage, |stage|, stage, - |open_run|, close_run, unstage, unstage - - A new group is created at each |stage| and bundled into the start document - at each |open_run|. + This calls the directory provider and ensures the start document contains + the correct data session. Args: plan: The plan to preprocess @@ -41,32 +29,10 @@ def attach_metadata( Yields: Iterator[Msg]: Plan messages """ - - group_in_progress = False - - for message in plan: - # If the first stage in a series of stages is detected, - # update the directory provider and create a new group. - if (message.command == "stage") and (not group_in_progress): - yield from bps.wait_for([provider.update]) - group_in_progress = True - # Mark if detectors are being unstaged so that the start - # of the next sequence of stages is detectable. - elif message.command == "unstage": - group_in_progress = False - - # If a run is being opened, attempt to bundle the information - # on any existing group into the start document. - if message.command == "open_run": - # Handle the case where we're opening a run but no detectors - # have been staged yet. Common for nested runs. - if not group_in_progress: - yield from bps.wait_for([provider.update]) - directory_info = provider() - message.kwargs[DATA_SESSION] = directory_info.filename_prefix - - # This is a preprocessor so we yield the original message. - yield message + directory_info = provider() + yield from bpp.inject_md_wrapper( + plan, md={DATA_SESSION: directory_info.filename_prefix} + ) attach_metadata_decorator = make_decorator(attach_metadata) diff --git a/src/blueapi/service/handler.py b/src/blueapi/service/handler.py index ed63b0e60..3e3dbef8a 100644 --- a/src/blueapi/service/handler.py +++ b/src/blueapi/service/handler.py @@ -1,15 +1,11 @@ import logging from typing import Mapping, Optional +from ophyd_async.core import StaticDirectoryProvider + from blueapi.config import ApplicationConfig from blueapi.core import BlueskyContext from blueapi.core.event import EventStream -from blueapi.data_management.gda_directory_provider import ( - LocalVisitServiceClient, - VisitDirectoryProvider, - VisitServiceClient, - VisitServiceClientBase, -) from blueapi.messaging import StompMessagingTemplate from blueapi.messaging.base import MessagingTemplate from blueapi.preprocessors.attach_metadata import attach_metadata @@ -92,18 +88,9 @@ def setup_handler( plan_wrappers = [] if config: - visit_service_client: VisitServiceClientBase - if config.env.data_writing.visit_service_url is not None: - visit_service_client = VisitServiceClient( - config.env.data_writing.visit_service_url - ) - else: - visit_service_client = LocalVisitServiceClient() - - provider = VisitDirectoryProvider( - data_group_name=config.env.data_writing.group_name, - data_directory=config.env.data_writing.visit_directory, - client=visit_service_client, + provider = StaticDirectoryProvider( + filename_prefix=f"{config.env.data_writing.group_name}-blueapi", + directory_path=str(config.env.data_writing.visit_directory), ) # Make all dodal devices created by the context use provider if they can @@ -125,7 +112,6 @@ def setup_handler( config, context=BlueskyContext( plan_wrappers=plan_wrappers, - directory_provider=provider, sim=False, ), ) diff --git a/tests/data_writing/__init__.py b/tests/data_writing/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/data_writing/test_gda_directory_provider.py b/tests/data_writing/test_gda_directory_provider.py deleted file mode 100644 index 10dd76d08..000000000 --- a/tests/data_writing/test_gda_directory_provider.py +++ /dev/null @@ -1,66 +0,0 @@ -from pathlib import Path - -import pytest -from ophyd_async.core import DirectoryInfo - -from blueapi.data_management.gda_directory_provider import ( - DataCollectionIdentifier, - LocalVisitServiceClient, - VisitDirectoryProvider, - VisitServiceClientBase, -) - - -@pytest.fixture -def visit_service_client() -> VisitServiceClientBase: - return LocalVisitServiceClient() - - -@pytest.fixture -def visit_directory_provider( - visit_service_client: VisitServiceClientBase, -) -> VisitDirectoryProvider: - return VisitDirectoryProvider("example", Path("/tmp"), visit_service_client) - - -@pytest.mark.asyncio -async def test_client_can_view_collection( - visit_service_client: VisitServiceClientBase, -) -> None: - collection = await visit_service_client.get_current_collection() - assert collection == DataCollectionIdentifier(collectionNumber=0) - - -@pytest.mark.asyncio -async def test_client_can_create_collection( - visit_service_client: VisitServiceClientBase, -) -> None: - collection = await visit_service_client.create_new_collection() - assert collection == DataCollectionIdentifier(collectionNumber=1) - - -@pytest.mark.asyncio -async def test_update_sets_collection_number( - visit_directory_provider: VisitDirectoryProvider, -) -> None: - await visit_directory_provider.update() - assert visit_directory_provider() == DirectoryInfo( - directory_path="/tmp", - filename_prefix="example-1", - ) - - -@pytest.mark.asyncio -async def test_update_sets_collection_number_multi( - visit_directory_provider: VisitDirectoryProvider, -) -> None: - await visit_directory_provider.update() - assert visit_directory_provider() == DirectoryInfo( - directory_path="/tmp", - filename_prefix="example-1", - ) - await visit_directory_provider.update() - assert visit_directory_provider() == DirectoryInfo( - directory_path="/tmp", - filename_prefix="example-2", - ) diff --git a/tests/preprocessors/test_attach_metadata.py b/tests/preprocessors/test_attach_metadata.py index a7ba1b1de..9f3a1f1d7 100644 --- a/tests/preprocessors/test_attach_metadata.py +++ b/tests/preprocessors/test_attach_metadata.py @@ -1,78 +1,29 @@ from pathlib import Path -from typing import Any, Callable, Dict, List, Mapping +from typing import Any, Dict, List, Mapping -import bluesky.plan_stubs as bps import bluesky.plans as bp import pytest from bluesky import RunEngine -from bluesky.preprocessors import ( - run_decorator, - run_wrapper, - set_run_key_decorator, - set_run_key_wrapper, - stage_wrapper, -) from bluesky.protocols import HasName, Readable, Reading, Status, Triggerable from event_model.documents.event_descriptor import DataKey from ophyd.status import StatusBase -from ophyd_async.core import DirectoryProvider +from ophyd_async.core import DirectoryProvider, StaticDirectoryProvider from blueapi.core import DataEvent, MsgGenerator -from blueapi.data_management.gda_directory_provider import ( - DataCollectionIdentifier, - VisitDirectoryProvider, - VisitServiceClient, -) from blueapi.preprocessors.attach_metadata import DATA_SESSION, attach_metadata DATA_DIRECTORY = Path("/tmp") DATA_GROUP_NAME = "test" -RUN_0 = DATA_DIRECTORY / f"{DATA_GROUP_NAME}-0" -RUN_1 = DATA_DIRECTORY / f"{DATA_GROUP_NAME}-1" -RUN_2 = DATA_DIRECTORY / f"{DATA_GROUP_NAME}-2" - - -class MockVisitServiceClient(VisitServiceClient): - _count: int - _fail: bool - - def __init__(self) -> None: - super().__init__("http://example.com") - self._count = 0 - self._fail = False - - def always_fail(self) -> None: - self._fail = True - - async def create_new_collection(self) -> DataCollectionIdentifier: - if self._fail: - raise ConnectionError() - - count = self._count - self._count += 1 - return DataCollectionIdentifier(collectionNumber=count) - - async def get_current_collection(self) -> DataCollectionIdentifier: - if self._fail: - raise ConnectionError() - - return DataCollectionIdentifier(collectionNumber=self._count) - - -@pytest.fixture -def client() -> VisitServiceClient: - return MockVisitServiceClient() +RUN_0 = DATA_DIRECTORY / f"{DATA_GROUP_NAME}" +RUN_1 = DATA_DIRECTORY / f"{DATA_GROUP_NAME}" +RUN_2 = DATA_DIRECTORY / f"{DATA_GROUP_NAME}" @pytest.fixture -def provider(client: VisitServiceClient) -> VisitDirectoryProvider: - return VisitDirectoryProvider( - data_directory=DATA_DIRECTORY, - data_group_name=DATA_GROUP_NAME, - client=client, - ) +def provider() -> DirectoryProvider: + return StaticDirectoryProvider(str(DATA_DIRECTORY), DATA_GROUP_NAME) @pytest.fixture @@ -126,7 +77,7 @@ def parent(self) -> None: @pytest.fixture(params=[1, 2]) -def detectors(request, provider: VisitDirectoryProvider) -> List[Readable]: +def detectors(request, provider: DirectoryProvider) -> List[Readable]: number_of_detectors = request.param return [ FakeDetector( @@ -137,219 +88,6 @@ def detectors(request, provider: VisitDirectoryProvider) -> List[Readable]: ] -def simple_run(detectors: List[Readable]) -> MsgGenerator: - yield from bp.count(detectors) - - -def multi_run(detectors: List[Readable]) -> MsgGenerator: - yield from bp.count(detectors) - yield from bp.count(detectors) - - -def multi_nested_plan(detectors: List[Readable]) -> MsgGenerator: - yield from simple_run(detectors) - yield from simple_run(detectors) - - -def multi_run_single_stage(detectors: List[Readable]) -> MsgGenerator: - def stageless_count() -> MsgGenerator: - return (yield from bps.one_shot(detectors)) - - def inner_plan() -> MsgGenerator: - yield from run_wrapper(stageless_count()) - yield from run_wrapper(stageless_count()) - - yield from stage_wrapper(inner_plan(), detectors) - - -def multi_run_single_stage_multi_group( - detectors: List[Readable], -) -> MsgGenerator: - def stageless_count() -> MsgGenerator: - return (yield from bps.one_shot(detectors)) - - def inner_plan() -> MsgGenerator: - yield from run_wrapper(stageless_count(), md={DATA_SESSION: 1}) - yield from run_wrapper(stageless_count(), md={DATA_SESSION: 1}) - yield from run_wrapper(stageless_count(), md={DATA_SESSION: 2}) - yield from run_wrapper(stageless_count(), md={DATA_SESSION: 2}) - - yield from stage_wrapper(inner_plan(), detectors) - - -@run_decorator(md={DATA_SESSION: 12345}) -@set_run_key_decorator("outer") -def nested_run_with_metadata(detectors: List[Readable]) -> MsgGenerator: - yield from set_run_key_wrapper(bp.count(detectors), "inner") - yield from set_run_key_wrapper(bp.count(detectors), "inner") - - -@run_decorator() -@set_run_key_decorator("outer") -def nested_run_without_metadata( - detectors: List[Readable], -) -> MsgGenerator: - yield from set_run_key_wrapper(bp.count(detectors), "inner") - yield from set_run_key_wrapper(bp.count(detectors), "inner") - - -def test_simple_run_gets_scan_number( - run_engine: RunEngine, - detectors: List[Readable], - provider: DirectoryProvider, -) -> None: - docs = collect_docs( - run_engine, - simple_run(detectors), - provider, - ) - assert docs[0].name == "start" - assert docs[0].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" - assert_all_detectors_used_collection_numbers(docs, detectors, [RUN_0]) - - -@pytest.mark.parametrize("plan", [multi_run, multi_nested_plan]) -def test_multi_run_gets_scan_numbers( - run_engine: RunEngine, - detectors: List[Readable], - plan: Callable[[List[Readable]], MsgGenerator], - provider: DirectoryProvider, -) -> None: - docs = collect_docs( - run_engine, - plan(detectors), - provider, - ) - start_docs = find_start_docs(docs) - assert len(start_docs) == 2 - assert start_docs[0].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" - assert start_docs[1].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-1" - assert_all_detectors_used_collection_numbers(docs, detectors, [RUN_0, RUN_1]) - - -def test_multi_run_single_stage( - run_engine: RunEngine, - detectors: List[Readable], - provider: DirectoryProvider, -) -> None: - docs = collect_docs( - run_engine, - multi_run_single_stage(detectors), - provider, - ) - start_docs = find_start_docs(docs) - assert len(start_docs) == 2 - assert start_docs[0].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" - assert start_docs[1].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" - assert_all_detectors_used_collection_numbers( - docs, - detectors, - [ - RUN_0, - RUN_0, - ], - ) - - -def test_multi_run_single_stage_multi_group( - run_engine: RunEngine, - detectors: List[Readable], - provider: DirectoryProvider, -) -> None: - docs = collect_docs( - run_engine, - multi_run_single_stage_multi_group(detectors), - provider, - ) - start_docs = find_start_docs(docs) - assert len(start_docs) == 4 - assert start_docs[0].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" - assert start_docs[1].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" - assert start_docs[2].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" - assert start_docs[3].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" - assert_all_detectors_used_collection_numbers( - docs, - detectors, - [ - RUN_0, - RUN_0, - RUN_0, - RUN_0, - ], - ) - - -def test_nested_run_with_metadata( - run_engine: RunEngine, - detectors: List[Readable], - provider: DirectoryProvider, -) -> None: - docs = collect_docs( - run_engine, - nested_run_with_metadata(detectors), - provider, - ) - start_docs = find_start_docs(docs) - assert len(start_docs) == 3 - assert start_docs[0].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" - assert start_docs[1].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-1" - assert start_docs[2].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-2" - assert_all_detectors_used_collection_numbers(docs, detectors, [RUN_1, RUN_2]) - - -def test_nested_run_without_metadata( - run_engine: RunEngine, - detectors: List[Readable], - provider: DirectoryProvider, -) -> None: - docs = collect_docs( - run_engine, - nested_run_without_metadata(detectors), - provider, - ) - start_docs = find_start_docs(docs) - assert len(start_docs) == 3 - assert start_docs[0].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" - assert start_docs[1].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-1" - assert start_docs[2].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-2" - assert_all_detectors_used_collection_numbers(docs, detectors, [RUN_1, RUN_2]) - - -def test_visit_directory_provider_fails( - run_engine: RunEngine, - detectors: List[Readable], - provider: DirectoryProvider, - client: MockVisitServiceClient, -) -> None: - client.always_fail() - with pytest.raises(ValueError): - collect_docs( - run_engine, - simple_run(detectors), - provider, - ) - - -def test_visit_directory_provider_fails_after_one_sucess( - run_engine: RunEngine, - detectors: List[Readable], - provider: DirectoryProvider, - client: MockVisitServiceClient, -) -> None: - collect_docs( - run_engine, - simple_run(detectors), - provider, - ) - client.always_fail() - with pytest.raises(ValueError): - collect_docs( - run_engine, - simple_run(detectors), - provider, - ) - - def collect_docs( run_engine: RunEngine, plan: MsgGenerator, @@ -365,25 +103,13 @@ def on_event(name: str, doc: Mapping[str, Any]) -> None: return events -def assert_all_detectors_used_collection_numbers( - docs: List[DataEvent], - detectors: List[Readable], - source_history: List[Path], -) -> None: - descriptors = find_descriptor_docs(docs) - assert len(descriptors) == len(source_history) - - for descriptor, expected_source in zip(descriptors, source_history): - for detector in detectors: - source = descriptor.doc.get("data_keys", {}).get(f"{detector.name}_data")[ - "source" - ] - assert Path(source) == expected_source - - -def find_start_docs(docs: List[DataEvent]) -> List[DataEvent]: - return list(filter(lambda event: event.name == "start", docs)) - - -def find_descriptor_docs(docs: List[DataEvent]) -> List[DataEvent]: - return list(filter(lambda event: event.name == "descriptor", docs)) +def test_attach_metadata_attaches_correct_data_session( + detectors: List[Readable], provider: DirectoryProvider, run_engine: RunEngine +): + docs = collect_docs( + run_engine, + attach_metadata(bp.count(detectors), provider), + provider, + ) + assert docs[0].name == "start" + assert docs[0].doc.get(DATA_SESSION) == DATA_GROUP_NAME