Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PerSignalConfig dataclass for StandardReadable devices to be used with a common implementation of prepare. #543

Closed
wants to merge 13 commits into from
Closed
3 changes: 2 additions & 1 deletion src/ophyd_async/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
UUIDFilenameProvider,
YMDPathProvider,
)
from ._readable import ConfigSignal, HintedSignal, StandardReadable
from ._readable import ConfigSignal, HintedSignal, PerSignalConfig, StandardReadable
from ._signal import (
Signal,
SignalR,
Expand Down Expand Up @@ -167,4 +167,5 @@
"is_pydantic_model",
"wait_for_connection",
"completed_status",
"PerSignalConfig",
]
28 changes: 25 additions & 3 deletions src/ophyd_async/core/_readable.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import asyncio
import warnings
from collections.abc import Callable, Generator, Sequence
from contextlib import contextmanager
from typing import (
Any,
TypeVar,
)

from bluesky.protocols import HasHints, Hints, Reading
from bluesky.protocols import HasHints, Hints, Preparable, Reading
from event_model import DataKey

from ._device import Device, DeviceVector
from ._protocol import AsyncConfigurable, AsyncReadable, AsyncStageable
from ._signal import SignalR
from ._signal import SignalR, SignalW
from ._status import AsyncStatus
from ._utils import merge_gathered_dicts

Expand All @@ -18,9 +23,19 @@
| type["HintedSignal"]
)

T = TypeVar("T")


class PerSignalConfig(dict[SignalW, Any]):
def __setitem__(self, signal: SignalW[T], value: T):
super().__setitem__(signal, value)

def __getitem__(self, signal: SignalW[T]) -> T:
return super().__getitem__(signal)


class StandardReadable(
Device, AsyncReadable, AsyncConfigurable, AsyncStageable, HasHints
Device, AsyncReadable, AsyncConfigurable, AsyncStageable, HasHints, Preparable
):
"""Device that owns its children and provides useful default behavior.

Expand Down Expand Up @@ -214,6 +229,13 @@ def add_readables(
if isinstance(obj, HasHints):
self._has_hints += (obj,)

@AsyncStatus.wrap
async def prepare(self, value: PerSignalConfig) -> None:
tasks = []
for signal, new_value in value.items():
tasks.append(signal.set(new_value))
await asyncio.gather(*tasks)


class ConfigSignal(AsyncConfigurable):
def __init__(self, signal: ReadableChild) -> None:
Expand Down
77 changes: 77 additions & 0 deletions tests/core/test_readable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import get_type_hints
from unittest.mock import MagicMock

import numpy as np
import pytest
from bluesky.protocols import HasHints

Expand All @@ -14,7 +15,11 @@
DeviceVector,
HintedSignal,
MockSignalBackend,
PerSignalConfig,
SignalR,
SignalRW,
SignalW,
SoftSignalBackend,
StandardReadable,
soft_signal_r_and_setter,
)
Expand Down Expand Up @@ -238,3 +243,75 @@ def test_standard_readable_add_children_multi_nested():
with outer.add_children_as_readables():
outer.inner = inner
assert outer


@pytest.fixture
def standard_readable_config():
return PerSignalConfig()


test_data = [
("test_int", int, 42),
("test_float", float, 3.14),
("test_str", str, "hello"),
("test_bool", bool, True),
("test_list", list, [1, 2, 3]),
("test_tuple", tuple, (1, 2, 3)),
("test_dict", dict, {"key": "value"}),
("test_set", set, {1, 2, 3}),
("test_frozenset", frozenset, frozenset([1, 2, 3])),
("test_bytes", bytes, b"hello"),
("test_bytearray", bytearray, bytearray(b"hello")),
("test_complex", complex, 1 + 2j),
("test_nonetype", type(None), None),
("test_ndarray", np.ndarray, np.array([1, 2, 3])),
]


@pytest.mark.parametrize("name, type_, value", test_data)
def test_config_set_get_item(standard_readable_config, name, type_, value):
mock_signal = MagicMock(spec=SignalW)
standard_readable_config[mock_signal] = value
if type_ is np.ndarray:
assert np.array_equal(standard_readable_config[mock_signal], value)
else:
assert standard_readable_config[mock_signal] == value


@pytest.mark.parametrize("name, type_, value", test_data)
def test_config_del_item(standard_readable_config, name, type_, value):
mock_signal = MagicMock(spec=SignalW)
standard_readable_config[mock_signal] = value
del standard_readable_config[mock_signal]
with pytest.raises(KeyError):
_ = standard_readable_config[mock_signal]


@pytest.mark.asyncio
@pytest.mark.parametrize("name, type_, value", test_data)
async def test_config_prepare(standard_readable_config, name, type_, value):
readable = StandardReadable()
if type_ is np.ndarray:
readable.mock_signal1 = SignalRW(
name="mock_signal1",
backend=SoftSignalBackend(
datatype=type_, initial_value=np.ndarray([0, 0, 0])
),
)
else:
readable.mock_signal1 = SignalRW(
name="mock_signal1", backend=SoftSignalBackend(datatype=type_)
)

readable.add_readables([readable.mock_signal1])

config = PerSignalConfig()
config[readable.mock_signal1] = value

await readable.prepare(config)
val = await readable.mock_signal1.get_value()

if type_ is np.ndarray:
assert np.array_equal(val, value)
else:
assert await readable.mock_signal1.get_value() == value
Loading