From 9fdb677600c3bcade7e8b0d14cef8332f667b043 Mon Sep 17 00:00:00 2001 From: Eva Lott Date: Fri, 22 Nov 2024 13:30:27 +0000 Subject: [PATCH] WIP --- src/ophyd_async/core/_derived_signal.py | 217 ++++++++++-------------- tests/core/test_derived_signal.py | 72 ++++++++ 2 files changed, 159 insertions(+), 130 deletions(-) create mode 100644 tests/core/test_derived_signal.py diff --git a/src/ophyd_async/core/_derived_signal.py b/src/ophyd_async/core/_derived_signal.py index be8b1aada..92d4b6e16 100644 --- a/src/ophyd_async/core/_derived_signal.py +++ b/src/ophyd_async/core/_derived_signal.py @@ -1,154 +1,111 @@ -import asyncio -import dataclasses -from collections.abc import Iterator -from typing import Generic, TypeVar, get_type_hints +from abc import abstractmethod, ABC +from dataclasses import asdict +from typing import ( + Generic, + TypeVar, + Protocol, + ClassVar, + Generator, + Any, +) +from enum import Enum +from ophyd_async.core import AsyncStatus, SignalRW -import numpy as np -from ._device import Device -from ._protocol import AsyncMovable -from ._signal import SignalR, SignalRW, soft_signal_rw -from ._signal_backend import Array1D, SignalBackend -from ._status import AsyncStatus -from ._utils import T, get_origin_class +class DataclassInstance(Protocol): + __dataclass_fields__: ClassVar[dict[str, SignalRW]] -RawSignalsT = TypeVar("RawSignalsT") -ParametersSignalsT = TypeVar("ParametersSignalsT") -RawT = TypeVar("RawT") -DerivedT = TypeVar("DerivedT") -ParametersT = TypeVar("ParametersT") +RawT = TypeVar("RawT", bound=DataclassInstance) +DerivedT = TypeVar("DerivedT", bound=DataclassInstance) +ParametersT = TypeVar("ParametersT", bound=DataclassInstance) -class Transform(Generic[RawT, DerivedT, ParametersT]): - def forward(self, raw: RawT, parameters: ParametersT) -> DerivedT: ... - def inverse(self, derived: DerivedT, parameters: ParametersT) -> RawT: ... +def _get_signals(signals: DataclassInstance) -> Generator[Any, SignalRW, None]: + for signal in asdict(signals).values(): # type: ignore + if not isinstance(signal, SignalRW): + raise TypeError + yield signal -F_contra = TypeVar("F_contra", bound=float | Array1D[np.float64], contravariant=True) +class VALUES_UPDATING(Enum): + NEITHER = 0 + RAW = 1 + DERIVED = 0 -# TODO: should this be a TypedDict? -@dataclasses.dataclass -class SlitsRaw(Generic[F_contra]): - top: F_contra - bottom: F_contra +class Transform(ABC, Generic[RawT, DerivedT, ParametersT]): + raw_signals: RawT + derived_signals: DerivedT + parameters: ParametersT + values_updating: VALUES_UPDATING -@dataclasses.dataclass -class SlitsDerived(Generic[F_contra]): - gap: F_contra - centre: F_contra + def __init__( + self, + raw_signals: RawT, + derived_signals: DerivedT, + parameters: ParametersT, + ) -> None: + self.raw_signals = raw_signals + self.derived_signals = derived_signals + self.parameters = parameters + + for signal in _get_signals(raw_signals): + self._set_backend_raw_callbacks(signal) + for signal in _get_signals(derived_signals): + self._set_backend_derived_callbacks(signal) + + super().__init__() + + self.values_updating = VALUES_UPDATING.NEITHER + + def _set_backend_raw_callbacks(self, raw_signal: SignalRW): + old_set = raw_signal.set + + @AsyncStatus.wrap + async def new_set(*args, **kwargs): + await old_set(*args, **kwargs) + await self.forward() + raw_signal.set = new_set -@dataclasses.dataclass -class SlitsParameters: - gap_offset: float + old_get_value = raw_signal.get_value + async def new_get_value(*args, **kwargs): + await self.inverse() + return await old_get_value(*args, **kwargs) -class SlitsTransform(Transform[SlitsRaw, SlitsDerived, SlitsParameters]): - def forward( - self, raw: SlitsRaw[F_contra], parameters: SlitsParameters - ) -> SlitsDerived[F_contra]: - return SlitsDerived( - gap=raw.top - raw.bottom + parameters.gap_offset, - centre=(raw.top + raw.bottom) / 2, - ) + raw_signal.get_value = new_get_value - def inverse( - self, derived: SlitsDerived[F_contra], parameters: SlitsParameters - ) -> SlitsRaw[F_contra]: - half_gap = (derived.gap - parameters.gap_offset) / 2 - return SlitsRaw( - top=derived.centre + half_gap, - bottom=derived.centre - half_gap, - ) + def _set_backend_derived_callbacks(self, derived_signal: SignalRW): + old_set = derived_signal.set + @AsyncStatus.wrap + async def update_raw(*args, **kwargs): + await self.inverse() + await old_set(*args, **kwargs) -def _get_dataclass_args(method) -> Iterator[type]: - for k, v in get_type_hints(method): - cls = get_origin_class(v) - if k != "return" and dataclasses.is_dataclass(cls): - yield cls + derived_signal.set = update_raw + old_get_value = derived_signal.get_value -async def _get_dataclass_from_signals(cls: type[T], device: Device) -> T: - coros = {} - for field in dataclasses.fields(cls): - sig = getattr(device, field.name) - assert isinstance( - sig, SignalR - ), f"{device.name}.{field.name} is {sig}, not a Signal" - coros[field.name] = sig.get_value() - results = await asyncio.gather(*coros.values()) - kwargs = dict(zip(coros, results, strict=True)) - return cls(**kwargs) + async def new_get_value(*args, **kwargs): + await self.forward() + return await old_get_value(*args, **kwargs) + + derived_signal.get_value = new_get_value + + @abstractmethod + async def forward(self): + pass + + @abstractmethod + async def inverse(self): + pass class DerivedBackend(Generic[RawT, DerivedT, ParametersT]): - def __init__( - self, - device: Device, - transform: Transform[RawT, DerivedT, ParametersT], - ): - self._device = device + _transform: Transform[RawT, DerivedT, ParametersT] + + def __init__(self, transform: Transform[RawT, DerivedT, ParametersT]): self._transform = transform - self._raw_cls, self._param_cls = _get_dataclass_args(self._transform.forward) - - async def get_parameters(self) -> ParametersT: - return await _get_dataclass_from_signals(self._param_cls, self._device) - - async def get_raw_values(self) -> RawT: - return await _get_dataclass_from_signals(self._raw_cls, self._device) - - async def get_derived_values(self) -> DerivedT: - raw, parameters = await asyncio.gather( - self.get_raw_values(), self.get_parameters() - ) - return self._transform.forward(raw, parameters) - - async def set_derived_values(self, derived: DerivedT): - assert isinstance(self._device, AsyncMovable) - await self._device.set(derived) - - async def calculate_raw_values(self, derived: DerivedT) -> RawT: - return self._transform.inverse(derived, await self.get_parameters()) - - def derived_signal(self, variable: str) -> SignalRW[float]: - return SignalRW(DerivedSignalBackend(self, variable)) - - -class DerivedSignalBackend(SignalBackend[float]): - def __init__(self, backend: DerivedBackend, variable: str): - self._backend = backend - self._variable = variable - super().__init__(float) - - async def get_value(self) -> float: - derived = await self._backend.get_derived_values() - return getattr(derived, self._variable) - - async def put(self, value: float | None, wait: bool): - derived = await self._backend.get_derived_values() - # TODO: we should be calling locate on these as we want to move relative to the - # setpoint, not readback - setattr(derived, self._variable, value) - await self._backend.set_derived_values(derived) - - -class Slits(Device): - def __init__(self, name=""): - self._backend = DerivedBackend(self, SlitsTransform()) - # Raw signals - self.top = soft_signal_rw(float) - self.bottom = soft_signal_rw(float) - # Parameter - self.gap_offset = soft_signal_rw(float) - # Derived signals - self.gap = self._backend.derived_signal("gap") - self.centre = self._backend.derived_signal("centre") - super().__init__(name=name) - - @AsyncStatus.wrap - async def set(self, derived: SlitsDerived[float]) -> None: - raw: SlitsRaw[float] = await self._backend.calculate_raw_values(derived) - await asyncio.gather(self.top.set(raw.top), self.bottom.set(raw.bottom)) diff --git a/tests/core/test_derived_signal.py b/tests/core/test_derived_signal.py new file mode 100644 index 000000000..904628649 --- /dev/null +++ b/tests/core/test_derived_signal.py @@ -0,0 +1,72 @@ +from ophyd_async.core._derived_signal import Transform, DerivedBackend +from dataclasses import dataclass +import asyncio +from ophyd_async.core import Device, SignalRW, soft_signal_rw + + +@dataclass +class SlitsRaw: + top: SignalRW[float] + bottom: SignalRW[float] + + +@dataclass +class SlitsDerived: + gap: SignalRW[float] + center: SignalRW[float] + + +@dataclass +class SlitsParameters: + gap_offset: SignalRW[float] + + +class SlitsTransform(Transform[SlitsRaw, SlitsDerived, SlitsParameters]): + async def forward(self): + print("DEBUG: FORWARD") + top, bottom, gap_offset = await asyncio.gather( + self.raw_signals.top.get_value(), + self.raw_signals.bottom.get_value(), + self.parameters.gap_offset.get_value(), + ) + new_gap = top - bottom + gap_offset + new_center = (top + bottom) / 2 + asyncio.gather( + self.derived_signals.gap.set(new_gap), + self.derived_signals.center.set(new_center), + ) + + async def inverse(self): + print("DEBUG: INVERSE") + new_gap, new_center, gap_offset = await asyncio.gather( + self.derived_signals.gap.get_value(), + self.derived_signals.center.get_value(), + self.parameters.gap_offset.get_value(), + ) + half_gap = (new_gap - gap_offset) / 2 + + await asyncio.gather( + self.raw_signals.top.set(new_center + half_gap), + self.raw_signals.bottom.set(new_center - half_gap), + ) + + +class Slits(Device): + def __init__(self, name=""): + self.raw = SlitsRaw(top=soft_signal_rw(float), bottom=soft_signal_rw(float)) + self.derived = SlitsDerived( + gap=soft_signal_rw(float), center=soft_signal_rw(float) + ) + self.parameters = SlitsParameters(gap_offset=soft_signal_rw(float)) + + self._backend = DerivedBackend( + SlitsTransform(self.raw, self.derived, self.parameters) + ) + + super().__init__(name=name) + + +async def test_derived_signal(): + slits = Slits(name="hello") + + assert (await slits.derived.center.get_value()) == 0.0