From 42cac054a299e27f0156cab38a98064d2910643d Mon Sep 17 00:00:00 2001 From: Callum Forrester Date: Thu, 26 Oct 2023 12:56:33 +0100 Subject: [PATCH] Make devices connect with a timeout (#321) Co-authored-by: Rose Yemelyanova --- pyproject.toml | 1 + src/blueapi/core/context.py | 25 ++++++------- src/blueapi/utils/__init__.py | 2 + src/blueapi/utils/ophyd_async_connect.py | 47 ++++++++++++++++++++++++ 4 files changed, 61 insertions(+), 14 deletions(-) create mode 100644 src/blueapi/utils/ophyd_async_connect.py diff --git a/pyproject.toml b/pyproject.toml index 88de76627..23483b7e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -104,6 +104,7 @@ testpaths = "docs src tests" markers = [ "handler: marks tests that interact with the global handler object in handler.py", ] +asyncio_mode = "auto" [tool.coverage.run] data_file = "/tmp/blueapi.coverage" diff --git a/src/blueapi/core/context.py b/src/blueapi/core/context.py index ffad9b98a..7295cf8d5 100644 --- a/src/blueapi/core/context.py +++ b/src/blueapi/core/context.py @@ -22,14 +22,16 @@ ) from bluesky.run_engine import RunEngine, call_in_bluesky_event_loop -from ophyd_async.core import Device as AsyncDevice -from ophyd_async.core import wait_for_connection from pydantic import create_model from pydantic.fields import FieldInfo, ModelField from blueapi.config import EnvironmentConfig, SourceKind from blueapi.data_management.gda_directory_provider import VisitDirectoryProvider -from blueapi.utils import BlueapiPlanModelConfig, load_module_all +from blueapi.utils import ( + BlueapiPlanModelConfig, + connect_ophyd_async_devices, + load_module_all, +) from .bluesky_types import ( BLUESKY_PROTOCOLS, @@ -104,17 +106,12 @@ def with_config(self, config: EnvironmentConfig) -> None: elif source.kind is SourceKind.DODAL: self.with_dodal_module(mod) - call_in_bluesky_event_loop(self.connect_devices(self.sim)) - - async def connect_devices(self, sim: bool = False) -> None: - coros = {} - for device_name, device in self.devices.items(): - if isinstance(device, AsyncDevice): - device.set_name(device_name) - coros[device_name] = device.connect(sim) - - if len(coros) > 0: - await wait_for_connection(**coros) + call_in_bluesky_event_loop( + connect_ophyd_async_devices( + self.devices.values(), + self.sim, + ) + ) def with_plan_module(self, module: ModuleType) -> None: """ diff --git a/src/blueapi/utils/__init__.py b/src/blueapi/utils/__init__.py index b871f842a..b3c212a51 100644 --- a/src/blueapi/utils/__init__.py +++ b/src/blueapi/utils/__init__.py @@ -1,6 +1,7 @@ from .base_model import BlueapiBaseModel, BlueapiModelConfig, BlueapiPlanModelConfig from .invalid_config_error import InvalidConfigError from .modules import load_module_all +from .ophyd_async_connect import connect_ophyd_async_devices from .serialization import serialize from .thread_exception import handle_all_exceptions @@ -13,4 +14,5 @@ "BlueapiModelConfig", "BlueapiPlanModelConfig", "InvalidConfigError", + "connect_ophyd_async_devices", ] diff --git a/src/blueapi/utils/ophyd_async_connect.py b/src/blueapi/utils/ophyd_async_connect.py new file mode 100644 index 000000000..45fdc5a11 --- /dev/null +++ b/src/blueapi/utils/ophyd_async_connect.py @@ -0,0 +1,47 @@ +import asyncio +import logging +from contextlib import suppress +from typing import Any, Dict, Iterable + +from ophyd_async.core import DEFAULT_TIMEOUT +from ophyd_async.core import Device as OphydAsyncDevice +from ophyd_async.core import NotConnected + + +async def connect_ophyd_async_devices( + devices: Iterable[Any], + sim: bool = False, + timeout: float = DEFAULT_TIMEOUT, +) -> None: + tasks: Dict[asyncio.Task, str] = {} + for device in devices: + if isinstance(device, OphydAsyncDevice): + task = asyncio.create_task(device.connect(sim=sim)) + tasks[task] = device.name + if tasks: + await _wait_for_tasks(tasks, timeout=timeout) + + +async def _wait_for_tasks(tasks: Dict[asyncio.Task, str], timeout: float): + done, pending = await asyncio.wait(tasks, timeout=timeout) + if pending: + msg = f"{len(pending)} Devices did not connect:" + for t in pending: + t.cancel() + with suppress(Exception): + await t + e = t.exception() + msg += f"\n {tasks[t]}: {type(e).__name__}" + lines = str(e).splitlines() + if len(lines) <= 1: + msg += f": {e}" + else: + msg += "".join(f"\n {line}" for line in lines) + logging.error(msg) + raised = [t for t in done if t.exception()] + if raised: + logging.error(f"{len(raised)} Devices raised an error:") + for t in raised: + logging.exception(f" {tasks[t]}:", exc_info=t.exception()) + if pending or raised: + raise NotConnected("Not all Devices connected")