From 581f6d9eaa3218b555f097fcbc8f766b9b510996 Mon Sep 17 00:00:00 2001 From: Jakub Wlodek Date: Tue, 10 Sep 2024 07:50:37 -0400 Subject: [PATCH 1/5] Update signature of __init__ for simdetector to match other detector classes (#563) --- src/ophyd_async/epics/adsimdetector/_sim.py | 9 ++++---- tests/core/test_protocol.py | 10 ++++----- tests/epics/adsimdetector/test_sim.py | 23 ++++++++------------- 3 files changed, 18 insertions(+), 24 deletions(-) diff --git a/src/ophyd_async/epics/adsimdetector/_sim.py b/src/ophyd_async/epics/adsimdetector/_sim.py index b69937705f..c007c72ffc 100644 --- a/src/ophyd_async/epics/adsimdetector/_sim.py +++ b/src/ophyd_async/epics/adsimdetector/_sim.py @@ -12,14 +12,15 @@ class SimDetector(StandardDetector): def __init__( self, - drv: adcore.ADBaseIO, - hdf: adcore.NDFileHDFIO, + prefix: str, path_provider: PathProvider, + drv_suffix="cam1:", + hdf_suffix="HDF1:", name: str = "", config_sigs: Sequence[SignalR] = (), ): - self.drv = drv - self.hdf = hdf + self.drv = adcore.ADBaseIO(prefix + drv_suffix) + self.hdf = adcore.NDFileHDFIO(prefix + hdf_suffix) super().__init__( SimController(self.drv), diff --git a/tests/core/test_protocol.py b/tests/core/test_protocol.py index d71c4cce09..637d287213 100644 --- a/tests/core/test_protocol.py +++ b/tests/core/test_protocol.py @@ -9,7 +9,7 @@ StaticFilenameProvider, StaticPathProvider, ) -from ophyd_async.epics import adcore, adsimdetector +from ophyd_async.epics import adsimdetector from ophyd_async.sim.demo import SimMotor @@ -18,11 +18,9 @@ async def make_detector(prefix: str, name: str, tmp_path: Path): dp = StaticPathProvider(fp, tmp_path) async with DeviceCollector(mock=True): - drv = adcore.ADBaseIO(f"{prefix}DRV:") - hdf = adcore.NDFileHDFIO(f"{prefix}HDF:") - det = adsimdetector.SimDetector( - drv, hdf, dp, config_sigs=[drv.acquire_time, drv.acquire], name=name - ) + det = adsimdetector.SimDetector(prefix, dp, name=name) + + det._config_sigs = [det.drv.acquire_time, det.drv.acquire] return det diff --git a/tests/epics/adsimdetector/test_sim.py b/tests/epics/adsimdetector/test_sim.py index 17494e1bb5..5bbb240bef 100644 --- a/tests/epics/adsimdetector/test_sim.py +++ b/tests/epics/adsimdetector/test_sim.py @@ -31,16 +31,13 @@ async def make_detector(prefix: str, name: str, tmp_path: Path): dp = StaticPathProvider(fp, tmp_path) async with DeviceCollector(mock=True): - drv = adcore.ADBaseIO(f"{prefix}DRV:", name="drv") - hdf = adcore.NDFileHDFIO(f"{prefix}HDF:") - det = adsimdetector.SimDetector( - drv, hdf, dp, config_sigs=[drv.acquire_time, drv.acquire], name=name - ) + det = adsimdetector.SimDetector(prefix, dp, name=name) + det._config_sigs = [det.drv.acquire_time, det.drv.acquire] def _set_full_file_name(val, *args, **kwargs): - set_mock_value(hdf.full_file_name, str(tmp_path / val)) + set_mock_value(det.hdf.full_file_name, str(tmp_path / val)) - callback_on_mock_put(hdf.file_name, _set_full_file_name) + callback_on_mock_put(det.hdf.file_name, _set_full_file_name) return det @@ -284,13 +281,13 @@ async def test_read_and_describe_detector(single_detector: StandardDetector): read = await single_detector.read_configuration() assert describe == { "test-drv-acquire_time": { - "source": "mock+ca://TEST:DRV:AcquireTime_RBV", + "source": "mock+ca://TEST:cam1:AcquireTime_RBV", "dtype": "number", "dtype_numpy": " Date: Tue, 10 Sep 2024 21:13:05 +0100 Subject: [PATCH 2/5] Extend sleep time (#566) --- tests/epics/adsimdetector/test_sim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/epics/adsimdetector/test_sim.py b/tests/epics/adsimdetector/test_sim.py index 5bbb240bef..149835aa49 100644 --- a/tests/epics/adsimdetector/test_sim.py +++ b/tests/epics/adsimdetector/test_sim.py @@ -57,7 +57,7 @@ def count_sim(dets: List[StandardDetector], times: int = 1): for det in dets: yield from bps.trigger(det, wait=False, group="wait_for_trigger") - yield from bps.sleep(0.1) + yield from bps.sleep(0.2) [ set_mock_value( cast(adcore.ADHDFWriter, det.writer).hdf.num_captured, From 0cbfbe643352d7419d01f4fe958df4781736e21e Mon Sep 17 00:00:00 2001 From: Eva Lott Date: Thu, 12 Sep 2024 11:48:04 +0100 Subject: [PATCH 3/5] Introduce PvaAbstractions and use them in SeqTable (#522) Introduced `Table` type which allows pva table structures to be represented as `BaseModel` --- src/ophyd_async/core/__init__.py | 8 +- src/ophyd_async/core/_device_save_loader.py | 12 + src/ophyd_async/core/_signal_backend.py | 15 +- src/ophyd_async/core/_soft_signal_backend.py | 32 ++- src/ophyd_async/core/_table.py | 58 +++++ src/ophyd_async/core/_utils.py | 2 +- src/ophyd_async/epics/signal/_aioca.py | 32 ++- src/ophyd_async/epics/signal/_p4p.py | 52 +++- src/ophyd_async/fastcs/panda/__init__.py | 6 - src/ophyd_async/fastcs/panda/_table.py | 255 +++++++++---------- src/ophyd_async/plan_stubs/_fly.py | 16 +- tests/core/test_device_save_loader.py | 16 ++ tests/core/test_signal.py | 25 ++ tests/core/test_soft_signal_backend.py | 2 +- tests/core/test_subset_enum.py | 8 +- tests/fastcs/panda/test_panda_connect.py | 1 + tests/fastcs/panda/test_panda_utils.py | 168 +++++++----- tests/fastcs/panda/test_table.py | 239 +++++++++++++++-- tests/fastcs/panda/test_trigger.py | 61 ++++- tests/test_data/test_yaml_save.yml | 4 + 20 files changed, 759 insertions(+), 253 deletions(-) create mode 100644 src/ophyd_async/core/_table.py diff --git a/src/ophyd_async/core/__init__.py b/src/ophyd_async/core/__init__.py index 6b9c6ccac3..1928c7aba4 100644 --- a/src/ophyd_async/core/__init__.py +++ b/src/ophyd_async/core/__init__.py @@ -61,9 +61,14 @@ soft_signal_rw, wait_for_value, ) -from ._signal_backend import RuntimeSubsetEnum, SignalBackend, SubsetEnum +from ._signal_backend import ( + RuntimeSubsetEnum, + SignalBackend, + SubsetEnum, +) from ._soft_signal_backend import SignalMetadata, SoftSignalBackend from ._status import AsyncStatus, WatchableAsyncStatus, completed_status +from ._table import Table from ._utils import ( DEFAULT_TIMEOUT, CalculatableTimeout, @@ -152,6 +157,7 @@ "CalculateTimeout", "NotConnected", "ReadingValueCallback", + "Table", "T", "WatcherUpdate", "get_dtype", diff --git a/src/ophyd_async/core/_device_save_loader.py b/src/ophyd_async/core/_device_save_loader.py index 5b81228264..d847caff69 100644 --- a/src/ophyd_async/core/_device_save_loader.py +++ b/src/ophyd_async/core/_device_save_loader.py @@ -7,6 +7,7 @@ from bluesky.plan_stubs import abs_set, wait from bluesky.protocols import Location from bluesky.utils import Msg +from pydantic import BaseModel from ._device import Device from ._signal import SignalRW @@ -18,6 +19,12 @@ def ndarray_representer(dumper: yaml.Dumper, array: npt.NDArray[Any]) -> yaml.No ) +def pydantic_model_abstraction_representer( + dumper: yaml.Dumper, model: BaseModel +) -> yaml.Node: + return dumper.represent_data(model.model_dump(mode="python")) + + class OphydDumper(yaml.Dumper): def represent_data(self, data: Any) -> Any: if isinstance(data, Enum): @@ -134,6 +141,11 @@ def save_to_yaml(phases: Sequence[Dict[str, Any]], save_path: str) -> None: """ yaml.add_representer(np.ndarray, ndarray_representer, Dumper=yaml.Dumper) + yaml.add_multi_representer( + BaseModel, + pydantic_model_abstraction_representer, + Dumper=yaml.Dumper, + ) with open(save_path, "w") as file: yaml.dump(phases, file, Dumper=OphydDumper, default_flow_style=False) diff --git a/src/ophyd_async/core/_signal_backend.py b/src/ophyd_async/core/_signal_backend.py index 41e9fbcbd3..594863ef2a 100644 --- a/src/ophyd_async/core/_signal_backend.py +++ b/src/ophyd_async/core/_signal_backend.py @@ -1,5 +1,13 @@ from abc import abstractmethod -from typing import TYPE_CHECKING, ClassVar, Generic, Literal, Optional, Tuple, Type +from typing import ( + TYPE_CHECKING, + ClassVar, + Generic, + Literal, + Optional, + Tuple, + Type, +) from ._protocol import DataKey, Reading from ._utils import DEFAULT_TIMEOUT, ReadingValueCallback, T @@ -11,6 +19,11 @@ class SignalBackend(Generic[T]): #: Datatype of the signal value datatype: Optional[Type[T]] = None + @classmethod + @abstractmethod + def datatype_allowed(cls, dtype: type): + """Check if a given datatype is acceptable for this signal backend.""" + #: Like ca://PV_PREFIX:SIGNAL @abstractmethod def source(self, name: str) -> str: diff --git a/src/ophyd_async/core/_soft_signal_backend.py b/src/ophyd_async/core/_soft_signal_backend.py index 62bafd5bb1..1e895e60cc 100644 --- a/src/ophyd_async/core/_soft_signal_backend.py +++ b/src/ophyd_async/core/_soft_signal_backend.py @@ -2,15 +2,20 @@ import inspect import time +from abc import ABCMeta from collections import abc from enum import Enum from typing import Dict, Generic, Optional, Tuple, Type, Union, cast, get_origin import numpy as np from bluesky.protocols import DataKey, Dtype, Reading +from pydantic import BaseModel from typing_extensions import TypedDict -from ._signal_backend import RuntimeSubsetEnum, SignalBackend +from ._signal_backend import ( + RuntimeSubsetEnum, + SignalBackend, +) from ._utils import DEFAULT_TIMEOUT, ReadingValueCallback, T, get_dtype primitive_dtypes: Dict[type, Dtype] = { @@ -94,7 +99,7 @@ def make_initial_value(self, datatype: Optional[Type[T]]) -> T: class SoftEnumConverter(SoftConverter): choices: Tuple[str, ...] - def __init__(self, datatype: Union[RuntimeSubsetEnum, Enum]): + def __init__(self, datatype: Union[RuntimeSubsetEnum, Type[Enum]]): if issubclass(datatype, Enum): self.choices = tuple(v.value for v in datatype) else: @@ -122,6 +127,16 @@ def make_initial_value(self, datatype: Optional[Type[T]]) -> T: return cast(T, self.choices[0]) +class SoftPydanticModelConverter(SoftConverter): + def __init__(self, datatype: Type[BaseModel]): + self.datatype = datatype + + def write_value(self, value): + if isinstance(value, dict): + return self.datatype(**value) + return value + + def make_converter(datatype): is_array = get_dtype(datatype) is not None is_sequence = get_origin(datatype) == abc.Sequence @@ -129,10 +144,19 @@ def make_converter(datatype): issubclass(datatype, Enum) or issubclass(datatype, RuntimeSubsetEnum) ) + is_pydantic_model = ( + inspect.isclass(datatype) + # Necessary to avoid weirdness in ABCMeta.__subclasscheck__ + and isinstance(datatype, ABCMeta) + and issubclass(datatype, BaseModel) + ) + if is_array or is_sequence: return SoftArrayConverter() if is_enum: return SoftEnumConverter(datatype) + if is_pydantic_model: + return SoftPydanticModelConverter(datatype) return SoftConverter() @@ -145,6 +169,10 @@ class SoftSignalBackend(SignalBackend[T]): _timestamp: float _severity: int + @classmethod + def datatype_allowed(cls, datatype: Type) -> bool: + return True # Any value allowed in a soft signal + def __init__( self, datatype: Optional[Type[T]], diff --git a/src/ophyd_async/core/_table.py b/src/ophyd_async/core/_table.py new file mode 100644 index 0000000000..bdb619a3b9 --- /dev/null +++ b/src/ophyd_async/core/_table.py @@ -0,0 +1,58 @@ +import numpy as np +from pydantic import BaseModel, ConfigDict, model_validator + + +class Table(BaseModel): + """An abstraction of a Table of str to numpy array.""" + + model_config = ConfigDict(validate_assignment=True, strict=False) + + @classmethod + def row(cls, sub_cls, **kwargs) -> "Table": + arrayified_kwargs = { + field_name: np.concatenate( + ( + (default_arr := field_value.default_factory()), + np.array([kwargs[field_name]], dtype=default_arr.dtype), + ) + ) + for field_name, field_value in sub_cls.model_fields.items() + } + return sub_cls(**arrayified_kwargs) + + def __add__(self, right: "Table") -> "Table": + """Concatenate the arrays in field values.""" + + assert isinstance(right, type(self)), ( + f"{right} is not a `Table`, or is not the same " + f"type of `Table` as {self}." + ) + + return type(self)( + **{ + field_name: np.concatenate( + (getattr(self, field_name), getattr(right, field_name)) + ) + for field_name in self.model_fields + } + ) + + @model_validator(mode="after") + def validate_arrays(self) -> "Table": + first_length = len(next(iter(self))[1]) + assert all( + len(field_value) == first_length for _, field_value in self + ), "Rows should all be of equal size." + + if not all( + np.issubdtype( + self.model_fields[field_name].default_factory().dtype, field_value.dtype + ) + for field_name, field_value in self + ): + raise ValueError( + f"Cannot construct a `{type(self).__name__}`, " + "some rows have incorrect types." + ) + + return self diff --git a/src/ophyd_async/core/_utils.py b/src/ophyd_async/core/_utils.py index f5098ce717..d081ed008f 100644 --- a/src/ophyd_async/core/_utils.py +++ b/src/ophyd_async/core/_utils.py @@ -145,7 +145,7 @@ def get_dtype(typ: Type) -> Optional[np.dtype]: def get_unique(values: Dict[str, T], types: str) -> T: - """If all values are the same, return that value, otherwise return TypeError + """If all values are the same, return that value, otherwise raise TypeError >>> get_unique({"a": 1, "b": 1}, "integers") 1 diff --git a/src/ophyd_async/epics/signal/_aioca.py b/src/ophyd_async/epics/signal/_aioca.py index 78052d448d..ef8a5693e2 100644 --- a/src/ophyd_async/epics/signal/_aioca.py +++ b/src/ophyd_async/epics/signal/_aioca.py @@ -1,9 +1,10 @@ +import inspect import logging import sys from dataclasses import dataclass from enum import Enum from math import isnan, nan -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Dict, List, Optional, Sequence, Type, Union, get_origin import numpy as np from aioca import ( @@ -24,6 +25,7 @@ DEFAULT_TIMEOUT, NotConnected, ReadingValueCallback, + RuntimeSubsetEnum, SignalBackend, T, get_dtype, @@ -211,7 +213,8 @@ def make_converter( raise TypeError(f"{pv} has type [{pv_dtype}] not [{dtype}]") return CaArrayConverter(pv_dbr, None) elif pv_dbr == dbr.DBR_ENUM and datatype is bool: - # Database can't do bools, so are often representated as enums, CA can do int + # Database can't do bools, so are often representated as enums, + # CA can do int pv_choices_len = get_unique( {k: len(v.enums) for k, v in values.items()}, "number of choices" ) @@ -240,7 +243,7 @@ def make_converter( f"{pv} has type {type(value).__name__.replace('ca_', '')} " + f"not {datatype.__name__}" ) - return CaConverter(pv_dbr, None) + return CaConverter(pv_dbr, None) _tried_pyepics = False @@ -256,8 +259,31 @@ def _use_pyepics_context_if_imported(): class CaSignalBackend(SignalBackend[T]): + _ALLOWED_DATATYPES = ( + bool, + int, + float, + str, + Sequence, + Enum, + RuntimeSubsetEnum, + np.ndarray, + ) + + @classmethod + def datatype_allowed(cls, datatype: Optional[Type]) -> bool: + stripped_origin = get_origin(datatype) or datatype + if datatype is None: + return True + + return inspect.isclass(stripped_origin) and issubclass( + stripped_origin, cls._ALLOWED_DATATYPES + ) + def __init__(self, datatype: Optional[Type[T]], read_pv: str, write_pv: str): self.datatype = datatype + if not CaSignalBackend.datatype_allowed(self.datatype): + raise TypeError(f"Given datatype {self.datatype} unsupported in CA.") self.read_pv = read_pv self.write_pv = write_pv self.initial_values: Dict[str, AugmentedValue] = {} diff --git a/src/ophyd_async/epics/signal/_p4p.py b/src/ophyd_async/epics/signal/_p4p.py index 28ec8fe6ab..c7d0b5240d 100644 --- a/src/ophyd_async/epics/signal/_p4p.py +++ b/src/ophyd_async/epics/signal/_p4p.py @@ -3,14 +3,17 @@ import inspect import logging import time +from abc import ABCMeta from dataclasses import dataclass from enum import Enum from math import isnan, nan -from typing import Any, Dict, List, Optional, Sequence, Type, Union +from typing import Any, Dict, List, Optional, Sequence, Type, Union, get_origin +import numpy as np from bluesky.protocols import DataKey, Dtype, Reading from p4p import Value from p4p.client.asyncio import Context, Subscription +from pydantic import BaseModel from ophyd_async.core import ( DEFAULT_TIMEOUT, @@ -253,6 +256,19 @@ def get_datakey(self, source: str, value) -> DataKey: return _data_key_from_value(source, value, dtype="object") +class PvaPydanticModelConverter(PvaConverter): + def __init__(self, datatype: BaseModel): + self.datatype = datatype + + def value(self, value: Value): + return self.datatype(**value.todict()) + + def write_value(self, value: Union[BaseModel, Dict[str, Any]]): + if isinstance(value, self.datatype): + return value.model_dump(mode="python") + return value + + class PvaDictConverter(PvaConverter): def reading(self, value): ts = time.time() @@ -348,6 +364,15 @@ def make_converter(datatype: Optional[Type], values: Dict[str, Any]) -> PvaConve raise TypeError(f"{pv} has type {typ.__name__} not {datatype.__name__}") return PvaConverter() elif "NTTable" in typeid: + if ( + datatype + and inspect.isclass(datatype) + and + # Necessary to avoid weirdness in ABCMeta.__subclasscheck__ + isinstance(datatype, ABCMeta) + and issubclass(datatype, BaseModel) + ): + return PvaPydanticModelConverter(datatype) return PvaTableConverter() elif "structure" in typeid: return PvaDictConverter() @@ -358,8 +383,33 @@ def make_converter(datatype: Optional[Type], values: Dict[str, Any]) -> PvaConve class PvaSignalBackend(SignalBackend[T]): _ctxt: Optional[Context] = None + _ALLOWED_DATATYPES = ( + bool, + int, + float, + str, + Sequence, + np.ndarray, + Enum, + RuntimeSubsetEnum, + BaseModel, + dict, + ) + + @classmethod + def datatype_allowed(cls, datatype: Optional[Type]) -> bool: + stripped_origin = get_origin(datatype) or datatype + if datatype is None: + return True + return inspect.isclass(stripped_origin) and issubclass( + stripped_origin, cls._ALLOWED_DATATYPES + ) + def __init__(self, datatype: Optional[Type[T]], read_pv: str, write_pv: str): self.datatype = datatype + if not PvaSignalBackend.datatype_allowed(self.datatype): + raise TypeError(f"Given datatype {self.datatype} unsupported in PVA.") + self.read_pv = read_pv self.write_pv = write_pv self.initial_values: Dict[str, Any] = {} diff --git a/src/ophyd_async/fastcs/panda/__init__.py b/src/ophyd_async/fastcs/panda/__init__.py index 9d1c1d429f..0dbe7222b0 100644 --- a/src/ophyd_async/fastcs/panda/__init__.py +++ b/src/ophyd_async/fastcs/panda/__init__.py @@ -15,10 +15,7 @@ DatasetTable, PandaHdf5DatasetType, SeqTable, - SeqTableRow, SeqTrigger, - seq_table_from_arrays, - seq_table_from_rows, ) from ._trigger import ( PcompInfo, @@ -45,10 +42,7 @@ "DatasetTable", "PandaHdf5DatasetType", "SeqTable", - "SeqTableRow", "SeqTrigger", - "seq_table_from_arrays", - "seq_table_from_rows", "PcompInfo", "SeqTableInfo", "StaticPcompTriggerLogic", diff --git a/src/ophyd_async/fastcs/panda/_table.py b/src/ophyd_async/fastcs/panda/_table.py index ec2c1a5b8b..ee6df7522f 100644 --- a/src/ophyd_async/fastcs/panda/_table.py +++ b/src/ophyd_async/fastcs/panda/_table.py @@ -1,11 +1,14 @@ -from dataclasses import dataclass +import inspect from enum import Enum -from typing import Optional, Sequence, Type, TypeVar +from typing import Annotated, Sequence import numpy as np import numpy.typing as npt -import pydantic_numpy.typing as pnd -from typing_extensions import NotRequired, TypedDict +from pydantic import Field, field_validator, model_validator +from pydantic_numpy.helper.annotation import NpArrayPydanticAnnotation +from typing_extensions import TypedDict + +from ophyd_async.core import Table class PandaHdf5DatasetType(str, Enum): @@ -34,137 +37,113 @@ class SeqTrigger(str, Enum): POSC_LT = "POSC<=POSITION" -@dataclass -class SeqTableRow: - repeats: int = 1 - trigger: SeqTrigger = SeqTrigger.IMMEDIATE - position: int = 0 - time1: int = 0 - outa1: bool = False - outb1: bool = False - outc1: bool = False - outd1: bool = False - oute1: bool = False - outf1: bool = False - time2: int = 0 - outa2: bool = False - outb2: bool = False - outc2: bool = False - outd2: bool = False - oute2: bool = False - outf2: bool = False - - -class SeqTable(TypedDict): - repeats: NotRequired[pnd.Np1DArrayUint16] - trigger: NotRequired[Sequence[SeqTrigger]] - position: NotRequired[pnd.Np1DArrayInt32] - time1: NotRequired[pnd.Np1DArrayUint32] - outa1: NotRequired[pnd.Np1DArrayBool] - outb1: NotRequired[pnd.Np1DArrayBool] - outc1: NotRequired[pnd.Np1DArrayBool] - outd1: NotRequired[pnd.Np1DArrayBool] - oute1: NotRequired[pnd.Np1DArrayBool] - outf1: NotRequired[pnd.Np1DArrayBool] - time2: NotRequired[pnd.Np1DArrayUint32] - outa2: NotRequired[pnd.Np1DArrayBool] - outb2: NotRequired[pnd.Np1DArrayBool] - outc2: NotRequired[pnd.Np1DArrayBool] - outd2: NotRequired[pnd.Np1DArrayBool] - oute2: NotRequired[pnd.Np1DArrayBool] - outf2: NotRequired[pnd.Np1DArrayBool] - - -def seq_table_from_rows(*rows: SeqTableRow): - """ - Constructs a sequence table from a series of rows. - """ - return seq_table_from_arrays( - repeats=np.array([row.repeats for row in rows], dtype=np.uint16), - trigger=[row.trigger for row in rows], - position=np.array([row.position for row in rows], dtype=np.int32), - time1=np.array([row.time1 for row in rows], dtype=np.uint32), - outa1=np.array([row.outa1 for row in rows], dtype=np.bool_), - outb1=np.array([row.outb1 for row in rows], dtype=np.bool_), - outc1=np.array([row.outc1 for row in rows], dtype=np.bool_), - outd1=np.array([row.outd1 for row in rows], dtype=np.bool_), - oute1=np.array([row.oute1 for row in rows], dtype=np.bool_), - outf1=np.array([row.outf1 for row in rows], dtype=np.bool_), - time2=np.array([row.time2 for row in rows], dtype=np.uint32), - outa2=np.array([row.outa2 for row in rows], dtype=np.bool_), - outb2=np.array([row.outb2 for row in rows], dtype=np.bool_), - outc2=np.array([row.outc2 for row in rows], dtype=np.bool_), - outd2=np.array([row.outd2 for row in rows], dtype=np.bool_), - oute2=np.array([row.oute2 for row in rows], dtype=np.bool_), - outf2=np.array([row.outf2 for row in rows], dtype=np.bool_), - ) - - -T = TypeVar("T", bound=np.generic) - - -def seq_table_from_arrays( - *, - repeats: Optional[npt.NDArray[np.uint16]] = None, - trigger: Optional[Sequence[SeqTrigger]] = None, - position: Optional[npt.NDArray[np.int32]] = None, - time1: Optional[npt.NDArray[np.uint32]] = None, - outa1: Optional[npt.NDArray[np.bool_]] = None, - outb1: Optional[npt.NDArray[np.bool_]] = None, - outc1: Optional[npt.NDArray[np.bool_]] = None, - outd1: Optional[npt.NDArray[np.bool_]] = None, - oute1: Optional[npt.NDArray[np.bool_]] = None, - outf1: Optional[npt.NDArray[np.bool_]] = None, - time2: npt.NDArray[np.uint32], - outa2: Optional[npt.NDArray[np.bool_]] = None, - outb2: Optional[npt.NDArray[np.bool_]] = None, - outc2: Optional[npt.NDArray[np.bool_]] = None, - outd2: Optional[npt.NDArray[np.bool_]] = None, - oute2: Optional[npt.NDArray[np.bool_]] = None, - outf2: Optional[npt.NDArray[np.bool_]] = None, -) -> SeqTable: - """ - Constructs a sequence table from a series of columns as arrays. - time2 is the only required argument and must not be None. - All other provided arguments must be of equal length to time2. - If any other argument is not given, or else given as None or empty, - an array of length len(time2) filled with the following is defaulted: - repeats: 1 - trigger: SeqTrigger.IMMEDIATE - all others: 0/False as appropriate - """ - assert time2 is not None, "time2 must be provided" - length = len(time2) - assert 0 < length < 4096, f"Length {length} not in range" - - def or_default( - value: Optional[npt.NDArray[T]], dtype: Type[T], default_value: int = 0 - ) -> npt.NDArray[T]: - if value is None or len(value) == 0: - return np.full(length, default_value, dtype=dtype) - return value - - table = SeqTable( - repeats=or_default(repeats, np.uint16, 1), - trigger=trigger or [SeqTrigger.IMMEDIATE] * length, - position=or_default(position, np.int32), - time1=or_default(time1, np.uint32), - outa1=or_default(outa1, np.bool_), - outb1=or_default(outb1, np.bool_), - outc1=or_default(outc1, np.bool_), - outd1=or_default(outd1, np.bool_), - oute1=or_default(oute1, np.bool_), - outf1=or_default(outf1, np.bool_), - time2=time2, - outa2=or_default(outa2, np.bool_), - outb2=or_default(outb2, np.bool_), - outc2=or_default(outc2, np.bool_), - outd2=or_default(outd2, np.bool_), - oute2=or_default(oute2, np.bool_), - outf2=or_default(outf2, np.bool_), - ) - for k, v in table.items(): - size = len(v) # type: ignore - if size != length: - raise ValueError(f"{k}: has length {size} not {length}") - return table +PydanticNp1DArrayInt32 = Annotated[ + np.ndarray[tuple[int], np.int32], + NpArrayPydanticAnnotation.factory( + data_type=np.int32, dimensions=1, strict_data_typing=False + ), + Field(default_factory=lambda: np.array([], np.int32)), +] +PydanticNp1DArrayBool = Annotated[ + np.ndarray[tuple[int], np.bool_], + NpArrayPydanticAnnotation.factory( + data_type=np.bool_, dimensions=1, strict_data_typing=False + ), + Field(default_factory=lambda: np.array([], dtype=np.bool_)), +] +TriggerStr = Annotated[ + np.ndarray[tuple[int], np.unicode_], + NpArrayPydanticAnnotation.factory( + data_type=np.unicode_, dimensions=1, strict_data_typing=False + ), + Field(default_factory=lambda: np.array([], dtype=np.dtype(" "SeqTable": + sig = inspect.signature(cls.row) + kwargs = {k: v for k, v in locals().items() if k in sig.parameters} + + if isinstance(kwargs["trigger"], SeqTrigger): + kwargs["trigger"] = kwargs["trigger"].value + elif isinstance(kwargs["trigger"], str): + SeqTrigger(kwargs["trigger"]) + + return Table.row(cls, **kwargs) + + @field_validator("trigger", mode="before") + @classmethod + def trigger_to_np_array(cls, trigger_column): + """ + The user can provide a list of SeqTrigger enum elements instead of a numpy str. + """ + if isinstance(trigger_column, Sequence) and all( + isinstance(trigger, SeqTrigger) for trigger in trigger_column + ): + trigger_column = np.array( + [trigger.value for trigger in trigger_column], dtype=np.dtype(" "SeqTable": + """ + Used to check max_length. Unfortunately trying the `max_length` arg in + the pydantic field doesn't work + """ + + first_length = len(next(iter(self))[1]) + assert 0 <= first_length < 4096, f"Length {first_length} not in range." + return self diff --git a/src/ophyd_async/plan_stubs/_fly.py b/src/ophyd_async/plan_stubs/_fly.py index 087ec62dd1..daa686b477 100644 --- a/src/ophyd_async/plan_stubs/_fly.py +++ b/src/ophyd_async/plan_stubs/_fly.py @@ -15,8 +15,6 @@ PcompInfo, SeqTable, SeqTableInfo, - SeqTableRow, - seq_table_from_rows, ) @@ -74,24 +72,26 @@ def prepare_static_seq_table_flyer_and_detectors_with_same_trigger( trigger_time = number_of_frames * (exposure + deadtime) pre_delay = max(period - 2 * shutter_time - trigger_time, 0) - table: SeqTable = seq_table_from_rows( + table = ( # Wait for pre-delay then open shutter - SeqTableRow( + SeqTable.row( time1=in_micros(pre_delay), time2=in_micros(shutter_time), outa2=True, - ), + ) + + # Keeping shutter open, do N triggers - SeqTableRow( + SeqTable.row( repeats=number_of_frames, time1=in_micros(exposure), outa1=True, outb1=True, time2=in_micros(deadtime), outa2=True, - ), + ) + + # Add the shutter close - SeqTableRow(time2=in_micros(shutter_time)), + SeqTable.row(time2=in_micros(shutter_time)) ) table_info = SeqTableInfo(sequence_table=table, repeats=repeats) diff --git a/tests/core/test_device_save_loader.py b/tests/core/test_device_save_loader.py index aa60be9802..b265b86137 100644 --- a/tests/core/test_device_save_loader.py +++ b/tests/core/test_device_save_loader.py @@ -8,6 +8,8 @@ import pytest import yaml from bluesky.run_engine import RunEngine +from pydantic import BaseModel, Field +from pydantic_numpy.typing import NpNDArrayFp16, NpNDArrayInt32 from ophyd_async.core import ( Device, @@ -54,6 +56,16 @@ class MyEnum(str, Enum): three = "three" +class SomePvaPydanticModel(BaseModel): + some_int_field: int = Field(default=1) + some_pydantic_numpy_field_float: NpNDArrayFp16 = Field( + default_factory=lambda: np.array([1, 2, 3]) + ) + some_pydantic_numpy_field_int: NpNDArrayInt32 = Field( + default_factory=lambda: np.array([1, 2, 3]) + ) + + class DummyDeviceGroupAllTypes(Device): def __init__(self, name: str): self.pv_int: SignalRW = epics_signal_rw(int, "PV1") @@ -73,6 +85,9 @@ def __init__(self, name: str): self.pv_array_float64 = epics_signal_rw(npt.NDArray[np.float64], "PV14") self.pv_array_npstr = epics_signal_rw(npt.NDArray[np.str_], "PV15") self.pv_array_str = epics_signal_rw(Sequence[str], "PV16") + self.pv_protocol_device_abstraction = epics_signal_rw( + SomePvaPydanticModel, "pva://PV17" + ) @pytest.fixture @@ -155,6 +170,7 @@ async def test_save_device_all_types(RE: RunEngine, device_all_types, tmp_path): await device_all_types.pv_array_str.set( ["one", "two", "three"], ) + await device_all_types.pv_protocol_device_abstraction.set(SomePvaPydanticModel()) # Create save plan from utility functions def save_my_device(): diff --git a/tests/core/test_signal.py b/tests/core/test_signal.py index 3b4c4934f4..ab5c02cffe 100644 --- a/tests/core/test_signal.py +++ b/tests/core/test_signal.py @@ -403,3 +403,28 @@ async def test_subscription_logs(caplog): assert "Making subscription" in caplog.text mock_signal_rw.clear_sub(cbs.append) assert "Closing subscription on source" in caplog.text + + +async def test_signal_unknown_datatype(): + class SomeClass: + def __init__(self): + self.some_attribute = "some_attribute" + + def some_function(self): + pass + + err_str = ( + "Given datatype .SomeClass'>" + " unsupported in %s." + ) + with pytest.raises(TypeError, match=err_str % ("PVA",)): + epics_signal_rw(SomeClass, "pva://mock_signal", name="mock_signal") + with pytest.raises(TypeError, match=err_str % ("CA",)): + epics_signal_rw(SomeClass, "ca://mock_signal", name="mock_signal") + + # Any dtype allowed in soft signal + signal = soft_signal_rw(SomeClass, SomeClass(), "soft_signal") + assert isinstance((await signal.get_value()), SomeClass) + await signal.set(1) + assert (await signal.get_value()) == 1 diff --git a/tests/core/test_soft_signal_backend.py b/tests/core/test_soft_signal_backend.py index 5e55507626..16bf23567e 100644 --- a/tests/core/test_soft_signal_backend.py +++ b/tests/core/test_soft_signal_backend.py @@ -94,7 +94,7 @@ async def test_soft_signal_backend_get_put_monitor( descriptor: Callable[[Any], dict], dtype_numpy: str, ): - backend = SoftSignalBackend(datatype) + backend = SoftSignalBackend(datatype=datatype) await backend.connect() q = MonitorQueue(backend) diff --git a/tests/core/test_subset_enum.py b/tests/core/test_subset_enum.py index 41af248aac..8c638d2770 100644 --- a/tests/core/test_subset_enum.py +++ b/tests/core/test_subset_enum.py @@ -7,8 +7,8 @@ from ophyd_async.epics.signal import epics_signal_rw # Allow these imports from private modules for tests -from ophyd_async.epics.signal._aioca import make_converter as aioca_make_converter -from ophyd_async.epics.signal._p4p import make_converter as p4p_make_converter +from ophyd_async.epics.signal._aioca import make_converter as ca_make_converter +from ophyd_async.epics.signal._p4p import make_converter as pva_make_converter async def test_runtime_enum_behaviour(): @@ -52,7 +52,7 @@ def __init__(self): epics_value = EpicsValue() rt_enum = SubsetEnum["A", "B"] - converter = aioca_make_converter( + converter = ca_make_converter( rt_enum, values={"READ_PV": epics_value, "WRITE_PV": epics_value} ) assert converter.choices == {"A": "A", "B": "B", "C": "C"} @@ -68,7 +68,7 @@ async def test_pva_runtime_enum_converter(): }, ) rt_enum = SubsetEnum["A", "B"] - converter = p4p_make_converter( + converter = pva_make_converter( rt_enum, values={"READ_PV": epics_value, "WRITE_PV": epics_value} ) assert {"A", "B"}.issubset(set(converter.choices)) diff --git a/tests/fastcs/panda/test_panda_connect.py b/tests/fastcs/panda/test_panda_connect.py index 2685f3c66c..b6dcbe2b00 100644 --- a/tests/fastcs/panda/test_panda_connect.py +++ b/tests/fastcs/panda/test_panda_connect.py @@ -115,6 +115,7 @@ async def test_panda_children_connected(mock_panda): oute2=np.array([1, 0, 1, 0]).astype(np.bool_), outf2=np.array([1, 0, 0, 0]).astype(np.bool_), ) + await mock_panda.pulse[1].delay.set(20.0) await mock_panda.seq[1].table.set(table) diff --git a/tests/fastcs/panda/test_panda_utils.py b/tests/fastcs/panda/test_panda_utils.py index f5d4e02600..d8b9a01269 100644 --- a/tests/fastcs/panda/test_panda_utils.py +++ b/tests/fastcs/panda/test_panda_utils.py @@ -1,22 +1,20 @@ -from unittest.mock import patch - -import pytest +import numpy as np +import yaml from bluesky import RunEngine -from ophyd_async.core import DEFAULT_TIMEOUT, DeviceCollector, save_device +from ophyd_async.core import DEFAULT_TIMEOUT, DeviceCollector, load_device, save_device from ophyd_async.epics.pvi import fill_pvi_entries from ophyd_async.epics.signal import epics_signal_rw from ophyd_async.fastcs.panda import ( CommonPandaBlocks, DataBlock, - PcompDirectionOptions, + SeqTable, TimeUnits, phase_sorter, ) -@pytest.fixture -async def mock_panda(): +async def get_mock_panda(): class Panda(CommonPandaBlocks): data: DataBlock @@ -33,56 +31,112 @@ async def connect(self, mock: bool = False, timeout: float = DEFAULT_TIMEOUT): async with DeviceCollector(mock=True): mock_panda = Panda("PANDA") mock_panda.phase_1_signal_units = epics_signal_rw(int, "") - assert mock_panda.name == "mock_panda" - yield mock_panda + return mock_panda + + +async def test_save_load_panda(tmp_path, RE: RunEngine): + mock_panda1 = await get_mock_panda() + await mock_panda1.seq[1].table.set(SeqTable.row(repeats=1)) + RE(save_device(mock_panda1, str(tmp_path / "panda.yaml"), sorter=phase_sorter)) -@patch("ophyd_async.core._device_save_loader.save_to_yaml") -async def test_save_panda(mock_save_to_yaml, mock_panda, RE: RunEngine): - RE(save_device(mock_panda, "path", sorter=phase_sorter)) - mock_save_to_yaml.assert_called_once() - assert mock_save_to_yaml.call_args[0] == ( - [ - { - "phase_1_signal_units": 0, - "seq.1.prescale_units": TimeUnits("min"), - "seq.2.prescale_units": TimeUnits("min"), - }, - { - "data.capture": False, - "data.create_directory": 0, - "data.flush_period": 0.0, - "data.hdf_directory": "", - "data.hdf_file_name": "", - "data.num_capture": 0, - "pcap.arm": False, - "pcomp.1.dir": PcompDirectionOptions.positive, - "pcomp.1.enable": "ZERO", - "pcomp.1.pulses": 0, - "pcomp.1.start": 0, - "pcomp.1.step": 0, - "pcomp.1.width": 0, - "pcomp.2.dir": PcompDirectionOptions.positive, - "pcomp.2.enable": "ZERO", - "pcomp.2.pulses": 0, - "pcomp.2.start": 0, - "pcomp.2.step": 0, - "pcomp.2.width": 0, - "pulse.1.delay": 0.0, - "pulse.1.width": 0.0, - "pulse.2.delay": 0.0, - "pulse.2.width": 0.0, - "seq.1.active": False, - "seq.1.table": {}, - "seq.1.repeats": 0, - "seq.1.prescale": 0.0, - "seq.1.enable": "ZERO", - "seq.2.table": {}, - "seq.2.active": False, - "seq.2.repeats": 0, - "seq.2.prescale": 0.0, - "seq.2.enable": "ZERO", - }, - ], - "path", + def check_equal_with_seq_tables(actual, expected): + assert actual.model_fields_set == expected.model_fields_set + for field_name, field_value1 in actual: + field_value2 = getattr(expected, field_name) + assert np.array_equal(field_value1, field_value2) + + mock_panda2 = await get_mock_panda() + check_equal_with_seq_tables( + (await mock_panda2.seq[1].table.get_value()), SeqTable() ) + RE(load_device(mock_panda2, str(tmp_path / "panda.yaml"))) + + check_equal_with_seq_tables( + await mock_panda2.seq[1].table.get_value(), + SeqTable.row(repeats=1), + ) + + # Load the YAML content as a string + with open(str(tmp_path / "panda.yaml"), "r") as file: + yaml_content = file.read() + + # Parse the YAML content + parsed_yaml = yaml.safe_load(yaml_content) + + assert parsed_yaml[0] == { + "phase_1_signal_units": 0, + "seq.1.prescale_units": TimeUnits("min"), + "seq.2.prescale_units": TimeUnits("min"), + } + assert parsed_yaml[1] == { + "data.capture": False, + "data.create_directory": 0, + "data.flush_period": 0.0, + "data.hdf_directory": "", + "data.hdf_file_name": "", + "data.num_capture": 0, + "pcap.arm": False, + "pcomp.1.dir": "Positive", + "pcomp.1.enable": "ZERO", + "pcomp.1.pulses": 0, + "pcomp.1.start": 0, + "pcomp.1.step": 0, + "pcomp.1.width": 0, + "pcomp.2.dir": "Positive", + "pcomp.2.enable": "ZERO", + "pcomp.2.pulses": 0, + "pcomp.2.start": 0, + "pcomp.2.step": 0, + "pcomp.2.width": 0, + "pulse.1.delay": 0.0, + "pulse.1.width": 0.0, + "pulse.2.delay": 0.0, + "pulse.2.width": 0.0, + "seq.1.active": False, + "seq.1.table": { + "outa1": [False], + "outa2": [False], + "outb1": [False], + "outb2": [False], + "outc1": [False], + "outc2": [False], + "outd1": [False], + "outd2": [False], + "oute1": [False], + "oute2": [False], + "outf1": [False], + "outf2": [False], + "position": [0], + "repeats": [1], + "time1": [0], + "time2": [0], + "trigger": ["Immediate"], + }, + "seq.1.repeats": 0, + "seq.1.prescale": 0.0, + "seq.1.enable": "ZERO", + "seq.2.table": { + "outa1": [], + "outa2": [], + "outb1": [], + "outb2": [], + "outc1": [], + "outc2": [], + "outd1": [], + "outd2": [], + "oute1": [], + "oute2": [], + "outf1": [], + "outf2": [], + "position": [], + "repeats": [], + "time1": [], + "time2": [], + "trigger": [], + }, + "seq.2.active": False, + "seq.2.repeats": 0, + "seq.2.prescale": 0.0, + "seq.2.enable": "ZERO", + } diff --git a/tests/fastcs/panda/test_table.py b/tests/fastcs/panda/test_table.py index ad92683bbd..c5f5abb846 100644 --- a/tests/fastcs/panda/test_table.py +++ b/tests/fastcs/panda/test_table.py @@ -1,31 +1,226 @@ +from functools import reduce + import numpy as np import pytest +from pydantic import ValidationError + +from ophyd_async.fastcs.panda import SeqTable +from ophyd_async.fastcs.panda._table import SeqTrigger + + +def test_seq_table_converts_lists(): + seq_table_dict_with_lists = {field_name: [] for field_name, _ in SeqTable()} + # Validation passes + seq_table = SeqTable(**seq_table_dict_with_lists) + assert isinstance(seq_table.trigger, np.ndarray) + assert seq_table.trigger.dtype == np.dtype("U32") + + +def test_seq_table_validation_errors(): + with pytest.raises(ValidationError, match="81 validation errors for SeqTable"): + SeqTable( + repeats=0, + trigger="Immediate", + position=0, + time1=0, + outa1=False, + outb1=False, + outc1=False, + outd1=False, + oute1=False, + outf1=False, + time2=0, + outa2=False, + outb2=False, + outc2=False, + outd2=False, + oute2=False, + outf2=False, + ) + + large_seq_table = SeqTable( + repeats=np.zeros(4095, dtype=np.int32), + trigger=np.array(["Immediate"] * 4095, dtype="U32"), + position=np.zeros(4095, dtype=np.int32), + time1=np.zeros(4095, dtype=np.int32), + outa1=np.zeros(4095, dtype=np.bool_), + outb1=np.zeros(4095, dtype=np.bool_), + outc1=np.zeros(4095, dtype=np.bool_), + outd1=np.zeros(4095, dtype=np.bool_), + oute1=np.zeros(4095, dtype=np.bool_), + outf1=np.zeros(4095, dtype=np.bool_), + time2=np.zeros(4095, dtype=np.int32), + outa2=np.zeros(4095, dtype=np.bool_), + outb2=np.zeros(4095, dtype=np.bool_), + outc2=np.zeros(4095, dtype=np.bool_), + outd2=np.zeros(4095, dtype=np.bool_), + oute2=np.zeros(4095, dtype=np.bool_), + outf2=np.zeros(4095, dtype=np.bool_), + ) + with pytest.raises( + ValidationError, + match=( + "1 validation error for SeqTable\n " + "Assertion failed, Length 4096 not in range." + ), + ): + large_seq_table + SeqTable.row() + with pytest.raises( + ValidationError, + match="12 validation errors for SeqTable", + ): + row_one = SeqTable.row() + wrong_types = { + field_name: field_value.astype(np.unicode_) + for field_name, field_value in row_one + } + SeqTable(**wrong_types) -from ophyd_async.fastcs.panda import seq_table_from_arrays +def test_seq_table_pva_conversion(): + pva_dict = { + "repeats": np.array([1, 2, 3, 4], dtype=np.int32), + "trigger": np.array( + ["Immediate", "Immediate", "BITC=0", "Immediate"], dtype=np.dtype("U32") + ), + "position": np.array([1, 2, 3, 4], dtype=np.int32), + "time1": np.array([1, 0, 1, 0], dtype=np.int32), + "outa1": np.array([1, 0, 1, 0], dtype=np.bool_), + "outb1": np.array([1, 0, 1, 0], dtype=np.bool_), + "outc1": np.array([1, 0, 1, 0], dtype=np.bool_), + "outd1": np.array([1, 0, 1, 0], dtype=np.bool_), + "oute1": np.array([1, 0, 1, 0], dtype=np.bool_), + "outf1": np.array([1, 0, 1, 0], dtype=np.bool_), + "time2": np.array([1, 2, 3, 4], dtype=np.int32), + "outa2": np.array([1, 0, 1, 0], dtype=np.bool_), + "outb2": np.array([1, 0, 1, 0], dtype=np.bool_), + "outc2": np.array([1, 0, 1, 0], dtype=np.bool_), + "outd2": np.array([1, 0, 1, 0], dtype=np.bool_), + "oute2": np.array([1, 0, 1, 0], dtype=np.bool_), + "outf2": np.array([1, 0, 1, 0], dtype=np.bool_), + } + row_wise_dicts = [ + { + "repeats": 1, + "trigger": "Immediate", + "position": 1, + "time1": 1, + "outa1": 1, + "outb1": 1, + "outc1": 1, + "outd1": 1, + "oute1": 1, + "outf1": 1, + "time2": 1, + "outa2": 1, + "outb2": 1, + "outc2": 1, + "outd2": 1, + "oute2": 1, + "outf2": 1, + }, + { + "repeats": 2, + "trigger": "Immediate", + "position": 2, + "time1": 0, + "outa1": 0, + "outb1": 0, + "outc1": 0, + "outd1": 0, + "oute1": 0, + "outf1": 0, + "time2": 2, + "outa2": 0, + "outb2": 0, + "outc2": 0, + "outd2": 0, + "oute2": 0, + "outf2": 0, + }, + { + "repeats": 3, + "trigger": "BITC=0", + "position": 3, + "time1": 1, + "outa1": 1, + "outb1": 1, + "outc1": 1, + "outd1": 1, + "oute1": 1, + "outf1": 1, + "time2": 3, + "outa2": 1, + "outb2": 1, + "outc2": 1, + "outd2": 1, + "oute2": 1, + "outf2": 1, + }, + { + "repeats": 4, + "trigger": "Immediate", + "position": 4, + "time1": 0, + "outa1": 0, + "outb1": 0, + "outc1": 0, + "outd1": 0, + "oute1": 0, + "outf1": 0, + "time2": 4, + "outa2": 0, + "outb2": 0, + "outc2": 0, + "outd2": 0, + "oute2": 0, + "outf2": 0, + }, + ] -def test_from_arrays_inconsistent_lengths(): - length = 4 - time2 = np.zeros(length) - time1 = np.zeros(length + 1) - with pytest.raises(ValueError, match="time1: has length 5 not 4"): - seq_table_from_arrays(time2=time2, time1=time1) - time1 = np.zeros(length - 1) - with pytest.raises(ValueError, match="time1: has length 3 not 4"): - seq_table_from_arrays(time2=time2, time1=time1) + seq_table_from_pva_dict = SeqTable(**pva_dict) + for (_, column1), column2 in zip(seq_table_from_pva_dict, pva_dict.values()): + assert np.array_equal(column1, column2) + assert column1.dtype == column2.dtype + seq_table_from_rows = reduce( + lambda x, y: x + y, + [SeqTable.row(**row_kwargs) for row_kwargs in row_wise_dicts], + ) + for (_, column1), column2 in zip(seq_table_from_rows, pva_dict.values()): + assert np.array_equal(column1, column2) + assert column1.dtype == column2.dtype -def test_from_arrays_no_time(): - with pytest.raises(AssertionError, match="time2 must be provided"): - seq_table_from_arrays(time2=None) # type: ignore - with pytest.raises(TypeError, match="required keyword-only argument: 'time2'"): - seq_table_from_arrays() # type: ignore - time2 = np.zeros(0) - with pytest.raises(AssertionError, match="Length 0 not in range"): - seq_table_from_arrays(time2=time2) + # Idempotency + applied_twice_to_pva_dict = SeqTable(**pva_dict).model_dump(mode="python") + for column1, column2 in zip(applied_twice_to_pva_dict.values(), pva_dict.values()): + assert np.array_equal(column1, column2) + assert column1.dtype == column2.dtype -def test_from_arrays_too_long(): - time2 = np.zeros(4097) - with pytest.raises(AssertionError, match="Length 4097 not in range"): - seq_table_from_arrays(time2=time2) +def test_seq_table_takes_trigger_enum_row(): + for trigger in (SeqTrigger.BITA_0, "BITA=0"): + table = SeqTable.row(trigger=trigger) + assert table.trigger[0] == "BITA=0" + assert np.issubdtype(table.trigger.dtype, np.dtype(" Date: Fri, 13 Sep 2024 10:12:57 +0100 Subject: [PATCH 4/5] Use asyncio TimeoutError in wait_for_value (#573) --- src/ophyd_async/core/_signal.py | 2 +- tests/core/test_signal.py | 4 ++-- tests/epics/adpilatus/test_pilatus.py | 3 ++- tests/fastcs/panda/test_writer.py | 3 ++- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/ophyd_async/core/_signal.py b/src/ophyd_async/core/_signal.py index 503c1f9ca3..340298160b 100644 --- a/src/ophyd_async/core/_signal.py +++ b/src/ophyd_async/core/_signal.py @@ -505,7 +505,7 @@ async def wait_for_value(self, signal: SignalR[T], timeout: Optional[float]): try: await asyncio.wait_for(self._wait_for_value(signal), timeout) except asyncio.TimeoutError as e: - raise TimeoutError( + raise asyncio.TimeoutError( f"{signal.name} didn't match {self._matcher_name} in {timeout}s, " f"last value {self._last_value!r}" ) from e diff --git a/tests/core/test_signal.py b/tests/core/test_signal.py index ab5c02cffe..d498e67531 100644 --- a/tests/core/test_signal.py +++ b/tests/core/test_signal.py @@ -240,7 +240,7 @@ async def test_wait_for_value_with_value(): await signal.set("blah") with pytest.raises( - TimeoutError, + asyncio.TimeoutError, match="signal didn't match 'something' in 0.1s, last value 'blah'", ): await wait_for_value(signal, "something", timeout=0.1) @@ -263,7 +263,7 @@ def less_than_42(v): return v < 42 with pytest.raises( - TimeoutError, + asyncio.TimeoutError, match="signal didn't match less_than_42 in 0.1s, last value 45.8", ): await wait_for_value(signal, less_than_42, timeout=0.1) diff --git a/tests/epics/adpilatus/test_pilatus.py b/tests/epics/adpilatus/test_pilatus.py index 00b3c119a8..21ba42fb2c 100644 --- a/tests/epics/adpilatus/test_pilatus.py +++ b/tests/epics/adpilatus/test_pilatus.py @@ -1,3 +1,4 @@ +import asyncio from typing import Awaitable, Callable from unittest.mock import patch @@ -84,7 +85,7 @@ async def trigger_and_complete(): "ophyd_async.epics.adpilatus._pilatus_controller.DEFAULT_TIMEOUT", 0.1, ): - with pytest.raises(TimeoutError): + with pytest.raises(asyncio.TimeoutError): await _trigger( test_adpilatus, adpilatus.PilatusTriggerMode.internal, diff --git a/tests/fastcs/panda/test_writer.py b/tests/fastcs/panda/test_writer.py index 7bdca1a31b..dc26787cbf 100644 --- a/tests/fastcs/panda/test_writer.py +++ b/tests/fastcs/panda/test_writer.py @@ -1,3 +1,4 @@ +import asyncio import logging import os from pathlib import Path @@ -187,7 +188,7 @@ async def test_wait_for_index(mock_writer: PandaHDFWriter): set_mock_value(mock_writer.panda_data_block.num_captured, 3) await mock_writer.wait_for_index(3, timeout=1) set_mock_value(mock_writer.panda_data_block.num_captured, 2) - with pytest.raises(TimeoutError): + with pytest.raises(asyncio.TimeoutError): await mock_writer.wait_for_index(3, timeout=0.1) From 5b675810653bc5d8816811b0850dd5336d1851bb Mon Sep 17 00:00:00 2001 From: Zoheb Shaikh Date: Fri, 13 Sep 2024 11:27:56 +0100 Subject: [PATCH 5/5] Interface change of StandardDetector and Standard Controller (#568) --- src/ophyd_async/core/_detector.py | 92 +++++++----- .../epics/adaravis/_aravis_controller.py | 27 ++-- .../epics/adkinetix/_kinetix_controller.py | 31 ++-- .../epics/adpilatus/_pilatus_controller.py | 32 +++-- .../epics/adsimdetector/_sim_controller.py | 32 +++-- .../epics/advimba/_vimba_controller.py | 35 +++-- .../epics/eiger/_eiger_controller.py | 30 ++-- src/ophyd_async/fastcs/panda/_control.py | 27 ++-- src/ophyd_async/plan_stubs/_fly.py | 2 + .../_pattern_detector_controller.py | 43 +++--- tests/core/test_flyer.py | 2 +- tests/epics/adaravis/test_aravis.py | 12 +- tests/epics/adcore/test_scans.py | 14 +- tests/epics/adkinetix/test_kinetix.py | 6 +- tests/epics/adpilatus/test_pilatus.py | 16 +-- .../adpilatus/test_pilatus_controller.py | 6 +- .../adsimdetector/test_adsim_controller.py | 8 +- tests/epics/adsimdetector/test_sim.py | 3 - tests/epics/advimba/test_vimba.py | 4 +- tests/epics/eiger/test_eiger_controller.py | 17 ++- tests/fastcs/panda/test_hdf_panda.py | 134 +++++++++++++++++- tests/fastcs/panda/test_panda_control.py | 16 ++- tests/plan_stubs/test_fly.py | 5 +- 23 files changed, 397 insertions(+), 197 deletions(-) diff --git a/src/ophyd_async/core/_detector.py b/src/ophyd_async/core/_detector.py index 7d22fa6ace..45d48fa43f 100644 --- a/src/ophyd_async/core/_detector.py +++ b/src/ophyd_async/core/_detector.py @@ -51,20 +51,24 @@ class TriggerInfo(BaseModel): """Minimal set of information required to setup triggering on a detector""" #: Number of triggers that will be sent, 0 means infinite - number: int = Field(gt=0) + number: int = Field(ge=0) #: Sort of triggers that will be sent - trigger: DetectorTrigger = Field() + trigger: DetectorTrigger = Field(default=DetectorTrigger.internal) #: What is the minimum deadtime between triggers - deadtime: float | None = Field(ge=0) + deadtime: float | None = Field(default=None, ge=0) #: What is the maximum high time of the triggers - livetime: float | None = Field(ge=0) + livetime: float | None = Field(default=None, ge=0) #: What is the maximum timeout on waiting for a frame - frame_timeout: float | None = Field(None, gt=0) + frame_timeout: float | None = Field(default=None, gt=0) #: How many triggers make up a single StreamDatum index, to allow multiple frames #: from a faster detector to be zipped with a single frame from a slow detector #: e.g. if num=10 and multiplier=5 then the detector will take 10 frames, #: but publish 2 indices, and describe() will show a shape of (5, h, w) multiplier: int = 1 + #: The number of times the detector can go through a complete cycle of kickoff and + #: complete without needing to re-arm. This is important for detectors where the + #: process of arming is expensive in terms of time + iteration: int = 1 class DetectorControl(ABC): @@ -78,27 +82,35 @@ def get_deadtime(self, exposure: float | None) -> float: """For a given exposure, how long should the time between exposures be""" @abstractmethod - async def arm( - self, - num: int, - trigger: DetectorTrigger = DetectorTrigger.internal, - exposure: Optional[float] = None, - ) -> AsyncStatus: + async def prepare(self, trigger_info: TriggerInfo): """ - Arm detector, do all necessary steps to prepare detector for triggers. + Do all necessary steps to prepare the detector for triggers. Args: - num: Expected number of frames - trigger: Type of trigger for which to prepare the detector. Defaults to - DetectorTrigger.internal. - exposure: Exposure time with which to set up the detector. Defaults to None - if not applicable or the detector is expected to use its previously-set - exposure time. + trigger_info: This is a Pydantic model which contains + number Expected number of frames. + trigger Type of trigger for which to prepare the detector. Defaults + to DetectorTrigger.internal. + livetime Livetime / Exposure time with which to set up the detector. + Defaults to None + if not applicable or the detector is expected to use its previously-set + exposure time. + deadtime Defaults to None. This is the minimum deadtime between + triggers. + multiplier The number of triggers grouped into a single StreamDatum + index. + """ - Returns: - AsyncStatus: Status representing the arm operation. This function returning - represents the start of the arm. The returned status completing means - the detector is now armed. + @abstractmethod + async def arm(self) -> None: + """ + Arm the detector + """ + + @abstractmethod + async def wait_for_idle(self): + """ + This will wait on the internal _arm_status and wait for it to get disarmed/idle """ @abstractmethod @@ -186,7 +198,7 @@ def __init__( self._watchers: List[Callable] = [] self._fly_status: Optional[WatchableAsyncStatus] = None self._fly_start: float - + self._iterations_completed: int = 0 self._intial_frame: int self._last_frame: int super().__init__(name) @@ -248,15 +260,15 @@ async def trigger(self) -> None: trigger=DetectorTrigger.internal, deadtime=None, livetime=None, + frame_timeout=None, ) ) + assert self._trigger_info + assert self._trigger_info.trigger is DetectorTrigger.internal # Arm the detector and wait for it to finish. indices_written = await self.writer.get_indices_written() - written_status = await self.controller.arm( - num=self._trigger_info.number, - trigger=self._trigger_info.trigger, - ) - await written_status + await self.controller.arm() + await self.controller.wait_for_idle() end_observation = indices_written + 1 async for index in self.writer.observe_indices_written( @@ -283,35 +295,35 @@ async def prepare(self, value: TriggerInfo) -> None: Args: value: TriggerInfo describing how to trigger the detector """ - self._trigger_info = value if value.trigger != DetectorTrigger.internal: assert ( value.deadtime ), "Deadtime must be supplied when in externally triggered mode" if value.deadtime: - required = self.controller.get_deadtime(self._trigger_info.livetime) + required = self.controller.get_deadtime(value.livetime) assert required <= value.deadtime, ( f"Detector {self.controller} needs at least {required}s deadtime, " f"but trigger logic provides only {value.deadtime}s" ) + self._trigger_info = value self._initial_frame = await self.writer.get_indices_written() self._last_frame = self._initial_frame + self._trigger_info.number - self._arm_status = await self.controller.arm( - num=self._trigger_info.number, - trigger=self._trigger_info.trigger, - exposure=self._trigger_info.livetime, + self._describe, _ = await asyncio.gather( + self.writer.open(value.multiplier), self.controller.prepare(value) ) - self._fly_start = time.monotonic() - self._describe = await self.writer.open(value.multiplier) + if value.trigger != DetectorTrigger.internal: + await self.controller.arm() + self._fly_start = time.monotonic() @AsyncStatus.wrap async def kickoff(self): - if not self._arm_status: - raise Exception("Detector not armed!") + assert self._trigger_info, "Prepare must be called before kickoff!" + if self._iterations_completed >= self._trigger_info.iteration: + raise Exception(f"Kickoff called more than {self._trigger_info.iteration}") + self._iterations_completed += 1 @WatchableAsyncStatus.wrap async def complete(self): - assert self._arm_status, "Prepare not run" assert self._trigger_info async for index in self.writer.observe_indices_written( self._trigger_info.frame_timeout @@ -332,6 +344,8 @@ async def complete(self): ) if index >= self._trigger_info.number: break + if self._iterations_completed == self._trigger_info.iteration: + await self.controller.wait_for_idle() async def describe_collect(self) -> Dict[str, DataKey]: return self._describe diff --git a/src/ophyd_async/epics/adaravis/_aravis_controller.py b/src/ophyd_async/epics/adaravis/_aravis_controller.py index 6349d111b1..894a46c008 100644 --- a/src/ophyd_async/epics/adaravis/_aravis_controller.py +++ b/src/ophyd_async/epics/adaravis/_aravis_controller.py @@ -1,12 +1,13 @@ import asyncio -from typing import Literal, Optional, Tuple +from typing import Literal, Tuple from ophyd_async.core import ( - AsyncStatus, DetectorControl, DetectorTrigger, + TriggerInfo, set_and_wait_for_value, ) +from ophyd_async.core._status import AsyncStatus from ophyd_async.epics import adcore from ._aravis_io import AravisDriverIO, AravisTriggerMode, AravisTriggerSource @@ -23,24 +24,20 @@ class AravisController(DetectorControl): def __init__(self, driver: AravisDriverIO, gpio_number: GPIO_NUMBER) -> None: self._drv = driver self.gpio_number = gpio_number + self._arm_status: AsyncStatus | None = None def get_deadtime(self, exposure: float) -> float: return _HIGHEST_POSSIBLE_DEADTIME - async def arm( - self, - num: int = 0, - trigger: DetectorTrigger = DetectorTrigger.internal, - exposure: Optional[float] = None, - ) -> AsyncStatus: - if num == 0: + async def prepare(self, trigger_info: TriggerInfo): + if (num := trigger_info.number) == 0: image_mode = adcore.ImageMode.continuous else: image_mode = adcore.ImageMode.multiple - if exposure is not None: + if (exposure := trigger_info.livetime) is not None: await self._drv.acquire_time.set(exposure) - trigger_mode, trigger_source = self._get_trigger_info(trigger) + trigger_mode, trigger_source = self._get_trigger_info(trigger_info.trigger) # trigger mode must be set first and on it's own! await self._drv.trigger_mode.set(trigger_mode) @@ -50,8 +47,12 @@ async def arm( self._drv.image_mode.set(image_mode), ) - status = await set_and_wait_for_value(self._drv.acquire, True) - return status + async def arm(self): + self._arm_status = await set_and_wait_for_value(self._drv.acquire, True) + + async def wait_for_idle(self): + if self._arm_status: + await self._arm_status def _get_trigger_info( self, trigger: DetectorTrigger diff --git a/src/ophyd_async/epics/adkinetix/_kinetix_controller.py b/src/ophyd_async/epics/adkinetix/_kinetix_controller.py index 9469fda68a..70d32e6a78 100644 --- a/src/ophyd_async/epics/adkinetix/_kinetix_controller.py +++ b/src/ophyd_async/epics/adkinetix/_kinetix_controller.py @@ -1,7 +1,8 @@ import asyncio -from typing import Optional -from ophyd_async.core import AsyncStatus, DetectorControl, DetectorTrigger +from ophyd_async.core import DetectorControl, DetectorTrigger +from ophyd_async.core._detector import TriggerInfo +from ophyd_async.core._status import AsyncStatus from ophyd_async.epics import adcore from ._kinetix_io import KinetixDriverIO, KinetixTriggerMode @@ -20,27 +21,31 @@ def __init__( driver: KinetixDriverIO, ) -> None: self._drv = driver + self._arm_status: AsyncStatus | None = None def get_deadtime(self, exposure: float) -> float: return 0.001 - async def arm( - self, - num: int, - trigger: DetectorTrigger = DetectorTrigger.internal, - exposure: Optional[float] = None, - ) -> AsyncStatus: + async def prepare(self, trigger_info: TriggerInfo): await asyncio.gather( - self._drv.trigger_mode.set(KINETIX_TRIGGER_MODE_MAP[trigger]), - self._drv.num_images.set(num), + self._drv.trigger_mode.set(KINETIX_TRIGGER_MODE_MAP[trigger_info.trigger]), + self._drv.num_images.set(trigger_info.number), self._drv.image_mode.set(adcore.ImageMode.multiple), ) - if exposure is not None and trigger not in [ + if trigger_info.livetime is not None and trigger_info.trigger not in [ DetectorTrigger.variable_gate, DetectorTrigger.constant_gate, ]: - await self._drv.acquire_time.set(exposure) - return await adcore.start_acquiring_driver_and_ensure_status(self._drv) + await self._drv.acquire_time.set(trigger_info.livetime) + + async def arm(self): + self._arm_status = await adcore.start_acquiring_driver_and_ensure_status( + self._drv + ) + + async def wait_for_idle(self): + if self._arm_status: + await self._arm_status async def disarm(self): await adcore.stop_busy_record(self._drv.acquire, False, timeout=1) diff --git a/src/ophyd_async/epics/adpilatus/_pilatus_controller.py b/src/ophyd_async/epics/adpilatus/_pilatus_controller.py index dd48eaf50c..54e0d41d5d 100644 --- a/src/ophyd_async/epics/adpilatus/_pilatus_controller.py +++ b/src/ophyd_async/epics/adpilatus/_pilatus_controller.py @@ -1,13 +1,13 @@ import asyncio -from typing import Optional from ophyd_async.core import ( DEFAULT_TIMEOUT, - AsyncStatus, DetectorControl, DetectorTrigger, wait_for_value, ) +from ophyd_async.core._detector import TriggerInfo +from ophyd_async.core._status import AsyncStatus from ophyd_async.epics import adcore from ._pilatus_io import PilatusDriverIO, PilatusTriggerMode @@ -27,29 +27,29 @@ def __init__( ) -> None: self._drv = driver self._readout_time = readout_time + self._arm_status: AsyncStatus | None = None def get_deadtime(self, exposure: float) -> float: return self._readout_time - async def arm( - self, - num: int, - trigger: DetectorTrigger = DetectorTrigger.internal, - exposure: Optional[float] = None, - ) -> AsyncStatus: - if exposure is not None: + async def prepare(self, trigger_info: TriggerInfo): + if trigger_info.livetime is not None: await adcore.set_exposure_time_and_acquire_period_if_supplied( - self, self._drv, exposure + self, self._drv, trigger_info.livetime ) await asyncio.gather( - self._drv.trigger_mode.set(self._get_trigger_mode(trigger)), - self._drv.num_images.set(999_999 if num == 0 else num), + self._drv.trigger_mode.set(self._get_trigger_mode(trigger_info.trigger)), + self._drv.num_images.set( + 999_999 if trigger_info.number == 0 else trigger_info.number + ), self._drv.image_mode.set(adcore.ImageMode.multiple), ) + async def arm(self): # Standard arm the detector and wait for the acquire PV to be True - idle_status = await adcore.start_acquiring_driver_and_ensure_status(self._drv) - + self._arm_status = await adcore.start_acquiring_driver_and_ensure_status( + self._drv + ) # The pilatus has an additional PV that goes True when the camserver # is actually ready. Should wait for that too or we risk dropping # a frame @@ -59,7 +59,9 @@ async def arm( timeout=DEFAULT_TIMEOUT, ) - return idle_status + async def wait_for_idle(self): + if self._arm_status: + await self._arm_status @classmethod def _get_trigger_mode(cls, trigger: DetectorTrigger) -> PilatusTriggerMode: diff --git a/src/ophyd_async/epics/adsimdetector/_sim_controller.py b/src/ophyd_async/epics/adsimdetector/_sim_controller.py index 789f89701c..6561ee24f1 100644 --- a/src/ophyd_async/epics/adsimdetector/_sim_controller.py +++ b/src/ophyd_async/epics/adsimdetector/_sim_controller.py @@ -1,12 +1,13 @@ import asyncio -from typing import Optional, Set +from typing import Set from ophyd_async.core import ( DEFAULT_TIMEOUT, - AsyncStatus, DetectorControl, DetectorTrigger, ) +from ophyd_async.core._detector import TriggerInfo +from ophyd_async.core._status import AsyncStatus from ophyd_async.epics import adcore @@ -18,28 +19,33 @@ def __init__( ) -> None: self.driver = driver self.good_states = good_states + self.frame_timeout: float + self._arm_status: AsyncStatus | None = None def get_deadtime(self, exposure: float) -> float: return 0.002 - async def arm( - self, - num: int, - trigger: DetectorTrigger = DetectorTrigger.internal, - exposure: Optional[float] = None, - ) -> AsyncStatus: + async def prepare(self, trigger_info: TriggerInfo): assert ( - trigger == DetectorTrigger.internal + trigger_info.trigger == DetectorTrigger.internal ), "fly scanning (i.e. external triggering) is not supported for this device" - frame_timeout = DEFAULT_TIMEOUT + await self.driver.acquire_time.get_value() + self.frame_timeout = ( + DEFAULT_TIMEOUT + await self.driver.acquire_time.get_value() + ) await asyncio.gather( - self.driver.num_images.set(num), + self.driver.num_images.set(trigger_info.number), self.driver.image_mode.set(adcore.ImageMode.multiple), ) - return await adcore.start_acquiring_driver_and_ensure_status( - self.driver, good_states=self.good_states, timeout=frame_timeout + + async def arm(self): + self._arm_status = await adcore.start_acquiring_driver_and_ensure_status( + self.driver, good_states=self.good_states, timeout=self.frame_timeout ) + async def wait_for_idle(self): + if self._arm_status: + await self._arm_status + async def disarm(self): # We can't use caput callback as we already used it in arm() and we can't have # 2 or they will deadlock diff --git a/src/ophyd_async/epics/advimba/_vimba_controller.py b/src/ophyd_async/epics/advimba/_vimba_controller.py index fa0232dd2a..9b87d37872 100644 --- a/src/ophyd_async/epics/advimba/_vimba_controller.py +++ b/src/ophyd_async/epics/advimba/_vimba_controller.py @@ -1,7 +1,8 @@ import asyncio -from typing import Optional -from ophyd_async.core import AsyncStatus, DetectorControl, DetectorTrigger +from ophyd_async.core import DetectorControl, DetectorTrigger +from ophyd_async.core._detector import TriggerInfo +from ophyd_async.core._status import AsyncStatus from ophyd_async.epics import adcore from ._vimba_io import VimbaDriverIO, VimbaExposeOutMode, VimbaOnOff, VimbaTriggerSource @@ -27,32 +28,36 @@ def __init__( driver: VimbaDriverIO, ) -> None: self._drv = driver + self._arm_status: AsyncStatus | None = None def get_deadtime(self, exposure: float) -> float: return 0.001 - async def arm( - self, - num: int, - trigger: DetectorTrigger = DetectorTrigger.internal, - exposure: Optional[float] = None, - ) -> AsyncStatus: + async def prepare(self, trigger_info: TriggerInfo): await asyncio.gather( - self._drv.trigger_mode.set(TRIGGER_MODE[trigger]), - self._drv.exposure_mode.set(EXPOSE_OUT_MODE[trigger]), - self._drv.num_images.set(num), + self._drv.trigger_mode.set(TRIGGER_MODE[trigger_info.trigger]), + self._drv.exposure_mode.set(EXPOSE_OUT_MODE[trigger_info.trigger]), + self._drv.num_images.set(trigger_info.number), self._drv.image_mode.set(adcore.ImageMode.multiple), ) - if exposure is not None and trigger not in [ + if trigger_info.livetime is not None and trigger_info.trigger not in [ DetectorTrigger.variable_gate, DetectorTrigger.constant_gate, ]: - await self._drv.acquire_time.set(exposure) - if trigger != DetectorTrigger.internal: + await self._drv.acquire_time.set(trigger_info.livetime) + if trigger_info.trigger != DetectorTrigger.internal: self._drv.trigger_source.set(VimbaTriggerSource.line1) else: self._drv.trigger_source.set(VimbaTriggerSource.freerun) - return await adcore.start_acquiring_driver_and_ensure_status(self._drv) + + async def arm(self): + self._arm_status = await adcore.start_acquiring_driver_and_ensure_status( + self._drv + ) + + async def wait_for_idle(self): + if self._arm_status: + await self._arm_status async def disarm(self): await adcore.stop_busy_record(self._drv.acquire, False, timeout=1) diff --git a/src/ophyd_async/epics/eiger/_eiger_controller.py b/src/ophyd_async/epics/eiger/_eiger_controller.py index fa37bbebaf..c7542bc741 100644 --- a/src/ophyd_async/epics/eiger/_eiger_controller.py +++ b/src/ophyd_async/epics/eiger/_eiger_controller.py @@ -1,13 +1,12 @@ import asyncio -from typing import Optional from ophyd_async.core import ( DEFAULT_TIMEOUT, - AsyncStatus, DetectorControl, DetectorTrigger, set_and_wait_for_other_value, ) +from ophyd_async.core._detector import TriggerInfo from ._eiger_io import EigerDriverIO, EigerTriggerMode @@ -37,30 +36,31 @@ async def set_energy(self, energy: float, tolerance: float = 0.1): if abs(current_energy - energy) > tolerance: await self._drv.photon_energy.set(energy) - @AsyncStatus.wrap - async def arm( - self, - num: int, - trigger: DetectorTrigger = DetectorTrigger.internal, - exposure: Optional[float] = None, - ): + async def prepare(self, trigger_info: TriggerInfo): coros = [ - self._drv.trigger_mode.set(EIGER_TRIGGER_MODE_MAP[trigger].value), - self._drv.num_images.set(num), + self._drv.trigger_mode.set( + EIGER_TRIGGER_MODE_MAP[trigger_info.trigger].value + ), + self._drv.num_images.set(trigger_info.number), ] - if exposure is not None: + if trigger_info.livetime is not None: coros.extend( [ - self._drv.acquire_time.set(exposure), - self._drv.acquire_period.set(exposure), + self._drv.acquire_time.set(trigger_info.livetime), + self._drv.acquire_period.set(trigger_info.livetime), ] ) await asyncio.gather(*coros) + async def arm(self): # TODO: Detector state should be an enum see https://github.com/DiamondLightSource/eiger-fastcs/issues/43 - await set_and_wait_for_other_value( + self._arm_status = set_and_wait_for_other_value( self._drv.arm, 1, self._drv.state, "ready", timeout=DEFAULT_TIMEOUT ) + async def wait_for_idle(self): + if self._arm_status: + await self._arm_status + async def disarm(self): await self._drv.disarm.set(1) diff --git a/src/ophyd_async/fastcs/panda/_control.py b/src/ophyd_async/fastcs/panda/_control.py index 08c17df1ef..aeb8e750cd 100644 --- a/src/ophyd_async/fastcs/panda/_control.py +++ b/src/ophyd_async/fastcs/panda/_control.py @@ -1,12 +1,12 @@ import asyncio -from typing import Optional from ophyd_async.core import ( - AsyncStatus, DetectorControl, DetectorTrigger, wait_for_value, ) +from ophyd_async.core._detector import TriggerInfo +from ophyd_async.core._status import AsyncStatus from ._block import PcapBlock @@ -14,25 +14,24 @@ class PandaPcapController(DetectorControl): def __init__(self, pcap: PcapBlock) -> None: self.pcap = pcap + self._arm_status: AsyncStatus | None = None def get_deadtime(self, exposure: float) -> float: return 0.000000008 - async def arm( - self, - num: int, - trigger: DetectorTrigger = DetectorTrigger.constant_gate, - exposure: Optional[float] = None, - ) -> AsyncStatus: - assert trigger in ( + async def prepare(self, trigger_info: TriggerInfo): + assert trigger_info.trigger in ( DetectorTrigger.constant_gate, - trigger == DetectorTrigger.variable_gate, + DetectorTrigger.variable_gate, ), "Only constant_gate and variable_gate triggering is supported on the PandA" - await asyncio.gather(self.pcap.arm.set(True)) + + async def arm(self): + self._arm_status = self.pcap.arm.set(True) await wait_for_value(self.pcap.active, True, timeout=1) - return AsyncStatus(wait_for_value(self.pcap.active, False, timeout=None)) - async def disarm(self) -> AsyncStatus: + async def wait_for_idle(self): + pass + + async def disarm(self): await asyncio.gather(self.pcap.arm.set(False)) await wait_for_value(self.pcap.active, False, timeout=1) - return AsyncStatus(wait_for_value(self.pcap.active, False, timeout=None)) diff --git a/src/ophyd_async/plan_stubs/_fly.py b/src/ophyd_async/plan_stubs/_fly.py index daa686b477..2cf6f5499e 100644 --- a/src/ophyd_async/plan_stubs/_fly.py +++ b/src/ophyd_async/plan_stubs/_fly.py @@ -46,6 +46,7 @@ def prepare_static_seq_table_flyer_and_detectors_with_same_trigger( repeats: int = 1, period: float = 0.0, frame_timeout: Optional[float] = None, + iteration: int = 1, ): """Prepare a hardware triggered flyable and one or more detectors. @@ -68,6 +69,7 @@ def prepare_static_seq_table_flyer_and_detectors_with_same_trigger( deadtime=deadtime, livetime=exposure, frame_timeout=frame_timeout, + iteration=iteration, ) trigger_time = number_of_frames * (exposure + deadtime) pre_delay = max(period - 2 * shutter_time - trigger_time, 0) diff --git a/src/ophyd_async/sim/demo/_pattern_detector/_pattern_detector_controller.py b/src/ophyd_async/sim/demo/_pattern_detector/_pattern_detector_controller.py index 064bea873f..039ddb066c 100644 --- a/src/ophyd_async/sim/demo/_pattern_detector/_pattern_detector_controller.py +++ b/src/ophyd_async/sim/demo/_pattern_detector/_pattern_detector_controller.py @@ -1,7 +1,10 @@ import asyncio from typing import Optional -from ophyd_async.core import AsyncStatus, DetectorControl, DetectorTrigger, PathProvider +from pydantic import Field + +from ophyd_async.core import DetectorControl, PathProvider +from ophyd_async.core._detector import TriggerInfo from ._pattern_generator import PatternGenerator @@ -11,30 +14,36 @@ def __init__( self, pattern_generator: PatternGenerator, path_provider: PathProvider, - exposure: float = 0.1, + exposure: float = Field(default=0.1), ) -> None: self.pattern_generator: PatternGenerator = pattern_generator - if exposure is None: - exposure = 0.1 self.pattern_generator.set_exposure(exposure) self.path_provider: PathProvider = path_provider self.task: Optional[asyncio.Task] = None super().__init__() - async def arm( - self, - num: int, - trigger: DetectorTrigger = DetectorTrigger.internal, - exposure: Optional[float] = 0.01, - ) -> AsyncStatus: - if exposure is None: - exposure = 0.1 - period: float = exposure + self.get_deadtime(exposure) - task = asyncio.create_task( - self._coroutine_for_image_writing(exposure, period, num) + async def prepare( + self, trigger_info: TriggerInfo = TriggerInfo(number=1, livetime=0.01) + ): + self._trigger_info = trigger_info + if self._trigger_info.livetime is None: + self._trigger_info.livetime = 0.01 + self.period: float = self._trigger_info.livetime + self.get_deadtime( + trigger_info.livetime ) - self.task = task - return AsyncStatus(task) + + async def arm(self): + assert self._trigger_info.livetime + assert self.period + self.task = asyncio.create_task( + self._coroutine_for_image_writing( + self._trigger_info.livetime, self.period, self._trigger_info.number + ) + ) + + async def wait_for_idle(self): + if self.task: + await self.task async def disarm(self): if self.task: diff --git a/tests/core/test_flyer.py b/tests/core/test_flyer.py index 9d968526d9..6d9c9142aa 100644 --- a/tests/core/test_flyer.py +++ b/tests/core/test_flyer.py @@ -117,7 +117,7 @@ async def detectors(RE: RunEngine) -> tuple[StandardDetector, StandardDetector]: await writers[0].dummy_signal.connect(mock=True) await writers[1].dummy_signal.connect(mock=True) - async def dummy_arm_1(self=None, trigger=None, num=0, exposure=None): + def dummy_arm_1(self=None, trigger=None, num=0, exposure=None): return writers[0].dummy_signal.set(1) async def dummy_arm_2(self=None, trigger=None, num=0, exposure=None): diff --git a/tests/epics/adaravis/test_aravis.py b/tests/epics/adaravis/test_aravis.py index efefae2eb2..3c34fa49eb 100644 --- a/tests/epics/adaravis/test_aravis.py +++ b/tests/epics/adaravis/test_aravis.py @@ -36,7 +36,15 @@ async def test_trigger_source_set_to_gpio_line(test_adaravis: adaravis.AravisDet set_mock_value(test_adaravis.drv.trigger_source, "Freerun") async def trigger_and_complete(): - await test_adaravis.controller.arm(num=1, trigger=DetectorTrigger.edge_trigger) + await test_adaravis.controller.prepare( + TriggerInfo( + number=1, + trigger=DetectorTrigger.edge_trigger, + livetime=None, + deadtime=None, + frame_timeout=None, + ) + ) # Prevent timeouts set_mock_value(test_adaravis.drv.acquire, True) @@ -158,7 +166,7 @@ async def test_unsupported_trigger_excepts(test_adaravis: adaravis.AravisDetecto ): await test_adaravis.prepare( TriggerInfo( - number=1, + number=0, trigger=DetectorTrigger.variable_gate, deadtime=1, livetime=1, diff --git a/tests/epics/adcore/test_scans.py b/tests/epics/adcore/test_scans.py index 3936885083..6889fbbe3b 100644 --- a/tests/epics/adcore/test_scans.py +++ b/tests/epics/adcore/test_scans.py @@ -37,14 +37,14 @@ async def stop(self): ... class DummyController(DetectorControl): def __init__(self) -> None: ... + async def prepare(self, trigger_info: TriggerInfo): + return AsyncStatus(asyncio.sleep(0.01)) - async def arm( - self, - num: int, - trigger: DetectorTrigger = DetectorTrigger.internal, - exposure: Optional[float] = None, - ) -> AsyncStatus: - return AsyncStatus(asyncio.sleep(0.1)) + async def arm(self): + self._arm_status = AsyncStatus(asyncio.sleep(0.01)) + + async def wait_for_idle(self): + await self._arm_status async def disarm(self): ... diff --git a/tests/epics/adkinetix/test_kinetix.py b/tests/epics/adkinetix/test_kinetix.py index 3a53091f40..a17be5e5b3 100644 --- a/tests/epics/adkinetix/test_kinetix.py +++ b/tests/epics/adkinetix/test_kinetix.py @@ -33,7 +33,11 @@ async def test_trigger_modes(test_adkinetix: adkinetix.KinetixDetector): set_mock_value(test_adkinetix.drv.trigger_mode, "Internal") async def setup_trigger_mode(trig_mode: DetectorTrigger): - await test_adkinetix.controller.arm(num=1, trigger=trig_mode) + await test_adkinetix.controller.prepare( + TriggerInfo(number=1, trigger=trig_mode) + ) + await test_adkinetix.controller.arm() + await test_adkinetix.controller.wait_for_idle() # Prevent timeouts set_mock_value(test_adkinetix.drv.acquire, True) diff --git a/tests/epics/adpilatus/test_pilatus.py b/tests/epics/adpilatus/test_pilatus.py index 21ba42fb2c..72145ff0cc 100644 --- a/tests/epics/adpilatus/test_pilatus.py +++ b/tests/epics/adpilatus/test_pilatus.py @@ -62,11 +62,11 @@ async def test_trigger_mode_set( ): async def trigger_and_complete(): set_mock_value(test_adpilatus.drv.armed, True) - status = await test_adpilatus.controller.arm( - num=1, - trigger=detector_trigger, + await test_adpilatus.controller.prepare( + TriggerInfo(number=1, trigger=detector_trigger) ) - await status + await test_adpilatus.controller.arm() + await test_adpilatus.controller.wait_for_idle() await _trigger(test_adpilatus, expected_trigger_mode, trigger_and_complete) @@ -75,11 +75,11 @@ async def test_trigger_mode_set_without_armed_pv( test_adpilatus: adpilatus.PilatusDetector, ): async def trigger_and_complete(): - status = await test_adpilatus.controller.arm( - num=1, - trigger=DetectorTrigger.internal, + await test_adpilatus.controller.prepare( + TriggerInfo(number=1, trigger=DetectorTrigger.internal) ) - await status + await test_adpilatus.controller.arm() + await test_adpilatus.controller.wait_for_idle() with patch( "ophyd_async.epics.adpilatus._pilatus_controller.DEFAULT_TIMEOUT", diff --git a/tests/epics/adpilatus/test_pilatus_controller.py b/tests/epics/adpilatus/test_pilatus_controller.py index 00194fdb94..bb825b8744 100644 --- a/tests/epics/adpilatus/test_pilatus_controller.py +++ b/tests/epics/adpilatus/test_pilatus_controller.py @@ -1,6 +1,7 @@ import pytest from ophyd_async.core import DetectorTrigger, DeviceCollector, set_mock_value +from ophyd_async.core._detector import TriggerInfo from ophyd_async.epics import adcore, adpilatus @@ -28,8 +29,9 @@ async def test_pilatus_controller( pilatus_driver: adpilatus.PilatusDriverIO, ): set_mock_value(pilatus_driver.armed, True) - status = await pilatus.arm(num=1, trigger=DetectorTrigger.constant_gate) - await status + await pilatus.prepare(TriggerInfo(number=1, trigger=DetectorTrigger.constant_gate)) + await pilatus.arm() + await pilatus.wait_for_idle() assert await pilatus_driver.num_images.get_value() == 1 assert await pilatus_driver.image_mode.get_value() == adcore.ImageMode.multiple diff --git a/tests/epics/adsimdetector/test_adsim_controller.py b/tests/epics/adsimdetector/test_adsim_controller.py index 8a7c33516b..64d53ce590 100644 --- a/tests/epics/adsimdetector/test_adsim_controller.py +++ b/tests/epics/adsimdetector/test_adsim_controller.py @@ -1,6 +1,9 @@ +from unittest.mock import patch + import pytest from ophyd_async.core import DeviceCollector +from ophyd_async.core._detector import DetectorTrigger, TriggerInfo from ophyd_async.epics import adcore, adsimdetector @@ -14,7 +17,10 @@ async def ad(RE) -> adsimdetector.SimController: async def test_ad_controller(RE, ad: adsimdetector.SimController): - await ad.arm(num=1) + with patch("ophyd_async.core._signal.wait_for_value", return_value=None): + await ad.prepare(TriggerInfo(number=1, trigger=DetectorTrigger.internal)) + await ad.arm() + await ad.wait_for_idle() driver = ad.driver assert await driver.num_images.get_value() == 1 diff --git a/tests/epics/adsimdetector/test_sim.py b/tests/epics/adsimdetector/test_sim.py index 149835aa49..891d89c33c 100644 --- a/tests/epics/adsimdetector/test_sim.py +++ b/tests/epics/adsimdetector/test_sim.py @@ -114,9 +114,6 @@ async def test_two_detectors_fly_different_rate( trigger_info = TriggerInfo( number=15, trigger=DetectorTrigger.internal, - deadtime=None, - livetime=None, - frame_timeout=None, ) docs = defaultdict(list) diff --git a/tests/epics/advimba/test_vimba.py b/tests/epics/advimba/test_vimba.py index 0bc32b887b..ec93cc07d3 100644 --- a/tests/epics/advimba/test_vimba.py +++ b/tests/epics/advimba/test_vimba.py @@ -35,7 +35,9 @@ async def test_arming_trig_modes(test_advimba: advimba.VimbaDetector): set_mock_value(test_advimba.drv.exposure_mode, "Timed") async def setup_trigger_mode(trig_mode: DetectorTrigger): - await test_advimba.controller.arm(num=1, trigger=trig_mode) + await test_advimba.controller.prepare(TriggerInfo(number=1, trigger=trig_mode)) + await test_advimba.controller.arm() + await test_advimba.controller.wait_for_idle() # Prevent timeouts set_mock_value(test_advimba.drv.acquire, True) diff --git a/tests/epics/eiger/test_eiger_controller.py b/tests/epics/eiger/test_eiger_controller.py index 70f3078163..39204142e5 100644 --- a/tests/epics/eiger/test_eiger_controller.py +++ b/tests/epics/eiger/test_eiger_controller.py @@ -8,6 +8,7 @@ get_mock_put, set_mock_value, ) +from ophyd_async.core._detector import TriggerInfo from ophyd_async.epics.eiger._eiger_controller import EigerController from ophyd_async.epics.eiger._eiger_io import EigerDriverIO @@ -43,7 +44,9 @@ async def test_when_arm_with_exposure_then_time_and_period_set( ): driver, controller = eiger_driver_and_controller test_exposure = 0.002 - await controller.arm(10, exposure=test_exposure) + await controller.prepare(TriggerInfo(number=10, livetime=test_exposure)) + await controller.arm() + await controller.wait_for_idle() assert (await driver.acquire_period.get_value()) == test_exposure assert (await driver.acquire_time.get_value()) == test_exposure @@ -52,7 +55,9 @@ async def test_when_arm_with_no_exposure_then_arm_set_correctly( eiger_driver_and_controller: DriverAndController, ): driver, controller = eiger_driver_and_controller - await controller.arm(10, exposure=None) + await controller.prepare(TriggerInfo(number=10)) + await controller.arm() + await controller.wait_for_idle() get_mock_put(driver.arm).assert_called_once_with(1, wait=ANY, timeout=ANY) @@ -61,7 +66,9 @@ async def test_when_arm_with_number_of_images_then_number_of_images_set_correctl ): driver, controller = eiger_driver_and_controller test_number_of_images = 40 - await controller.arm(test_number_of_images, exposure=None) + await controller.prepare(TriggerInfo(number=test_number_of_images)) + await controller.arm() + await controller.wait_for_idle() get_mock_put(driver.num_images).assert_called_once_with( test_number_of_images, wait=ANY, timeout=ANY ) @@ -73,7 +80,9 @@ async def test_given_detector_fails_to_go_ready_when_arm_called_then_fails( ): driver, controller = eiger_driver_and_controller_no_arm with raises(TimeoutError): - await controller.arm(10) + await controller.prepare(TriggerInfo(number=10)) + await controller.arm() + await controller.wait_for_idle() async def test_when_disarm_called_on_controller_then_disarm_called_on_driver( diff --git a/tests/fastcs/panda/test_hdf_panda.py b/tests/fastcs/panda/test_hdf_panda.py index 43b969ab5e..ab5acd5c20 100644 --- a/tests/fastcs/panda/test_hdf_panda.py +++ b/tests/fastcs/panda/test_hdf_panda.py @@ -2,28 +2,28 @@ from typing import Dict from unittest.mock import ANY +import bluesky.plan_stubs as bps import numpy as np import pytest -from bluesky import plan_stubs as bps -from bluesky.run_engine import RunEngine +from bluesky import RunEngine from ophyd_async.core import ( Device, SignalR, - StandardFlyer, StaticFilenameProvider, StaticPathProvider, - assert_emitted, callback_on_mock_put, set_mock_value, ) +from ophyd_async.core._flyer import StandardFlyer +from ophyd_async.core._signal import assert_emitted from ophyd_async.fastcs.panda import ( DatasetTable, HDFPanda, PandaHdf5DatasetType, - StaticSeqTableTriggerLogic, ) -from ophyd_async.plan_stubs import ( +from ophyd_async.fastcs.panda._trigger import StaticSeqTableTriggerLogic +from ophyd_async.plan_stubs._fly import ( prepare_static_seq_table_flyer_and_detectors_with_same_trigger, ) @@ -111,7 +111,7 @@ def flying_plan(): set_mock_value(flyer.trigger_logic.seq.active, 1) yield from bps.kickoff(flyer, wait=True) - yield from bps.kickoff(mock_hdf_panda) + yield from bps.kickoff(mock_hdf_panda, wait=True) yield from bps.complete(flyer, wait=False, group="complete") yield from bps.complete(mock_hdf_panda, wait=False, group="complete") @@ -191,3 +191,123 @@ def assert_resource_document(): assert stream_datum["stream_resource"] in [ sd["uid"].split("/")[0] for sd in docs["stream_datum"] ] + + +async def test_hdf_panda_hardware_triggered_flyable_with_iterations( + RE: RunEngine, + mock_hdf_panda, + tmp_path, +): + docs = {} + + def append_and_print(name, doc): + if name not in docs: + docs[name] = [] + docs[name] += [doc] + + RE.subscribe(append_and_print) + + shutter_time = 0.004 + exposure = 1 + + trigger_logic = StaticSeqTableTriggerLogic(mock_hdf_panda.seq[1]) + flyer = StandardFlyer(trigger_logic, [], name="flyer") + + def flying_plan(): + iteration = 2 + yield from bps.stage_all(mock_hdf_panda, flyer) + yield from bps.open_run() + yield from prepare_static_seq_table_flyer_and_detectors_with_same_trigger( # noqa: E501 + flyer, + [mock_hdf_panda], + number_of_frames=1, + exposure=exposure, + shutter_time=shutter_time, + iteration=iteration, + ) + + yield from bps.declare_stream(mock_hdf_panda, name="main_stream", collect=True) + + for i in range(iteration): + set_mock_value(flyer.trigger_logic.seq.active, 1) + yield from bps.kickoff(flyer, wait=True) + yield from bps.kickoff(mock_hdf_panda) + + yield from bps.complete(flyer, wait=False, group="complete") + yield from bps.complete(mock_hdf_panda, wait=False, group="complete") + + # Manually incremenet the index as if a frame was taken + set_mock_value(mock_hdf_panda.data.num_captured, 1) + set_mock_value(flyer.trigger_logic.seq.active, 0) + + done = False + while not done: + try: + yield from bps.wait(group="complete", timeout=0.5) + except TimeoutError: + pass + else: + done = True + yield from bps.collect( + mock_hdf_panda, + return_payload=False, + name="main_stream", + ) + yield from bps.wait(group="complete") + yield from bps.close_run() + + yield from bps.unstage_all(flyer, mock_hdf_panda) + yield from bps.wait_for([lambda: mock_hdf_panda.controller.disarm()]) + + # fly scan + RE(flying_plan()) + + assert_emitted( + docs, start=1, descriptor=1, stream_resource=2, stream_datum=2, stop=1 + ) + + # test descriptor + data_key_names: Dict[str, str] = docs["descriptor"][0]["object_keys"]["panda"] + assert data_key_names == ["x", "y"] + for data_key_name in data_key_names: + assert ( + docs["descriptor"][0]["data_keys"][data_key_name]["source"] + == "mock+soft://panda-data-hdf_directory" + ) + + # test stream resources + for dataset_name, stream_resource, data_key_name in zip( + ("x", "y"), docs["stream_resource"], data_key_names + ): + + def assert_resource_document(): + assert stream_resource == { + "run_start": docs["start"][0]["uid"], + "uid": ANY, + "data_key": data_key_name, + "mimetype": "application/x-hdf5", + "uri": "file://localhost" + str(tmp_path / "test-panda.h5"), + "parameters": { + "dataset": f"/{dataset_name}", + "swmr": False, + "multiplier": 1, + }, + } + assert "test-panda.h5" in stream_resource["uri"] + + assert_resource_document() + + # test stream datum + for stream_datum in docs["stream_datum"]: + assert stream_datum["descriptor"] == docs["descriptor"][0]["uid"] + assert stream_datum["seq_nums"] == { + "start": 1, + "stop": 2, + } + assert stream_datum["indices"] == { + "start": 0, + "stop": 1, + } + assert stream_datum["stream_resource"] in [ + sd["uid"].split("/")[0] for sd in docs["stream_datum"] + ] diff --git a/tests/fastcs/panda/test_panda_control.py b/tests/fastcs/panda/test_panda_control.py index b73d907f89..6c0abf298f 100644 --- a/tests/fastcs/panda/test_panda_control.py +++ b/tests/fastcs/panda/test_panda_control.py @@ -5,6 +5,7 @@ import pytest from ophyd_async.core import DEFAULT_TIMEOUT, DetectorTrigger, Device, DeviceCollector +from ophyd_async.core._detector import TriggerInfo from ophyd_async.epics.pvi import fill_pvi_entries from ophyd_async.epics.signal import epics_signal_rw from ophyd_async.fastcs.panda import CommonPandaBlocks, PandaPcapController @@ -36,18 +37,27 @@ class PcapBlock(Device): pandaController = PandaPcapController(pcap=PcapBlock()) with patch("ophyd_async.fastcs.panda._control.wait_for_value", return_value=None): with pytest.raises(AttributeError) as exc: - await pandaController.arm(num=1, trigger=DetectorTrigger.constant_gate) + await pandaController.prepare( + TriggerInfo(number=1, trigger=DetectorTrigger.constant_gate) + ) + await pandaController.arm() assert ("'PcapBlock' object has no attribute 'arm'") in str(exc.value) async def test_panda_controller_arm_disarm(mock_panda): pandaController = PandaPcapController(mock_panda.pcap) with patch("ophyd_async.fastcs.panda._control.wait_for_value", return_value=None): - await pandaController.arm(num=1, trigger=DetectorTrigger.constant_gate) + await pandaController.prepare( + TriggerInfo(number=1, trigger=DetectorTrigger.constant_gate) + ) + await pandaController.arm() + await pandaController.wait_for_idle() await pandaController.disarm() async def test_panda_controller_wrong_trigger(): pandaController = PandaPcapController(None) with pytest.raises(AssertionError): - await pandaController.arm(num=1, trigger=DetectorTrigger.internal) + await pandaController.prepare( + TriggerInfo(number=1, trigger=DetectorTrigger.internal) + ) diff --git a/tests/plan_stubs/test_fly.py b/tests/plan_stubs/test_fly.py index 791935c111..06e40e29db 100644 --- a/tests/plan_stubs/test_fly.py +++ b/tests/plan_stubs/test_fly.py @@ -115,7 +115,6 @@ def __init__( @WatchableAsyncStatus.wrap async def complete(self): - assert self._arm_status, "Prepare not run" assert self._trigger_info self.writer.increment_index() async for index in self.writer.observe_indices_written( @@ -145,10 +144,10 @@ async def detectors(RE: RunEngine) -> tuple[MockDetector, MockDetector]: await writers[0].dummy_signal.connect(mock=True) await writers[1].dummy_signal.connect(mock=True) - async def dummy_arm_1(self=None, trigger=None, num=0, exposure=None): + def dummy_arm_1(self=None): return writers[0].dummy_signal.set(1) - async def dummy_arm_2(self=None, trigger=None, num=0, exposure=None): + def dummy_arm_2(self=None): return writers[1].dummy_signal.set(1) detector_1 = MockDetector(