diff --git a/src/ophyd_async/core/__init__.py b/src/ophyd_async/core/__init__.py index 6b9c6ccac3..1928c7aba4 100644 --- a/src/ophyd_async/core/__init__.py +++ b/src/ophyd_async/core/__init__.py @@ -61,9 +61,14 @@ soft_signal_rw, wait_for_value, ) -from ._signal_backend import RuntimeSubsetEnum, SignalBackend, SubsetEnum +from ._signal_backend import ( + RuntimeSubsetEnum, + SignalBackend, + SubsetEnum, +) from ._soft_signal_backend import SignalMetadata, SoftSignalBackend from ._status import AsyncStatus, WatchableAsyncStatus, completed_status +from ._table import Table from ._utils import ( DEFAULT_TIMEOUT, CalculatableTimeout, @@ -152,6 +157,7 @@ "CalculateTimeout", "NotConnected", "ReadingValueCallback", + "Table", "T", "WatcherUpdate", "get_dtype", diff --git a/src/ophyd_async/core/_device_save_loader.py b/src/ophyd_async/core/_device_save_loader.py index 5b81228264..d847caff69 100644 --- a/src/ophyd_async/core/_device_save_loader.py +++ b/src/ophyd_async/core/_device_save_loader.py @@ -7,6 +7,7 @@ from bluesky.plan_stubs import abs_set, wait from bluesky.protocols import Location from bluesky.utils import Msg +from pydantic import BaseModel from ._device import Device from ._signal import SignalRW @@ -18,6 +19,12 @@ def ndarray_representer(dumper: yaml.Dumper, array: npt.NDArray[Any]) -> yaml.No ) +def pydantic_model_abstraction_representer( + dumper: yaml.Dumper, model: BaseModel +) -> yaml.Node: + return dumper.represent_data(model.model_dump(mode="python")) + + class OphydDumper(yaml.Dumper): def represent_data(self, data: Any) -> Any: if isinstance(data, Enum): @@ -134,6 +141,11 @@ def save_to_yaml(phases: Sequence[Dict[str, Any]], save_path: str) -> None: """ yaml.add_representer(np.ndarray, ndarray_representer, Dumper=yaml.Dumper) + yaml.add_multi_representer( + BaseModel, + pydantic_model_abstraction_representer, + Dumper=yaml.Dumper, + ) with open(save_path, "w") as file: yaml.dump(phases, file, Dumper=OphydDumper, default_flow_style=False) diff --git a/src/ophyd_async/core/_signal_backend.py b/src/ophyd_async/core/_signal_backend.py index 41e9fbcbd3..594863ef2a 100644 --- a/src/ophyd_async/core/_signal_backend.py +++ b/src/ophyd_async/core/_signal_backend.py @@ -1,5 +1,13 @@ from abc import abstractmethod -from typing import TYPE_CHECKING, ClassVar, Generic, Literal, Optional, Tuple, Type +from typing import ( + TYPE_CHECKING, + ClassVar, + Generic, + Literal, + Optional, + Tuple, + Type, +) from ._protocol import DataKey, Reading from ._utils import DEFAULT_TIMEOUT, ReadingValueCallback, T @@ -11,6 +19,11 @@ class SignalBackend(Generic[T]): #: Datatype of the signal value datatype: Optional[Type[T]] = None + @classmethod + @abstractmethod + def datatype_allowed(cls, dtype: type): + """Check if a given datatype is acceptable for this signal backend.""" + #: Like ca://PV_PREFIX:SIGNAL @abstractmethod def source(self, name: str) -> str: diff --git a/src/ophyd_async/core/_soft_signal_backend.py b/src/ophyd_async/core/_soft_signal_backend.py index 62bafd5bb1..1e895e60cc 100644 --- a/src/ophyd_async/core/_soft_signal_backend.py +++ b/src/ophyd_async/core/_soft_signal_backend.py @@ -2,15 +2,20 @@ import inspect import time +from abc import ABCMeta from collections import abc from enum import Enum from typing import Dict, Generic, Optional, Tuple, Type, Union, cast, get_origin import numpy as np from bluesky.protocols import DataKey, Dtype, Reading +from pydantic import BaseModel from typing_extensions import TypedDict -from ._signal_backend import RuntimeSubsetEnum, SignalBackend +from ._signal_backend import ( + RuntimeSubsetEnum, + SignalBackend, +) from ._utils import DEFAULT_TIMEOUT, ReadingValueCallback, T, get_dtype primitive_dtypes: Dict[type, Dtype] = { @@ -94,7 +99,7 @@ def make_initial_value(self, datatype: Optional[Type[T]]) -> T: class SoftEnumConverter(SoftConverter): choices: Tuple[str, ...] - def __init__(self, datatype: Union[RuntimeSubsetEnum, Enum]): + def __init__(self, datatype: Union[RuntimeSubsetEnum, Type[Enum]]): if issubclass(datatype, Enum): self.choices = tuple(v.value for v in datatype) else: @@ -122,6 +127,16 @@ def make_initial_value(self, datatype: Optional[Type[T]]) -> T: return cast(T, self.choices[0]) +class SoftPydanticModelConverter(SoftConverter): + def __init__(self, datatype: Type[BaseModel]): + self.datatype = datatype + + def write_value(self, value): + if isinstance(value, dict): + return self.datatype(**value) + return value + + def make_converter(datatype): is_array = get_dtype(datatype) is not None is_sequence = get_origin(datatype) == abc.Sequence @@ -129,10 +144,19 @@ def make_converter(datatype): issubclass(datatype, Enum) or issubclass(datatype, RuntimeSubsetEnum) ) + is_pydantic_model = ( + inspect.isclass(datatype) + # Necessary to avoid weirdness in ABCMeta.__subclasscheck__ + and isinstance(datatype, ABCMeta) + and issubclass(datatype, BaseModel) + ) + if is_array or is_sequence: return SoftArrayConverter() if is_enum: return SoftEnumConverter(datatype) + if is_pydantic_model: + return SoftPydanticModelConverter(datatype) return SoftConverter() @@ -145,6 +169,10 @@ class SoftSignalBackend(SignalBackend[T]): _timestamp: float _severity: int + @classmethod + def datatype_allowed(cls, datatype: Type) -> bool: + return True # Any value allowed in a soft signal + def __init__( self, datatype: Optional[Type[T]], diff --git a/src/ophyd_async/core/_table.py b/src/ophyd_async/core/_table.py new file mode 100644 index 0000000000..bdb619a3b9 --- /dev/null +++ b/src/ophyd_async/core/_table.py @@ -0,0 +1,58 @@ +import numpy as np +from pydantic import BaseModel, ConfigDict, model_validator + + +class Table(BaseModel): + """An abstraction of a Table of str to numpy array.""" + + model_config = ConfigDict(validate_assignment=True, strict=False) + + @classmethod + def row(cls, sub_cls, **kwargs) -> "Table": + arrayified_kwargs = { + field_name: np.concatenate( + ( + (default_arr := field_value.default_factory()), + np.array([kwargs[field_name]], dtype=default_arr.dtype), + ) + ) + for field_name, field_value in sub_cls.model_fields.items() + } + return sub_cls(**arrayified_kwargs) + + def __add__(self, right: "Table") -> "Table": + """Concatenate the arrays in field values.""" + + assert isinstance(right, type(self)), ( + f"{right} is not a `Table`, or is not the same " + f"type of `Table` as {self}." + ) + + return type(self)( + **{ + field_name: np.concatenate( + (getattr(self, field_name), getattr(right, field_name)) + ) + for field_name in self.model_fields + } + ) + + @model_validator(mode="after") + def validate_arrays(self) -> "Table": + first_length = len(next(iter(self))[1]) + assert all( + len(field_value) == first_length for _, field_value in self + ), "Rows should all be of equal size." + + if not all( + np.issubdtype( + self.model_fields[field_name].default_factory().dtype, field_value.dtype + ) + for field_name, field_value in self + ): + raise ValueError( + f"Cannot construct a `{type(self).__name__}`, " + "some rows have incorrect types." + ) + + return self diff --git a/src/ophyd_async/core/_utils.py b/src/ophyd_async/core/_utils.py index f5098ce717..d081ed008f 100644 --- a/src/ophyd_async/core/_utils.py +++ b/src/ophyd_async/core/_utils.py @@ -145,7 +145,7 @@ def get_dtype(typ: Type) -> Optional[np.dtype]: def get_unique(values: Dict[str, T], types: str) -> T: - """If all values are the same, return that value, otherwise return TypeError + """If all values are the same, return that value, otherwise raise TypeError >>> get_unique({"a": 1, "b": 1}, "integers") 1 diff --git a/src/ophyd_async/epics/signal/_aioca.py b/src/ophyd_async/epics/signal/_aioca.py index 78052d448d..ef8a5693e2 100644 --- a/src/ophyd_async/epics/signal/_aioca.py +++ b/src/ophyd_async/epics/signal/_aioca.py @@ -1,9 +1,10 @@ +import inspect import logging import sys from dataclasses import dataclass from enum import Enum from math import isnan, nan -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Dict, List, Optional, Sequence, Type, Union, get_origin import numpy as np from aioca import ( @@ -24,6 +25,7 @@ DEFAULT_TIMEOUT, NotConnected, ReadingValueCallback, + RuntimeSubsetEnum, SignalBackend, T, get_dtype, @@ -211,7 +213,8 @@ def make_converter( raise TypeError(f"{pv} has type [{pv_dtype}] not [{dtype}]") return CaArrayConverter(pv_dbr, None) elif pv_dbr == dbr.DBR_ENUM and datatype is bool: - # Database can't do bools, so are often representated as enums, CA can do int + # Database can't do bools, so are often representated as enums, + # CA can do int pv_choices_len = get_unique( {k: len(v.enums) for k, v in values.items()}, "number of choices" ) @@ -240,7 +243,7 @@ def make_converter( f"{pv} has type {type(value).__name__.replace('ca_', '')} " + f"not {datatype.__name__}" ) - return CaConverter(pv_dbr, None) + return CaConverter(pv_dbr, None) _tried_pyepics = False @@ -256,8 +259,31 @@ def _use_pyepics_context_if_imported(): class CaSignalBackend(SignalBackend[T]): + _ALLOWED_DATATYPES = ( + bool, + int, + float, + str, + Sequence, + Enum, + RuntimeSubsetEnum, + np.ndarray, + ) + + @classmethod + def datatype_allowed(cls, datatype: Optional[Type]) -> bool: + stripped_origin = get_origin(datatype) or datatype + if datatype is None: + return True + + return inspect.isclass(stripped_origin) and issubclass( + stripped_origin, cls._ALLOWED_DATATYPES + ) + def __init__(self, datatype: Optional[Type[T]], read_pv: str, write_pv: str): self.datatype = datatype + if not CaSignalBackend.datatype_allowed(self.datatype): + raise TypeError(f"Given datatype {self.datatype} unsupported in CA.") self.read_pv = read_pv self.write_pv = write_pv self.initial_values: Dict[str, AugmentedValue] = {} diff --git a/src/ophyd_async/epics/signal/_p4p.py b/src/ophyd_async/epics/signal/_p4p.py index 28ec8fe6ab..c7d0b5240d 100644 --- a/src/ophyd_async/epics/signal/_p4p.py +++ b/src/ophyd_async/epics/signal/_p4p.py @@ -3,14 +3,17 @@ import inspect import logging import time +from abc import ABCMeta from dataclasses import dataclass from enum import Enum from math import isnan, nan -from typing import Any, Dict, List, Optional, Sequence, Type, Union +from typing import Any, Dict, List, Optional, Sequence, Type, Union, get_origin +import numpy as np from bluesky.protocols import DataKey, Dtype, Reading from p4p import Value from p4p.client.asyncio import Context, Subscription +from pydantic import BaseModel from ophyd_async.core import ( DEFAULT_TIMEOUT, @@ -253,6 +256,19 @@ def get_datakey(self, source: str, value) -> DataKey: return _data_key_from_value(source, value, dtype="object") +class PvaPydanticModelConverter(PvaConverter): + def __init__(self, datatype: BaseModel): + self.datatype = datatype + + def value(self, value: Value): + return self.datatype(**value.todict()) + + def write_value(self, value: Union[BaseModel, Dict[str, Any]]): + if isinstance(value, self.datatype): + return value.model_dump(mode="python") + return value + + class PvaDictConverter(PvaConverter): def reading(self, value): ts = time.time() @@ -348,6 +364,15 @@ def make_converter(datatype: Optional[Type], values: Dict[str, Any]) -> PvaConve raise TypeError(f"{pv} has type {typ.__name__} not {datatype.__name__}") return PvaConverter() elif "NTTable" in typeid: + if ( + datatype + and inspect.isclass(datatype) + and + # Necessary to avoid weirdness in ABCMeta.__subclasscheck__ + isinstance(datatype, ABCMeta) + and issubclass(datatype, BaseModel) + ): + return PvaPydanticModelConverter(datatype) return PvaTableConverter() elif "structure" in typeid: return PvaDictConverter() @@ -358,8 +383,33 @@ def make_converter(datatype: Optional[Type], values: Dict[str, Any]) -> PvaConve class PvaSignalBackend(SignalBackend[T]): _ctxt: Optional[Context] = None + _ALLOWED_DATATYPES = ( + bool, + int, + float, + str, + Sequence, + np.ndarray, + Enum, + RuntimeSubsetEnum, + BaseModel, + dict, + ) + + @classmethod + def datatype_allowed(cls, datatype: Optional[Type]) -> bool: + stripped_origin = get_origin(datatype) or datatype + if datatype is None: + return True + return inspect.isclass(stripped_origin) and issubclass( + stripped_origin, cls._ALLOWED_DATATYPES + ) + def __init__(self, datatype: Optional[Type[T]], read_pv: str, write_pv: str): self.datatype = datatype + if not PvaSignalBackend.datatype_allowed(self.datatype): + raise TypeError(f"Given datatype {self.datatype} unsupported in PVA.") + self.read_pv = read_pv self.write_pv = write_pv self.initial_values: Dict[str, Any] = {} diff --git a/src/ophyd_async/fastcs/panda/__init__.py b/src/ophyd_async/fastcs/panda/__init__.py index 9d1c1d429f..0dbe7222b0 100644 --- a/src/ophyd_async/fastcs/panda/__init__.py +++ b/src/ophyd_async/fastcs/panda/__init__.py @@ -15,10 +15,7 @@ DatasetTable, PandaHdf5DatasetType, SeqTable, - SeqTableRow, SeqTrigger, - seq_table_from_arrays, - seq_table_from_rows, ) from ._trigger import ( PcompInfo, @@ -45,10 +42,7 @@ "DatasetTable", "PandaHdf5DatasetType", "SeqTable", - "SeqTableRow", "SeqTrigger", - "seq_table_from_arrays", - "seq_table_from_rows", "PcompInfo", "SeqTableInfo", "StaticPcompTriggerLogic", diff --git a/src/ophyd_async/fastcs/panda/_table.py b/src/ophyd_async/fastcs/panda/_table.py index ec2c1a5b8b..ee6df7522f 100644 --- a/src/ophyd_async/fastcs/panda/_table.py +++ b/src/ophyd_async/fastcs/panda/_table.py @@ -1,11 +1,14 @@ -from dataclasses import dataclass +import inspect from enum import Enum -from typing import Optional, Sequence, Type, TypeVar +from typing import Annotated, Sequence import numpy as np import numpy.typing as npt -import pydantic_numpy.typing as pnd -from typing_extensions import NotRequired, TypedDict +from pydantic import Field, field_validator, model_validator +from pydantic_numpy.helper.annotation import NpArrayPydanticAnnotation +from typing_extensions import TypedDict + +from ophyd_async.core import Table class PandaHdf5DatasetType(str, Enum): @@ -34,137 +37,113 @@ class SeqTrigger(str, Enum): POSC_LT = "POSC<=POSITION" -@dataclass -class SeqTableRow: - repeats: int = 1 - trigger: SeqTrigger = SeqTrigger.IMMEDIATE - position: int = 0 - time1: int = 0 - outa1: bool = False - outb1: bool = False - outc1: bool = False - outd1: bool = False - oute1: bool = False - outf1: bool = False - time2: int = 0 - outa2: bool = False - outb2: bool = False - outc2: bool = False - outd2: bool = False - oute2: bool = False - outf2: bool = False - - -class SeqTable(TypedDict): - repeats: NotRequired[pnd.Np1DArrayUint16] - trigger: NotRequired[Sequence[SeqTrigger]] - position: NotRequired[pnd.Np1DArrayInt32] - time1: NotRequired[pnd.Np1DArrayUint32] - outa1: NotRequired[pnd.Np1DArrayBool] - outb1: NotRequired[pnd.Np1DArrayBool] - outc1: NotRequired[pnd.Np1DArrayBool] - outd1: NotRequired[pnd.Np1DArrayBool] - oute1: NotRequired[pnd.Np1DArrayBool] - outf1: NotRequired[pnd.Np1DArrayBool] - time2: NotRequired[pnd.Np1DArrayUint32] - outa2: NotRequired[pnd.Np1DArrayBool] - outb2: NotRequired[pnd.Np1DArrayBool] - outc2: NotRequired[pnd.Np1DArrayBool] - outd2: NotRequired[pnd.Np1DArrayBool] - oute2: NotRequired[pnd.Np1DArrayBool] - outf2: NotRequired[pnd.Np1DArrayBool] - - -def seq_table_from_rows(*rows: SeqTableRow): - """ - Constructs a sequence table from a series of rows. - """ - return seq_table_from_arrays( - repeats=np.array([row.repeats for row in rows], dtype=np.uint16), - trigger=[row.trigger for row in rows], - position=np.array([row.position for row in rows], dtype=np.int32), - time1=np.array([row.time1 for row in rows], dtype=np.uint32), - outa1=np.array([row.outa1 for row in rows], dtype=np.bool_), - outb1=np.array([row.outb1 for row in rows], dtype=np.bool_), - outc1=np.array([row.outc1 for row in rows], dtype=np.bool_), - outd1=np.array([row.outd1 for row in rows], dtype=np.bool_), - oute1=np.array([row.oute1 for row in rows], dtype=np.bool_), - outf1=np.array([row.outf1 for row in rows], dtype=np.bool_), - time2=np.array([row.time2 for row in rows], dtype=np.uint32), - outa2=np.array([row.outa2 for row in rows], dtype=np.bool_), - outb2=np.array([row.outb2 for row in rows], dtype=np.bool_), - outc2=np.array([row.outc2 for row in rows], dtype=np.bool_), - outd2=np.array([row.outd2 for row in rows], dtype=np.bool_), - oute2=np.array([row.oute2 for row in rows], dtype=np.bool_), - outf2=np.array([row.outf2 for row in rows], dtype=np.bool_), - ) - - -T = TypeVar("T", bound=np.generic) - - -def seq_table_from_arrays( - *, - repeats: Optional[npt.NDArray[np.uint16]] = None, - trigger: Optional[Sequence[SeqTrigger]] = None, - position: Optional[npt.NDArray[np.int32]] = None, - time1: Optional[npt.NDArray[np.uint32]] = None, - outa1: Optional[npt.NDArray[np.bool_]] = None, - outb1: Optional[npt.NDArray[np.bool_]] = None, - outc1: Optional[npt.NDArray[np.bool_]] = None, - outd1: Optional[npt.NDArray[np.bool_]] = None, - oute1: Optional[npt.NDArray[np.bool_]] = None, - outf1: Optional[npt.NDArray[np.bool_]] = None, - time2: npt.NDArray[np.uint32], - outa2: Optional[npt.NDArray[np.bool_]] = None, - outb2: Optional[npt.NDArray[np.bool_]] = None, - outc2: Optional[npt.NDArray[np.bool_]] = None, - outd2: Optional[npt.NDArray[np.bool_]] = None, - oute2: Optional[npt.NDArray[np.bool_]] = None, - outf2: Optional[npt.NDArray[np.bool_]] = None, -) -> SeqTable: - """ - Constructs a sequence table from a series of columns as arrays. - time2 is the only required argument and must not be None. - All other provided arguments must be of equal length to time2. - If any other argument is not given, or else given as None or empty, - an array of length len(time2) filled with the following is defaulted: - repeats: 1 - trigger: SeqTrigger.IMMEDIATE - all others: 0/False as appropriate - """ - assert time2 is not None, "time2 must be provided" - length = len(time2) - assert 0 < length < 4096, f"Length {length} not in range" - - def or_default( - value: Optional[npt.NDArray[T]], dtype: Type[T], default_value: int = 0 - ) -> npt.NDArray[T]: - if value is None or len(value) == 0: - return np.full(length, default_value, dtype=dtype) - return value - - table = SeqTable( - repeats=or_default(repeats, np.uint16, 1), - trigger=trigger or [SeqTrigger.IMMEDIATE] * length, - position=or_default(position, np.int32), - time1=or_default(time1, np.uint32), - outa1=or_default(outa1, np.bool_), - outb1=or_default(outb1, np.bool_), - outc1=or_default(outc1, np.bool_), - outd1=or_default(outd1, np.bool_), - oute1=or_default(oute1, np.bool_), - outf1=or_default(outf1, np.bool_), - time2=time2, - outa2=or_default(outa2, np.bool_), - outb2=or_default(outb2, np.bool_), - outc2=or_default(outc2, np.bool_), - outd2=or_default(outd2, np.bool_), - oute2=or_default(oute2, np.bool_), - outf2=or_default(outf2, np.bool_), - ) - for k, v in table.items(): - size = len(v) # type: ignore - if size != length: - raise ValueError(f"{k}: has length {size} not {length}") - return table +PydanticNp1DArrayInt32 = Annotated[ + np.ndarray[tuple[int], np.int32], + NpArrayPydanticAnnotation.factory( + data_type=np.int32, dimensions=1, strict_data_typing=False + ), + Field(default_factory=lambda: np.array([], np.int32)), +] +PydanticNp1DArrayBool = Annotated[ + np.ndarray[tuple[int], np.bool_], + NpArrayPydanticAnnotation.factory( + data_type=np.bool_, dimensions=1, strict_data_typing=False + ), + Field(default_factory=lambda: np.array([], dtype=np.bool_)), +] +TriggerStr = Annotated[ + np.ndarray[tuple[int], np.unicode_], + NpArrayPydanticAnnotation.factory( + data_type=np.unicode_, dimensions=1, strict_data_typing=False + ), + Field(default_factory=lambda: np.array([], dtype=np.dtype(" "SeqTable": + sig = inspect.signature(cls.row) + kwargs = {k: v for k, v in locals().items() if k in sig.parameters} + + if isinstance(kwargs["trigger"], SeqTrigger): + kwargs["trigger"] = kwargs["trigger"].value + elif isinstance(kwargs["trigger"], str): + SeqTrigger(kwargs["trigger"]) + + return Table.row(cls, **kwargs) + + @field_validator("trigger", mode="before") + @classmethod + def trigger_to_np_array(cls, trigger_column): + """ + The user can provide a list of SeqTrigger enum elements instead of a numpy str. + """ + if isinstance(trigger_column, Sequence) and all( + isinstance(trigger, SeqTrigger) for trigger in trigger_column + ): + trigger_column = np.array( + [trigger.value for trigger in trigger_column], dtype=np.dtype(" "SeqTable": + """ + Used to check max_length. Unfortunately trying the `max_length` arg in + the pydantic field doesn't work + """ + + first_length = len(next(iter(self))[1]) + assert 0 <= first_length < 4096, f"Length {first_length} not in range." + return self diff --git a/src/ophyd_async/plan_stubs/_fly.py b/src/ophyd_async/plan_stubs/_fly.py index 087ec62dd1..daa686b477 100644 --- a/src/ophyd_async/plan_stubs/_fly.py +++ b/src/ophyd_async/plan_stubs/_fly.py @@ -15,8 +15,6 @@ PcompInfo, SeqTable, SeqTableInfo, - SeqTableRow, - seq_table_from_rows, ) @@ -74,24 +72,26 @@ def prepare_static_seq_table_flyer_and_detectors_with_same_trigger( trigger_time = number_of_frames * (exposure + deadtime) pre_delay = max(period - 2 * shutter_time - trigger_time, 0) - table: SeqTable = seq_table_from_rows( + table = ( # Wait for pre-delay then open shutter - SeqTableRow( + SeqTable.row( time1=in_micros(pre_delay), time2=in_micros(shutter_time), outa2=True, - ), + ) + + # Keeping shutter open, do N triggers - SeqTableRow( + SeqTable.row( repeats=number_of_frames, time1=in_micros(exposure), outa1=True, outb1=True, time2=in_micros(deadtime), outa2=True, - ), + ) + + # Add the shutter close - SeqTableRow(time2=in_micros(shutter_time)), + SeqTable.row(time2=in_micros(shutter_time)) ) table_info = SeqTableInfo(sequence_table=table, repeats=repeats) diff --git a/tests/core/test_device_save_loader.py b/tests/core/test_device_save_loader.py index aa60be9802..b265b86137 100644 --- a/tests/core/test_device_save_loader.py +++ b/tests/core/test_device_save_loader.py @@ -8,6 +8,8 @@ import pytest import yaml from bluesky.run_engine import RunEngine +from pydantic import BaseModel, Field +from pydantic_numpy.typing import NpNDArrayFp16, NpNDArrayInt32 from ophyd_async.core import ( Device, @@ -54,6 +56,16 @@ class MyEnum(str, Enum): three = "three" +class SomePvaPydanticModel(BaseModel): + some_int_field: int = Field(default=1) + some_pydantic_numpy_field_float: NpNDArrayFp16 = Field( + default_factory=lambda: np.array([1, 2, 3]) + ) + some_pydantic_numpy_field_int: NpNDArrayInt32 = Field( + default_factory=lambda: np.array([1, 2, 3]) + ) + + class DummyDeviceGroupAllTypes(Device): def __init__(self, name: str): self.pv_int: SignalRW = epics_signal_rw(int, "PV1") @@ -73,6 +85,9 @@ def __init__(self, name: str): self.pv_array_float64 = epics_signal_rw(npt.NDArray[np.float64], "PV14") self.pv_array_npstr = epics_signal_rw(npt.NDArray[np.str_], "PV15") self.pv_array_str = epics_signal_rw(Sequence[str], "PV16") + self.pv_protocol_device_abstraction = epics_signal_rw( + SomePvaPydanticModel, "pva://PV17" + ) @pytest.fixture @@ -155,6 +170,7 @@ async def test_save_device_all_types(RE: RunEngine, device_all_types, tmp_path): await device_all_types.pv_array_str.set( ["one", "two", "three"], ) + await device_all_types.pv_protocol_device_abstraction.set(SomePvaPydanticModel()) # Create save plan from utility functions def save_my_device(): diff --git a/tests/core/test_signal.py b/tests/core/test_signal.py index 3b4c4934f4..ab5c02cffe 100644 --- a/tests/core/test_signal.py +++ b/tests/core/test_signal.py @@ -403,3 +403,28 @@ async def test_subscription_logs(caplog): assert "Making subscription" in caplog.text mock_signal_rw.clear_sub(cbs.append) assert "Closing subscription on source" in caplog.text + + +async def test_signal_unknown_datatype(): + class SomeClass: + def __init__(self): + self.some_attribute = "some_attribute" + + def some_function(self): + pass + + err_str = ( + "Given datatype .SomeClass'>" + " unsupported in %s." + ) + with pytest.raises(TypeError, match=err_str % ("PVA",)): + epics_signal_rw(SomeClass, "pva://mock_signal", name="mock_signal") + with pytest.raises(TypeError, match=err_str % ("CA",)): + epics_signal_rw(SomeClass, "ca://mock_signal", name="mock_signal") + + # Any dtype allowed in soft signal + signal = soft_signal_rw(SomeClass, SomeClass(), "soft_signal") + assert isinstance((await signal.get_value()), SomeClass) + await signal.set(1) + assert (await signal.get_value()) == 1 diff --git a/tests/core/test_soft_signal_backend.py b/tests/core/test_soft_signal_backend.py index 5e55507626..16bf23567e 100644 --- a/tests/core/test_soft_signal_backend.py +++ b/tests/core/test_soft_signal_backend.py @@ -94,7 +94,7 @@ async def test_soft_signal_backend_get_put_monitor( descriptor: Callable[[Any], dict], dtype_numpy: str, ): - backend = SoftSignalBackend(datatype) + backend = SoftSignalBackend(datatype=datatype) await backend.connect() q = MonitorQueue(backend) diff --git a/tests/core/test_subset_enum.py b/tests/core/test_subset_enum.py index 41af248aac..8c638d2770 100644 --- a/tests/core/test_subset_enum.py +++ b/tests/core/test_subset_enum.py @@ -7,8 +7,8 @@ from ophyd_async.epics.signal import epics_signal_rw # Allow these imports from private modules for tests -from ophyd_async.epics.signal._aioca import make_converter as aioca_make_converter -from ophyd_async.epics.signal._p4p import make_converter as p4p_make_converter +from ophyd_async.epics.signal._aioca import make_converter as ca_make_converter +from ophyd_async.epics.signal._p4p import make_converter as pva_make_converter async def test_runtime_enum_behaviour(): @@ -52,7 +52,7 @@ def __init__(self): epics_value = EpicsValue() rt_enum = SubsetEnum["A", "B"] - converter = aioca_make_converter( + converter = ca_make_converter( rt_enum, values={"READ_PV": epics_value, "WRITE_PV": epics_value} ) assert converter.choices == {"A": "A", "B": "B", "C": "C"} @@ -68,7 +68,7 @@ async def test_pva_runtime_enum_converter(): }, ) rt_enum = SubsetEnum["A", "B"] - converter = p4p_make_converter( + converter = pva_make_converter( rt_enum, values={"READ_PV": epics_value, "WRITE_PV": epics_value} ) assert {"A", "B"}.issubset(set(converter.choices)) diff --git a/tests/fastcs/panda/test_panda_connect.py b/tests/fastcs/panda/test_panda_connect.py index 2685f3c66c..b6dcbe2b00 100644 --- a/tests/fastcs/panda/test_panda_connect.py +++ b/tests/fastcs/panda/test_panda_connect.py @@ -115,6 +115,7 @@ async def test_panda_children_connected(mock_panda): oute2=np.array([1, 0, 1, 0]).astype(np.bool_), outf2=np.array([1, 0, 0, 0]).astype(np.bool_), ) + await mock_panda.pulse[1].delay.set(20.0) await mock_panda.seq[1].table.set(table) diff --git a/tests/fastcs/panda/test_panda_utils.py b/tests/fastcs/panda/test_panda_utils.py index f5d4e02600..d8b9a01269 100644 --- a/tests/fastcs/panda/test_panda_utils.py +++ b/tests/fastcs/panda/test_panda_utils.py @@ -1,22 +1,20 @@ -from unittest.mock import patch - -import pytest +import numpy as np +import yaml from bluesky import RunEngine -from ophyd_async.core import DEFAULT_TIMEOUT, DeviceCollector, save_device +from ophyd_async.core import DEFAULT_TIMEOUT, DeviceCollector, load_device, save_device from ophyd_async.epics.pvi import fill_pvi_entries from ophyd_async.epics.signal import epics_signal_rw from ophyd_async.fastcs.panda import ( CommonPandaBlocks, DataBlock, - PcompDirectionOptions, + SeqTable, TimeUnits, phase_sorter, ) -@pytest.fixture -async def mock_panda(): +async def get_mock_panda(): class Panda(CommonPandaBlocks): data: DataBlock @@ -33,56 +31,112 @@ async def connect(self, mock: bool = False, timeout: float = DEFAULT_TIMEOUT): async with DeviceCollector(mock=True): mock_panda = Panda("PANDA") mock_panda.phase_1_signal_units = epics_signal_rw(int, "") - assert mock_panda.name == "mock_panda" - yield mock_panda + return mock_panda + + +async def test_save_load_panda(tmp_path, RE: RunEngine): + mock_panda1 = await get_mock_panda() + await mock_panda1.seq[1].table.set(SeqTable.row(repeats=1)) + RE(save_device(mock_panda1, str(tmp_path / "panda.yaml"), sorter=phase_sorter)) -@patch("ophyd_async.core._device_save_loader.save_to_yaml") -async def test_save_panda(mock_save_to_yaml, mock_panda, RE: RunEngine): - RE(save_device(mock_panda, "path", sorter=phase_sorter)) - mock_save_to_yaml.assert_called_once() - assert mock_save_to_yaml.call_args[0] == ( - [ - { - "phase_1_signal_units": 0, - "seq.1.prescale_units": TimeUnits("min"), - "seq.2.prescale_units": TimeUnits("min"), - }, - { - "data.capture": False, - "data.create_directory": 0, - "data.flush_period": 0.0, - "data.hdf_directory": "", - "data.hdf_file_name": "", - "data.num_capture": 0, - "pcap.arm": False, - "pcomp.1.dir": PcompDirectionOptions.positive, - "pcomp.1.enable": "ZERO", - "pcomp.1.pulses": 0, - "pcomp.1.start": 0, - "pcomp.1.step": 0, - "pcomp.1.width": 0, - "pcomp.2.dir": PcompDirectionOptions.positive, - "pcomp.2.enable": "ZERO", - "pcomp.2.pulses": 0, - "pcomp.2.start": 0, - "pcomp.2.step": 0, - "pcomp.2.width": 0, - "pulse.1.delay": 0.0, - "pulse.1.width": 0.0, - "pulse.2.delay": 0.0, - "pulse.2.width": 0.0, - "seq.1.active": False, - "seq.1.table": {}, - "seq.1.repeats": 0, - "seq.1.prescale": 0.0, - "seq.1.enable": "ZERO", - "seq.2.table": {}, - "seq.2.active": False, - "seq.2.repeats": 0, - "seq.2.prescale": 0.0, - "seq.2.enable": "ZERO", - }, - ], - "path", + def check_equal_with_seq_tables(actual, expected): + assert actual.model_fields_set == expected.model_fields_set + for field_name, field_value1 in actual: + field_value2 = getattr(expected, field_name) + assert np.array_equal(field_value1, field_value2) + + mock_panda2 = await get_mock_panda() + check_equal_with_seq_tables( + (await mock_panda2.seq[1].table.get_value()), SeqTable() ) + RE(load_device(mock_panda2, str(tmp_path / "panda.yaml"))) + + check_equal_with_seq_tables( + await mock_panda2.seq[1].table.get_value(), + SeqTable.row(repeats=1), + ) + + # Load the YAML content as a string + with open(str(tmp_path / "panda.yaml"), "r") as file: + yaml_content = file.read() + + # Parse the YAML content + parsed_yaml = yaml.safe_load(yaml_content) + + assert parsed_yaml[0] == { + "phase_1_signal_units": 0, + "seq.1.prescale_units": TimeUnits("min"), + "seq.2.prescale_units": TimeUnits("min"), + } + assert parsed_yaml[1] == { + "data.capture": False, + "data.create_directory": 0, + "data.flush_period": 0.0, + "data.hdf_directory": "", + "data.hdf_file_name": "", + "data.num_capture": 0, + "pcap.arm": False, + "pcomp.1.dir": "Positive", + "pcomp.1.enable": "ZERO", + "pcomp.1.pulses": 0, + "pcomp.1.start": 0, + "pcomp.1.step": 0, + "pcomp.1.width": 0, + "pcomp.2.dir": "Positive", + "pcomp.2.enable": "ZERO", + "pcomp.2.pulses": 0, + "pcomp.2.start": 0, + "pcomp.2.step": 0, + "pcomp.2.width": 0, + "pulse.1.delay": 0.0, + "pulse.1.width": 0.0, + "pulse.2.delay": 0.0, + "pulse.2.width": 0.0, + "seq.1.active": False, + "seq.1.table": { + "outa1": [False], + "outa2": [False], + "outb1": [False], + "outb2": [False], + "outc1": [False], + "outc2": [False], + "outd1": [False], + "outd2": [False], + "oute1": [False], + "oute2": [False], + "outf1": [False], + "outf2": [False], + "position": [0], + "repeats": [1], + "time1": [0], + "time2": [0], + "trigger": ["Immediate"], + }, + "seq.1.repeats": 0, + "seq.1.prescale": 0.0, + "seq.1.enable": "ZERO", + "seq.2.table": { + "outa1": [], + "outa2": [], + "outb1": [], + "outb2": [], + "outc1": [], + "outc2": [], + "outd1": [], + "outd2": [], + "oute1": [], + "oute2": [], + "outf1": [], + "outf2": [], + "position": [], + "repeats": [], + "time1": [], + "time2": [], + "trigger": [], + }, + "seq.2.active": False, + "seq.2.repeats": 0, + "seq.2.prescale": 0.0, + "seq.2.enable": "ZERO", + } diff --git a/tests/fastcs/panda/test_table.py b/tests/fastcs/panda/test_table.py index ad92683bbd..c5f5abb846 100644 --- a/tests/fastcs/panda/test_table.py +++ b/tests/fastcs/panda/test_table.py @@ -1,31 +1,226 @@ +from functools import reduce + import numpy as np import pytest +from pydantic import ValidationError + +from ophyd_async.fastcs.panda import SeqTable +from ophyd_async.fastcs.panda._table import SeqTrigger + + +def test_seq_table_converts_lists(): + seq_table_dict_with_lists = {field_name: [] for field_name, _ in SeqTable()} + # Validation passes + seq_table = SeqTable(**seq_table_dict_with_lists) + assert isinstance(seq_table.trigger, np.ndarray) + assert seq_table.trigger.dtype == np.dtype("U32") + + +def test_seq_table_validation_errors(): + with pytest.raises(ValidationError, match="81 validation errors for SeqTable"): + SeqTable( + repeats=0, + trigger="Immediate", + position=0, + time1=0, + outa1=False, + outb1=False, + outc1=False, + outd1=False, + oute1=False, + outf1=False, + time2=0, + outa2=False, + outb2=False, + outc2=False, + outd2=False, + oute2=False, + outf2=False, + ) + + large_seq_table = SeqTable( + repeats=np.zeros(4095, dtype=np.int32), + trigger=np.array(["Immediate"] * 4095, dtype="U32"), + position=np.zeros(4095, dtype=np.int32), + time1=np.zeros(4095, dtype=np.int32), + outa1=np.zeros(4095, dtype=np.bool_), + outb1=np.zeros(4095, dtype=np.bool_), + outc1=np.zeros(4095, dtype=np.bool_), + outd1=np.zeros(4095, dtype=np.bool_), + oute1=np.zeros(4095, dtype=np.bool_), + outf1=np.zeros(4095, dtype=np.bool_), + time2=np.zeros(4095, dtype=np.int32), + outa2=np.zeros(4095, dtype=np.bool_), + outb2=np.zeros(4095, dtype=np.bool_), + outc2=np.zeros(4095, dtype=np.bool_), + outd2=np.zeros(4095, dtype=np.bool_), + oute2=np.zeros(4095, dtype=np.bool_), + outf2=np.zeros(4095, dtype=np.bool_), + ) + with pytest.raises( + ValidationError, + match=( + "1 validation error for SeqTable\n " + "Assertion failed, Length 4096 not in range." + ), + ): + large_seq_table + SeqTable.row() + with pytest.raises( + ValidationError, + match="12 validation errors for SeqTable", + ): + row_one = SeqTable.row() + wrong_types = { + field_name: field_value.astype(np.unicode_) + for field_name, field_value in row_one + } + SeqTable(**wrong_types) -from ophyd_async.fastcs.panda import seq_table_from_arrays +def test_seq_table_pva_conversion(): + pva_dict = { + "repeats": np.array([1, 2, 3, 4], dtype=np.int32), + "trigger": np.array( + ["Immediate", "Immediate", "BITC=0", "Immediate"], dtype=np.dtype("U32") + ), + "position": np.array([1, 2, 3, 4], dtype=np.int32), + "time1": np.array([1, 0, 1, 0], dtype=np.int32), + "outa1": np.array([1, 0, 1, 0], dtype=np.bool_), + "outb1": np.array([1, 0, 1, 0], dtype=np.bool_), + "outc1": np.array([1, 0, 1, 0], dtype=np.bool_), + "outd1": np.array([1, 0, 1, 0], dtype=np.bool_), + "oute1": np.array([1, 0, 1, 0], dtype=np.bool_), + "outf1": np.array([1, 0, 1, 0], dtype=np.bool_), + "time2": np.array([1, 2, 3, 4], dtype=np.int32), + "outa2": np.array([1, 0, 1, 0], dtype=np.bool_), + "outb2": np.array([1, 0, 1, 0], dtype=np.bool_), + "outc2": np.array([1, 0, 1, 0], dtype=np.bool_), + "outd2": np.array([1, 0, 1, 0], dtype=np.bool_), + "oute2": np.array([1, 0, 1, 0], dtype=np.bool_), + "outf2": np.array([1, 0, 1, 0], dtype=np.bool_), + } + row_wise_dicts = [ + { + "repeats": 1, + "trigger": "Immediate", + "position": 1, + "time1": 1, + "outa1": 1, + "outb1": 1, + "outc1": 1, + "outd1": 1, + "oute1": 1, + "outf1": 1, + "time2": 1, + "outa2": 1, + "outb2": 1, + "outc2": 1, + "outd2": 1, + "oute2": 1, + "outf2": 1, + }, + { + "repeats": 2, + "trigger": "Immediate", + "position": 2, + "time1": 0, + "outa1": 0, + "outb1": 0, + "outc1": 0, + "outd1": 0, + "oute1": 0, + "outf1": 0, + "time2": 2, + "outa2": 0, + "outb2": 0, + "outc2": 0, + "outd2": 0, + "oute2": 0, + "outf2": 0, + }, + { + "repeats": 3, + "trigger": "BITC=0", + "position": 3, + "time1": 1, + "outa1": 1, + "outb1": 1, + "outc1": 1, + "outd1": 1, + "oute1": 1, + "outf1": 1, + "time2": 3, + "outa2": 1, + "outb2": 1, + "outc2": 1, + "outd2": 1, + "oute2": 1, + "outf2": 1, + }, + { + "repeats": 4, + "trigger": "Immediate", + "position": 4, + "time1": 0, + "outa1": 0, + "outb1": 0, + "outc1": 0, + "outd1": 0, + "oute1": 0, + "outf1": 0, + "time2": 4, + "outa2": 0, + "outb2": 0, + "outc2": 0, + "outd2": 0, + "oute2": 0, + "outf2": 0, + }, + ] -def test_from_arrays_inconsistent_lengths(): - length = 4 - time2 = np.zeros(length) - time1 = np.zeros(length + 1) - with pytest.raises(ValueError, match="time1: has length 5 not 4"): - seq_table_from_arrays(time2=time2, time1=time1) - time1 = np.zeros(length - 1) - with pytest.raises(ValueError, match="time1: has length 3 not 4"): - seq_table_from_arrays(time2=time2, time1=time1) + seq_table_from_pva_dict = SeqTable(**pva_dict) + for (_, column1), column2 in zip(seq_table_from_pva_dict, pva_dict.values()): + assert np.array_equal(column1, column2) + assert column1.dtype == column2.dtype + seq_table_from_rows = reduce( + lambda x, y: x + y, + [SeqTable.row(**row_kwargs) for row_kwargs in row_wise_dicts], + ) + for (_, column1), column2 in zip(seq_table_from_rows, pva_dict.values()): + assert np.array_equal(column1, column2) + assert column1.dtype == column2.dtype -def test_from_arrays_no_time(): - with pytest.raises(AssertionError, match="time2 must be provided"): - seq_table_from_arrays(time2=None) # type: ignore - with pytest.raises(TypeError, match="required keyword-only argument: 'time2'"): - seq_table_from_arrays() # type: ignore - time2 = np.zeros(0) - with pytest.raises(AssertionError, match="Length 0 not in range"): - seq_table_from_arrays(time2=time2) + # Idempotency + applied_twice_to_pva_dict = SeqTable(**pva_dict).model_dump(mode="python") + for column1, column2 in zip(applied_twice_to_pva_dict.values(), pva_dict.values()): + assert np.array_equal(column1, column2) + assert column1.dtype == column2.dtype -def test_from_arrays_too_long(): - time2 = np.zeros(4097) - with pytest.raises(AssertionError, match="Length 4097 not in range"): - seq_table_from_arrays(time2=time2) +def test_seq_table_takes_trigger_enum_row(): + for trigger in (SeqTrigger.BITA_0, "BITA=0"): + table = SeqTable.row(trigger=trigger) + assert table.trigger[0] == "BITA=0" + assert np.issubdtype(table.trigger.dtype, np.dtype("