Skip to content

Commit

Permalink
361 type conversions for python builtins (#364)
Browse files Browse the repository at this point in the history
* (#361) Implement conversions to python builtins for str, int and float in CA backend

* (#361) Remove redundant representers from device_save_loader

* (#361) Changes from PR review comments

* (#361) Add more comprehensive test for saving to yaml

* (#361) Additional response to PR comments

* (#361) Remove sprurious add_representer line that was missed

---------

Co-authored-by: Tom C (DLS) <101418278+coretl@users.noreply.github.com>
  • Loading branch information
rtuck99 and coretl committed Jun 11, 2024
1 parent cbbb295 commit b59d1e2
Show file tree
Hide file tree
Showing 5 changed files with 300 additions and 68 deletions.
24 changes: 1 addition & 23 deletions src/ophyd_async/core/device_save_loader.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,23 @@
from enum import Enum
from functools import partial
from typing import Any, Callable, Dict, Generator, List, Optional, Sequence, Union
from typing import Any, Callable, Dict, Generator, List, Optional, Sequence

import numpy as np
import numpy.typing as npt
import yaml
from bluesky.plan_stubs import abs_set, wait
from bluesky.protocols import Location
from bluesky.utils import Msg
from epicscorelibs.ca.dbr import ca_array, ca_float, ca_int, ca_str

from .device import Device
from .signal import SignalRW

CaType = Union[ca_float, ca_int, ca_str, ca_array]


def ndarray_representer(dumper: yaml.Dumper, array: npt.NDArray[Any]) -> yaml.Node:
return dumper.represent_sequence(
"tag:yaml.org,2002:seq", array.tolist(), flow_style=True
)


def ca_dbr_representer(dumper: yaml.Dumper, value: CaType) -> yaml.Node:
# if it's an array, just call ndarray_representer...
represent_array = partial(ndarray_representer, dumper)

representers: Dict[CaType, Callable[[CaType], yaml.Node]] = {
ca_float: dumper.represent_float,
ca_int: dumper.represent_int,
ca_str: dumper.represent_str,
ca_array: represent_array,
}
return representers[type(value)](value)


class OphydDumper(yaml.Dumper):
def represent_data(self, data: Any) -> Any:
if isinstance(data, Enum):
Expand Down Expand Up @@ -152,11 +135,6 @@ def save_to_yaml(phases: Sequence[Dict[str, Any]], save_path: str) -> None:

yaml.add_representer(np.ndarray, ndarray_representer, Dumper=yaml.Dumper)

yaml.add_representer(ca_float, ca_dbr_representer, Dumper=yaml.Dumper)
yaml.add_representer(ca_int, ca_dbr_representer, Dumper=yaml.Dumper)
yaml.add_representer(ca_str, ca_dbr_representer, Dumper=yaml.Dumper)
yaml.add_representer(ca_array, ca_dbr_representer, Dumper=yaml.Dumper)

with open(save_path, "w") as file:
yaml.dump(phases, file, Dumper=OphydDumper, default_flow_style=False)

Expand Down
17 changes: 13 additions & 4 deletions src/ophyd_async/epics/_backend/_aioca.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import sys
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, Optional, Sequence, Type, Union
from typing import Any, Dict, Optional, Type, Union

import numpy as np
from aioca import (
FORMAT_CTRL,
FORMAT_RAW,
Expand Down Expand Up @@ -49,7 +50,10 @@ def write_value(self, value) -> Any:
return value

def value(self, value: AugmentedValue):
return value
# for channel access ca_xxx classes, this
# invokes __pos__ operator to return an instance of
# the builtin base class
return +value

def reading(self, value: AugmentedValue):
return {
Expand All @@ -76,6 +80,9 @@ class CaArrayConverter(CaConverter):
def get_datakey(self, source: str, value: AugmentedValue) -> DataKey:
return {"source": source, "dtype": "array", "shape": [len(value)]}

def value(self, value: AugmentedValue):
return np.array(value, copy=False)


@dataclass
class CaEnumConverter(CaConverter):
Expand Down Expand Up @@ -115,8 +122,10 @@ def make_converter(
return CaLongStrConverter()
elif is_array and pv_dbr == dbr.DBR_STRING:
# Waveform of strings, check we wanted this
if datatype and datatype != Sequence[str]:
raise TypeError(f"{pv} has type [str] not {datatype.__name__}")
if datatype:
datatype_dtype = get_dtype(datatype)
if not datatype_dtype or not np.can_cast(datatype_dtype, np.str_):
raise TypeError(f"{pv} has type [str] not {datatype.__name__}")
return CaArrayConverter(pv_dbr, None)
elif is_array:
pv_dtype = get_unique({k: v.dtype for k, v in values.items()}, "dtypes")
Expand Down
100 changes: 99 additions & 1 deletion tests/core/test_device_save_loader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from enum import Enum
from os import path
from typing import Any, Dict, List
from typing import Any, Dict, List, Sequence
from unittest.mock import patch

import numpy as np
Expand Down Expand Up @@ -48,13 +48,47 @@ def __init__(self, name: str):
self.position: npt.NDArray[np.int32]


class MyEnum(str, Enum):
one = "one"
two = "two"
three = "three"


class DummyDeviceGroupAllTypes(Device):
def __init__(self, name: str):
self.pv_int: SignalRW = epics_signal_rw(int, "PV1")
self.pv_float: SignalRW = epics_signal_rw(float, "PV2")
self.pv_str: SignalRW = epics_signal_rw(str, "PV2")
self.pv_enum_str: SignalRW = epics_signal_rw(MyEnum, "PV3")
self.pv_enum: SignalRW = epics_signal_rw(MyEnum, "PV4")
self.pv_array_int8 = epics_signal_rw(npt.NDArray[np.int8], "PV5")
self.pv_array_uint8 = epics_signal_rw(npt.NDArray[np.uint8], "PV6")
self.pv_array_int16 = epics_signal_rw(npt.NDArray[np.int16], "PV7")
self.pv_array_uint16 = epics_signal_rw(npt.NDArray[np.uint16], "PV8")
self.pv_array_int32 = epics_signal_rw(npt.NDArray[np.int32], "PV9")
self.pv_array_uint32 = epics_signal_rw(npt.NDArray[np.uint32], "PV10")
self.pv_array_int64 = epics_signal_rw(npt.NDArray[np.int64], "PV11")
self.pv_array_uint64 = epics_signal_rw(npt.NDArray[np.uint64], "PV12")
self.pv_array_float32 = epics_signal_rw(npt.NDArray[np.float32], "PV13")
self.pv_array_float64 = epics_signal_rw(npt.NDArray[np.float64], "PV14")
self.pv_array_npstr = epics_signal_rw(npt.NDArray[np.str_], "PV15")
self.pv_array_str = epics_signal_rw(Sequence[str], "PV16")


@pytest.fixture
async def device() -> DummyDeviceGroup:
device = DummyDeviceGroup("parent")
await device.connect(mock=True)
return device


@pytest.fixture
async def device_all_types() -> DummyDeviceGroupAllTypes:
device = DummyDeviceGroupAllTypes("parent")
await device.connect(mock=True)
return device


# Dummy function to check different phases save properly
def sort_signal_by_phase(values: Dict[str, Any]) -> List[Dict[str, Any]]:
phase_1 = {"child1.sig1": values["child1.sig1"]}
Expand All @@ -73,6 +107,70 @@ async def test_enum_yaml_formatting(tmp_path):
assert saved_enums == enums


async def test_save_device_all_types(RE: RunEngine, device_all_types, tmp_path):
# Populate fake device with PV's...
await device_all_types.pv_int.set(1)
await device_all_types.pv_float.set(1.234)
await device_all_types.pv_str.set("test_string")
await device_all_types.pv_enum_str.set("two")
await device_all_types.pv_enum.set(MyEnum.two)
for pv, dtype in {
device_all_types.pv_array_int8: np.int8,
device_all_types.pv_array_uint8: np.uint8,
device_all_types.pv_array_int16: np.int16,
device_all_types.pv_array_uint16: np.uint16,
device_all_types.pv_array_int32: np.int32,
device_all_types.pv_array_uint32: np.uint32,
device_all_types.pv_array_int64: np.int64,
device_all_types.pv_array_uint64: np.uint64,
}.items():
await pv.set(
np.array(
[np.iinfo(dtype).min, np.iinfo(dtype).max, 0, 1, 2, 3, 4], dtype=dtype
)
)
for pv, dtype in {
device_all_types.pv_array_float32: np.float32,
device_all_types.pv_array_float64: np.float64,
}.items():
finfo = np.finfo(dtype)
data = np.array(
[
finfo.min,
finfo.max,
finfo.smallest_normal,
finfo.smallest_subnormal,
0,
1.234,
2.34e5,
3.45e-6,
],
dtype=dtype,
)

await pv.set(data)
await device_all_types.pv_array_npstr.set(
np.array(["one", "two", "three"], dtype=np.str_),
)
await device_all_types.pv_array_str.set(
["one", "two", "three"],
)

# Create save plan from utility functions
def save_my_device():
signalRWs = walk_rw_signals(device_all_types)
values = yield from get_signal_values(signalRWs)

save_to_yaml([values], path.join(tmp_path, "test_file.yaml"))

RE(save_my_device())

actual_file_path = path.join(tmp_path, "test_file.yaml")
with open(actual_file_path, "r") as actual_file:
with open("tests/test_data/test_yaml_save.yml") as expected_file:
assert actual_file.read() == expected_file.read()


async def test_save_device(RE: RunEngine, device, tmp_path):
# Populate fake device with PV's...
await device.child1.sig1.set("test_string")
Expand Down
Loading

0 comments on commit b59d1e2

Please sign in to comment.