diff --git a/src/ophyd_async/core/device_save_loader.py b/src/ophyd_async/core/device_save_loader.py index 88168b462d..77db3d37d6 100644 --- a/src/ophyd_async/core/device_save_loader.py +++ b/src/ophyd_async/core/device_save_loader.py @@ -1,6 +1,5 @@ from enum import Enum -from functools import partial -from typing import Any, Callable, Dict, Generator, List, Optional, Sequence, Union +from typing import Any, Callable, Dict, Generator, List, Optional, Sequence import numpy as np import numpy.typing as npt @@ -8,13 +7,10 @@ from bluesky.plan_stubs import abs_set, wait from bluesky.protocols import Location from bluesky.utils import Msg -from epicscorelibs.ca.dbr import ca_array, ca_float, ca_int, ca_str from .device import Device from .signal import SignalRW -CaType = Union[ca_float, ca_int, ca_str, ca_array] - def ndarray_representer(dumper: yaml.Dumper, array: npt.NDArray[Any]) -> yaml.Node: return dumper.represent_sequence( @@ -22,19 +18,6 @@ def ndarray_representer(dumper: yaml.Dumper, array: npt.NDArray[Any]) -> yaml.No ) -def ca_dbr_representer(dumper: yaml.Dumper, value: CaType) -> yaml.Node: - # if it's an array, just call ndarray_representer... - represent_array = partial(ndarray_representer, dumper) - - representers: Dict[CaType, Callable[[CaType], yaml.Node]] = { - ca_float: dumper.represent_float, - ca_int: dumper.represent_int, - ca_str: dumper.represent_str, - ca_array: represent_array, - } - return representers[type(value)](value) - - class OphydDumper(yaml.Dumper): def represent_data(self, data: Any) -> Any: if isinstance(data, Enum): @@ -152,11 +135,6 @@ def save_to_yaml(phases: Sequence[Dict[str, Any]], save_path: str) -> None: yaml.add_representer(np.ndarray, ndarray_representer, Dumper=yaml.Dumper) - yaml.add_representer(ca_float, ca_dbr_representer, Dumper=yaml.Dumper) - yaml.add_representer(ca_int, ca_dbr_representer, Dumper=yaml.Dumper) - yaml.add_representer(ca_str, ca_dbr_representer, Dumper=yaml.Dumper) - yaml.add_representer(ca_array, ca_dbr_representer, Dumper=yaml.Dumper) - with open(save_path, "w") as file: yaml.dump(phases, file, Dumper=OphydDumper, default_flow_style=False) diff --git a/src/ophyd_async/epics/_backend/_aioca.py b/src/ophyd_async/epics/_backend/_aioca.py index 87e3395dc9..89bf6d256b 100644 --- a/src/ophyd_async/epics/_backend/_aioca.py +++ b/src/ophyd_async/epics/_backend/_aioca.py @@ -2,8 +2,9 @@ import sys from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, Optional, Sequence, Type, Union +from typing import Any, Dict, Optional, Type, Union +import numpy as np from aioca import ( FORMAT_CTRL, FORMAT_RAW, @@ -49,7 +50,10 @@ def write_value(self, value) -> Any: return value def value(self, value: AugmentedValue): - return value + # for channel access ca_xxx classes, this + # invokes __pos__ operator to return an instance of + # the builtin base class + return +value def reading(self, value: AugmentedValue): return { @@ -76,6 +80,9 @@ class CaArrayConverter(CaConverter): def get_datakey(self, source: str, value: AugmentedValue) -> DataKey: return {"source": source, "dtype": "array", "shape": [len(value)]} + def value(self, value: AugmentedValue): + return np.array(value, copy=False) + @dataclass class CaEnumConverter(CaConverter): @@ -115,8 +122,10 @@ def make_converter( return CaLongStrConverter() elif is_array and pv_dbr == dbr.DBR_STRING: # Waveform of strings, check we wanted this - if datatype and datatype != Sequence[str]: - raise TypeError(f"{pv} has type [str] not {datatype.__name__}") + if datatype: + datatype_dtype = get_dtype(datatype) + if not datatype_dtype or not np.can_cast(datatype_dtype, np.str_): + raise TypeError(f"{pv} has type [str] not {datatype.__name__}") return CaArrayConverter(pv_dbr, None) elif is_array: pv_dtype = get_unique({k: v.dtype for k, v in values.items()}, "dtypes") diff --git a/tests/core/test_device_save_loader.py b/tests/core/test_device_save_loader.py index 688c0cdfb5..b0e1da0c4f 100644 --- a/tests/core/test_device_save_loader.py +++ b/tests/core/test_device_save_loader.py @@ -1,6 +1,6 @@ from enum import Enum from os import path -from typing import Any, Dict, List +from typing import Any, Dict, List, Sequence from unittest.mock import patch import numpy as np @@ -48,6 +48,33 @@ def __init__(self, name: str): self.position: npt.NDArray[np.int32] +class MyEnum(str, Enum): + one = "one" + two = "two" + three = "three" + + +class DummyDeviceGroupAllTypes(Device): + def __init__(self, name: str): + self.pv_int: SignalRW = epics_signal_rw(int, "PV1") + self.pv_float: SignalRW = epics_signal_rw(float, "PV2") + self.pv_str: SignalRW = epics_signal_rw(str, "PV2") + self.pv_enum_str: SignalRW = epics_signal_rw(MyEnum, "PV3") + self.pv_enum: SignalRW = epics_signal_rw(MyEnum, "PV4") + self.pv_array_int8 = epics_signal_rw(npt.NDArray[np.int8], "PV5") + self.pv_array_uint8 = epics_signal_rw(npt.NDArray[np.uint8], "PV6") + self.pv_array_int16 = epics_signal_rw(npt.NDArray[np.int16], "PV7") + self.pv_array_uint16 = epics_signal_rw(npt.NDArray[np.uint16], "PV8") + self.pv_array_int32 = epics_signal_rw(npt.NDArray[np.int32], "PV9") + self.pv_array_uint32 = epics_signal_rw(npt.NDArray[np.uint32], "PV10") + self.pv_array_int64 = epics_signal_rw(npt.NDArray[np.int64], "PV11") + self.pv_array_uint64 = epics_signal_rw(npt.NDArray[np.uint64], "PV12") + self.pv_array_float32 = epics_signal_rw(npt.NDArray[np.float32], "PV13") + self.pv_array_float64 = epics_signal_rw(npt.NDArray[np.float64], "PV14") + self.pv_array_npstr = epics_signal_rw(npt.NDArray[np.str_], "PV15") + self.pv_array_str = epics_signal_rw(Sequence[str], "PV16") + + @pytest.fixture async def device() -> DummyDeviceGroup: device = DummyDeviceGroup("parent") @@ -55,6 +82,13 @@ async def device() -> DummyDeviceGroup: return device +@pytest.fixture +async def device_all_types() -> DummyDeviceGroupAllTypes: + device = DummyDeviceGroupAllTypes("parent") + await device.connect(mock=True) + return device + + # Dummy function to check different phases save properly def sort_signal_by_phase(values: Dict[str, Any]) -> List[Dict[str, Any]]: phase_1 = {"child1.sig1": values["child1.sig1"]} @@ -73,6 +107,70 @@ async def test_enum_yaml_formatting(tmp_path): assert saved_enums == enums +async def test_save_device_all_types(RE: RunEngine, device_all_types, tmp_path): + # Populate fake device with PV's... + await device_all_types.pv_int.set(1) + await device_all_types.pv_float.set(1.234) + await device_all_types.pv_str.set("test_string") + await device_all_types.pv_enum_str.set("two") + await device_all_types.pv_enum.set(MyEnum.two) + for pv, dtype in { + device_all_types.pv_array_int8: np.int8, + device_all_types.pv_array_uint8: np.uint8, + device_all_types.pv_array_int16: np.int16, + device_all_types.pv_array_uint16: np.uint16, + device_all_types.pv_array_int32: np.int32, + device_all_types.pv_array_uint32: np.uint32, + device_all_types.pv_array_int64: np.int64, + device_all_types.pv_array_uint64: np.uint64, + }.items(): + await pv.set( + np.array( + [np.iinfo(dtype).min, np.iinfo(dtype).max, 0, 1, 2, 3, 4], dtype=dtype + ) + ) + for pv, dtype in { + device_all_types.pv_array_float32: np.float32, + device_all_types.pv_array_float64: np.float64, + }.items(): + finfo = np.finfo(dtype) + data = np.array( + [ + finfo.min, + finfo.max, + finfo.smallest_normal, + finfo.smallest_subnormal, + 0, + 1.234, + 2.34e5, + 3.45e-6, + ], + dtype=dtype, + ) + + await pv.set(data) + await device_all_types.pv_array_npstr.set( + np.array(["one", "two", "three"], dtype=np.str_), + ) + await device_all_types.pv_array_str.set( + ["one", "two", "three"], + ) + + # Create save plan from utility functions + def save_my_device(): + signalRWs = walk_rw_signals(device_all_types) + values = yield from get_signal_values(signalRWs) + + save_to_yaml([values], path.join(tmp_path, "test_file.yaml")) + + RE(save_my_device()) + + actual_file_path = path.join(tmp_path, "test_file.yaml") + with open(actual_file_path, "r") as actual_file: + with open("tests/test_data/test_yaml_save.yml") as expected_file: + assert actual_file.read() == expected_file.read() + + async def test_save_device(RE: RunEngine, device, tmp_path): # Populate fake device with PV's... await device.child1.sig1.set("test_string") diff --git a/tests/epics/test_signals.py b/tests/epics/test_signals.py index ff6fce24f6..7e80545da1 100644 --- a/tests/epics/test_signals.py +++ b/tests/epics/test_signals.py @@ -9,6 +9,7 @@ from dataclasses import dataclass from enum import Enum from pathlib import Path +from types import GenericAlias from typing import ( Any, Callable, @@ -28,7 +29,7 @@ from aioca import CANothing, purge_channel_caches from bluesky.protocols import Reading -from ophyd_async.core import SignalBackend, T, get_dtype, load_from_yaml, save_to_yaml +from ophyd_async.core import SignalBackend, T, load_from_yaml, save_to_yaml from ophyd_async.core.utils import NotConnected from ophyd_async.epics.signal._epics_transport import EpicsTransport from ophyd_async.epics.signal.signal import ( @@ -102,6 +103,28 @@ def ioc(request): pass +def assert_types_are_equal(t_actual, t_expected, actual_value): + expected_plain_type = getattr(t_expected, "__origin__", t_expected) + if issubclass(expected_plain_type, np.ndarray): + actual_plain_type = getattr(t_actual, "__origin__", t_actual) + assert actual_plain_type == expected_plain_type + actual_dtype_type = actual_value.dtype.type + expected_dtype_type = t_expected.__args__[1].__args__[0] + assert actual_dtype_type == expected_dtype_type + elif ( + expected_plain_type is not str + and not issubclass(expected_plain_type, Enum) + and issubclass(expected_plain_type, Sequence) + ): + actual_plain_type = getattr(t_actual, "__origin__", t_actual) + assert issubclass(actual_plain_type, expected_plain_type) + assert len(actual_value) == 0 or isinstance( + actual_value[0], t_expected.__args__[0] + ) + else: + assert t_actual == t_expected + + class MonitorQueue: def __init__(self, backend: SignalBackend): self.backend = backend @@ -111,7 +134,7 @@ def __init__(self, backend: SignalBackend): def add_reading_value(self, reading: Reading, value): self.updates.put_nowait((reading, value)) - async def assert_updates(self, expected_value): + async def assert_updates(self, expected_value, expected_type=None): expected_reading = { "value": expected_value, "timestamp": pytest.approx(time.time(), rel=0.1), @@ -122,12 +145,22 @@ async def assert_updates(self, expected_value): backend_value = await asyncio.wait_for(self.backend.get_value(), timeout=5) assert value == expected_value == backend_value + if expected_type: + assert_types_are_equal(type(value), expected_type, value) + assert_types_are_equal(type(backend_value), expected_type, backend_value) assert reading == expected_reading == backend_reading def close(self): self.backend.set_callback(None) +def _is_numpy_subclass(t): + if t is None: + return False + plain_type = t.__origin__ if isinstance(t, GenericAlias) else t + return issubclass(plain_type, np.ndarray) + + async def assert_monitor_then_put( ioc: IOC, suffix: str, @@ -135,6 +168,7 @@ async def assert_monitor_then_put( initial_value: T, put_value: T, datatype: Optional[Type[T]] = None, + check_type: Optional[bool] = True, ): backend = await ioc.make_backend(datatype, suffix) # Make a monitor queue that will monitor for updates @@ -144,10 +178,15 @@ async def assert_monitor_then_put( pv_name = f"{ioc.protocol}://{PV_PREFIX}:{ioc.protocol}:{suffix}" assert dict(source=pv_name, **descriptor) == await backend.get_datakey(pv_name) # Check initial value - await q.assert_updates(pytest.approx(initial_value)) + await q.assert_updates( + pytest.approx(initial_value), + datatype if check_type else None, + ) # Put to new value and check that await backend.put(put_value) - await q.assert_updates(pytest.approx(put_value)) + await q.assert_updates( + pytest.approx(put_value), datatype if check_type else None + ) finally: q.close() @@ -195,34 +234,113 @@ def waveform_d(value): ls1 = "a string that is just longer than forty characters" ls2 = "another string that is just longer than forty characters" -ca_dtype_mapping = { - np.int8: np.uint8, - np.uint16: np.int32, - np.uint32: np.float64, - np.int64: np.float64, - np.uint64: np.float64, -} - @pytest.mark.parametrize( - "datatype, suffix, initial_value, put_value, descriptor", + "datatype, suffix, initial_value, put_value, descriptor, supported_backends", [ - (int, "int", 42, 43, integer_d), - (float, "float", 3.141, 43.5, number_d), - (str, "str", "hello", "goodbye", string_d), - (MyEnum, "enum", MyEnum.b, MyEnum.c, enum_d), - (str, "enum", "Bbb", "Ccc", enum_d), - (npt.NDArray[np.int8], "int8a", [-128, 127], [-8, 3, 44], waveform_d), - (npt.NDArray[np.uint8], "uint8a", [0, 255], [218], waveform_d), - (npt.NDArray[np.int16], "int16a", [-32768, 32767], [-855], waveform_d), - (npt.NDArray[np.uint16], "uint16a", [0, 65535], [5666], waveform_d), - (npt.NDArray[np.int32], "int32a", [-2147483648, 2147483647], [-2], waveform_d), - (npt.NDArray[np.uint32], "uint32a", [0, 4294967295], [1022233], waveform_d), - (npt.NDArray[np.int64], "int64a", [-2147483649, 2147483648], [-3], waveform_d), - (npt.NDArray[np.uint64], "uint64a", [0, 4294967297], [995444], waveform_d), - (npt.NDArray[np.float32], "float32a", [0.000002, -123.123], [1.0], waveform_d), - (npt.NDArray[np.float64], "float64a", [0.1, -12345678.123], [0.2], waveform_d), - (Sequence[str], "stra", ["five", "six", "seven"], ["nine", "ten"], waveform_d), + # python builtin scalars + (int, "int", 42, 43, integer_d, {"ca", "pva"}), + (float, "float", 3.141, 43.5, number_d, {"ca", "pva"}), + (str, "str", "hello", "goodbye", string_d, {"ca", "pva"}), + (MyEnum, "enum", MyEnum.b, MyEnum.c, enum_d, {"ca", "pva"}), + (str, "enum", "Bbb", "Ccc", enum_d, {"ca", "pva"}), + # numpy arrays of numpy types + ( + npt.NDArray[np.int8], + "int8a", + [-128, 127], + [-8, 3, 44], + waveform_d, + {"pva"}, + ), + ( + npt.NDArray[np.uint8], + "uint8a", + [0, 255], + [218], + waveform_d, + {"ca", "pva"}, + ), + ( + npt.NDArray[np.int16], + "int16a", + [-32768, 32767], + [-855], + waveform_d, + {"ca", "pva"}, + ), + ( + npt.NDArray[np.uint16], + "uint16a", + [0, 65535], + [5666], + waveform_d, + {"pva"}, + ), + ( + npt.NDArray[np.int32], + "int32a", + [-2147483648, 2147483647], + [-2], + waveform_d, + {"ca", "pva"}, + ), + ( + npt.NDArray[np.uint32], + "uint32a", + [0, 4294967295], + [1022233], + waveform_d, + {"pva"}, + ), + ( + npt.NDArray[np.int64], + "int64a", + [-2147483649, 2147483648], + [-3], + waveform_d, + {"pva"}, + ), + ( + npt.NDArray[np.uint64], + "uint64a", + [0, 4294967297], + [995444], + waveform_d, + {"pva"}, + ), + ( + npt.NDArray[np.float32], + "float32a", + [0.000002, -123.123], + [1.0], + waveform_d, + {"ca", "pva"}, + ), + ( + npt.NDArray[np.float64], + "float64a", + [0.1, -12345678.123], + [0.2], + waveform_d, + {"ca", "pva"}, + ), + ( + Sequence[str], + "stra", + ["five", "six", "seven"], + ["nine", "ten"], + waveform_d, + {"pva"}, + ), + ( + npt.NDArray[np.str_], + "stra", + ["five", "six", "seven"], + ["nine", "ten"], + waveform_d, + {"ca"}, + ), # Can't do long strings until https://github.com/epics-base/pva2pva/issues/17 # (str, "longstr", ls1, ls2, string_d), # (str, "longstr2.VAL$", ls1, ls2, string_d), @@ -236,25 +354,31 @@ async def test_backend_get_put_monitor( put_value: T, descriptor: Callable[[Any], dict], tmp_path, + supported_backends: set[str], ): # ca can't support all the types - dtype = get_dtype(datatype) - if ioc.protocol == "ca" and dtype and dtype.type in ca_dtype_mapping: - if dtype == np.int8: - # CA maps uint8 onto int8 rather than upcasting, so we need to change - # initial array - initial_value, put_value = [ # type: ignore - np.array(x).astype(np.uint8) for x in (initial_value, put_value) - ] - datatype = npt.NDArray[ca_dtype_mapping[dtype.type]] # type: ignore + for backend in supported_backends: + assert backend in ["ca", "pva"] + if ioc.protocol not in supported_backends: + return # With the given datatype, check we have the correct initial value and putting # works await assert_monitor_then_put( - ioc, suffix, descriptor(initial_value), initial_value, put_value, datatype + ioc, + suffix, + descriptor(initial_value), + initial_value, + put_value, + datatype, ) # With datatype guessed from CA/PVA, check we can set it back to the initial value await assert_monitor_then_put( - ioc, suffix, descriptor(put_value), put_value, initial_value, datatype=None + ioc, + suffix, + descriptor(put_value), + put_value, + initial_value, + datatype=None, ) yaml_path = tmp_path / "test.yaml" @@ -272,6 +396,7 @@ async def test_bool_conversion_of_enum(ioc: IOC, suffix: str) -> None: initial_value=True, put_value=False, datatype=bool, + check_type=False, ) diff --git a/tests/test_data/test_yaml_save.yml b/tests/test_data/test_yaml_save.yml new file mode 100644 index 0000000000..fc3e1ebd95 --- /dev/null +++ b/tests/test_data/test_yaml_save.yml @@ -0,0 +1,22 @@ +- pv_array_float32: [-3.4028234663852886e+38, 3.4028234663852886e+38, 1.1754943508222875e-38, + 1.401298464324817e-45, 0.0, 1.2339999675750732, 234000.0, 3.4499998946557753e-06] + pv_array_float64: [-1.7976931348623157e+308, 1.7976931348623157e+308, 2.2250738585072014e-308, + 5.0e-324, 0.0, 1.234, 234000.0, 3.45e-06] + pv_array_int16: [-32768, 32767, 0, 1, 2, 3, 4] + pv_array_int32: [-2147483648, 2147483647, 0, 1, 2, 3, 4] + pv_array_int64: [-9223372036854775808, 9223372036854775807, 0, 1, 2, 3, 4] + pv_array_int8: [-128, 127, 0, 1, 2, 3, 4] + pv_array_npstr: [one, two, three] + pv_array_str: + - one + - two + - three + pv_array_uint16: [0, 65535, 0, 1, 2, 3, 4] + pv_array_uint32: [0, 4294967295, 0, 1, 2, 3, 4] + pv_array_uint64: [0, 18446744073709551615, 0, 1, 2, 3, 4] + pv_array_uint8: [0, 255, 0, 1, 2, 3, 4] + pv_enum: two + pv_enum_str: two + pv_float: 1.234 + pv_int: 1 + pv_str: test_string