-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9aaff23
commit 9fdb677
Showing
2 changed files
with
159 additions
and
130 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,154 +1,111 @@ | ||
import asyncio | ||
import dataclasses | ||
from collections.abc import Iterator | ||
from typing import Generic, TypeVar, get_type_hints | ||
from abc import abstractmethod, ABC | ||
from dataclasses import asdict | ||
from typing import ( | ||
Generic, | ||
TypeVar, | ||
Protocol, | ||
ClassVar, | ||
Generator, | ||
Any, | ||
) | ||
from enum import Enum | ||
from ophyd_async.core import AsyncStatus, SignalRW | ||
|
||
import numpy as np | ||
|
||
from ._device import Device | ||
from ._protocol import AsyncMovable | ||
from ._signal import SignalR, SignalRW, soft_signal_rw | ||
from ._signal_backend import Array1D, SignalBackend | ||
from ._status import AsyncStatus | ||
from ._utils import T, get_origin_class | ||
class DataclassInstance(Protocol): | ||
__dataclass_fields__: ClassVar[dict[str, SignalRW]] | ||
|
||
RawSignalsT = TypeVar("RawSignalsT") | ||
ParametersSignalsT = TypeVar("ParametersSignalsT") | ||
RawT = TypeVar("RawT") | ||
DerivedT = TypeVar("DerivedT") | ||
ParametersT = TypeVar("ParametersT") | ||
|
||
RawT = TypeVar("RawT", bound=DataclassInstance) | ||
DerivedT = TypeVar("DerivedT", bound=DataclassInstance) | ||
ParametersT = TypeVar("ParametersT", bound=DataclassInstance) | ||
|
||
class Transform(Generic[RawT, DerivedT, ParametersT]): | ||
def forward(self, raw: RawT, parameters: ParametersT) -> DerivedT: ... | ||
def inverse(self, derived: DerivedT, parameters: ParametersT) -> RawT: ... | ||
|
||
def _get_signals(signals: DataclassInstance) -> Generator[Any, SignalRW, None]: | ||
for signal in asdict(signals).values(): # type: ignore | ||
if not isinstance(signal, SignalRW): | ||
raise TypeError | ||
yield signal | ||
|
||
F_contra = TypeVar("F_contra", bound=float | Array1D[np.float64], contravariant=True) | ||
|
||
class VALUES_UPDATING(Enum): | ||
NEITHER = 0 | ||
RAW = 1 | ||
DERIVED = 0 | ||
|
||
# TODO: should this be a TypedDict? | ||
@dataclasses.dataclass | ||
class SlitsRaw(Generic[F_contra]): | ||
top: F_contra | ||
bottom: F_contra | ||
|
||
class Transform(ABC, Generic[RawT, DerivedT, ParametersT]): | ||
raw_signals: RawT | ||
derived_signals: DerivedT | ||
parameters: ParametersT | ||
values_updating: VALUES_UPDATING | ||
|
||
@dataclasses.dataclass | ||
class SlitsDerived(Generic[F_contra]): | ||
gap: F_contra | ||
centre: F_contra | ||
def __init__( | ||
self, | ||
raw_signals: RawT, | ||
derived_signals: DerivedT, | ||
parameters: ParametersT, | ||
) -> None: | ||
self.raw_signals = raw_signals | ||
self.derived_signals = derived_signals | ||
self.parameters = parameters | ||
|
||
for signal in _get_signals(raw_signals): | ||
self._set_backend_raw_callbacks(signal) | ||
for signal in _get_signals(derived_signals): | ||
self._set_backend_derived_callbacks(signal) | ||
|
||
super().__init__() | ||
|
||
self.values_updating = VALUES_UPDATING.NEITHER | ||
|
||
def _set_backend_raw_callbacks(self, raw_signal: SignalRW): | ||
old_set = raw_signal.set | ||
|
||
@AsyncStatus.wrap | ||
async def new_set(*args, **kwargs): | ||
await old_set(*args, **kwargs) | ||
await self.forward() | ||
|
||
raw_signal.set = new_set | ||
|
||
@dataclasses.dataclass | ||
class SlitsParameters: | ||
gap_offset: float | ||
old_get_value = raw_signal.get_value | ||
|
||
async def new_get_value(*args, **kwargs): | ||
await self.inverse() | ||
return await old_get_value(*args, **kwargs) | ||
|
||
class SlitsTransform(Transform[SlitsRaw, SlitsDerived, SlitsParameters]): | ||
def forward( | ||
self, raw: SlitsRaw[F_contra], parameters: SlitsParameters | ||
) -> SlitsDerived[F_contra]: | ||
return SlitsDerived( | ||
gap=raw.top - raw.bottom + parameters.gap_offset, | ||
centre=(raw.top + raw.bottom) / 2, | ||
) | ||
raw_signal.get_value = new_get_value | ||
|
||
def inverse( | ||
self, derived: SlitsDerived[F_contra], parameters: SlitsParameters | ||
) -> SlitsRaw[F_contra]: | ||
half_gap = (derived.gap - parameters.gap_offset) / 2 | ||
return SlitsRaw( | ||
top=derived.centre + half_gap, | ||
bottom=derived.centre - half_gap, | ||
) | ||
def _set_backend_derived_callbacks(self, derived_signal: SignalRW): | ||
old_set = derived_signal.set | ||
|
||
@AsyncStatus.wrap | ||
async def update_raw(*args, **kwargs): | ||
await self.inverse() | ||
await old_set(*args, **kwargs) | ||
|
||
def _get_dataclass_args(method) -> Iterator[type]: | ||
for k, v in get_type_hints(method): | ||
cls = get_origin_class(v) | ||
if k != "return" and dataclasses.is_dataclass(cls): | ||
yield cls | ||
derived_signal.set = update_raw | ||
|
||
old_get_value = derived_signal.get_value | ||
|
||
async def _get_dataclass_from_signals(cls: type[T], device: Device) -> T: | ||
coros = {} | ||
for field in dataclasses.fields(cls): | ||
sig = getattr(device, field.name) | ||
assert isinstance( | ||
sig, SignalR | ||
), f"{device.name}.{field.name} is {sig}, not a Signal" | ||
coros[field.name] = sig.get_value() | ||
results = await asyncio.gather(*coros.values()) | ||
kwargs = dict(zip(coros, results, strict=True)) | ||
return cls(**kwargs) | ||
async def new_get_value(*args, **kwargs): | ||
await self.forward() | ||
return await old_get_value(*args, **kwargs) | ||
|
||
derived_signal.get_value = new_get_value | ||
|
||
@abstractmethod | ||
async def forward(self): | ||
pass | ||
|
||
@abstractmethod | ||
async def inverse(self): | ||
pass | ||
|
||
|
||
class DerivedBackend(Generic[RawT, DerivedT, ParametersT]): | ||
def __init__( | ||
self, | ||
device: Device, | ||
transform: Transform[RawT, DerivedT, ParametersT], | ||
): | ||
self._device = device | ||
_transform: Transform[RawT, DerivedT, ParametersT] | ||
|
||
def __init__(self, transform: Transform[RawT, DerivedT, ParametersT]): | ||
self._transform = transform | ||
self._raw_cls, self._param_cls = _get_dataclass_args(self._transform.forward) | ||
|
||
async def get_parameters(self) -> ParametersT: | ||
return await _get_dataclass_from_signals(self._param_cls, self._device) | ||
|
||
async def get_raw_values(self) -> RawT: | ||
return await _get_dataclass_from_signals(self._raw_cls, self._device) | ||
|
||
async def get_derived_values(self) -> DerivedT: | ||
raw, parameters = await asyncio.gather( | ||
self.get_raw_values(), self.get_parameters() | ||
) | ||
return self._transform.forward(raw, parameters) | ||
|
||
async def set_derived_values(self, derived: DerivedT): | ||
assert isinstance(self._device, AsyncMovable) | ||
await self._device.set(derived) | ||
|
||
async def calculate_raw_values(self, derived: DerivedT) -> RawT: | ||
return self._transform.inverse(derived, await self.get_parameters()) | ||
|
||
def derived_signal(self, variable: str) -> SignalRW[float]: | ||
return SignalRW(DerivedSignalBackend(self, variable)) | ||
|
||
|
||
class DerivedSignalBackend(SignalBackend[float]): | ||
def __init__(self, backend: DerivedBackend, variable: str): | ||
self._backend = backend | ||
self._variable = variable | ||
super().__init__(float) | ||
|
||
async def get_value(self) -> float: | ||
derived = await self._backend.get_derived_values() | ||
return getattr(derived, self._variable) | ||
|
||
async def put(self, value: float | None, wait: bool): | ||
derived = await self._backend.get_derived_values() | ||
# TODO: we should be calling locate on these as we want to move relative to the | ||
# setpoint, not readback | ||
setattr(derived, self._variable, value) | ||
await self._backend.set_derived_values(derived) | ||
|
||
|
||
class Slits(Device): | ||
def __init__(self, name=""): | ||
self._backend = DerivedBackend(self, SlitsTransform()) | ||
# Raw signals | ||
self.top = soft_signal_rw(float) | ||
self.bottom = soft_signal_rw(float) | ||
# Parameter | ||
self.gap_offset = soft_signal_rw(float) | ||
# Derived signals | ||
self.gap = self._backend.derived_signal("gap") | ||
self.centre = self._backend.derived_signal("centre") | ||
super().__init__(name=name) | ||
|
||
@AsyncStatus.wrap | ||
async def set(self, derived: SlitsDerived[float]) -> None: | ||
raw: SlitsRaw[float] = await self._backend.calculate_raw_values(derived) | ||
await asyncio.gather(self.top.set(raw.top), self.bottom.set(raw.bottom)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from ophyd_async.core._derived_signal import Transform, DerivedBackend | ||
from dataclasses import dataclass | ||
import asyncio | ||
from ophyd_async.core import Device, SignalRW, soft_signal_rw | ||
|
||
|
||
@dataclass | ||
class SlitsRaw: | ||
top: SignalRW[float] | ||
bottom: SignalRW[float] | ||
|
||
|
||
@dataclass | ||
class SlitsDerived: | ||
gap: SignalRW[float] | ||
center: SignalRW[float] | ||
|
||
|
||
@dataclass | ||
class SlitsParameters: | ||
gap_offset: SignalRW[float] | ||
|
||
|
||
class SlitsTransform(Transform[SlitsRaw, SlitsDerived, SlitsParameters]): | ||
async def forward(self): | ||
print("DEBUG: FORWARD") | ||
top, bottom, gap_offset = await asyncio.gather( | ||
self.raw_signals.top.get_value(), | ||
self.raw_signals.bottom.get_value(), | ||
self.parameters.gap_offset.get_value(), | ||
) | ||
new_gap = top - bottom + gap_offset | ||
new_center = (top + bottom) / 2 | ||
asyncio.gather( | ||
self.derived_signals.gap.set(new_gap), | ||
self.derived_signals.center.set(new_center), | ||
) | ||
|
||
async def inverse(self): | ||
print("DEBUG: INVERSE") | ||
new_gap, new_center, gap_offset = await asyncio.gather( | ||
self.derived_signals.gap.get_value(), | ||
self.derived_signals.center.get_value(), | ||
self.parameters.gap_offset.get_value(), | ||
) | ||
half_gap = (new_gap - gap_offset) / 2 | ||
|
||
await asyncio.gather( | ||
self.raw_signals.top.set(new_center + half_gap), | ||
self.raw_signals.bottom.set(new_center - half_gap), | ||
) | ||
|
||
|
||
class Slits(Device): | ||
def __init__(self, name=""): | ||
self.raw = SlitsRaw(top=soft_signal_rw(float), bottom=soft_signal_rw(float)) | ||
self.derived = SlitsDerived( | ||
gap=soft_signal_rw(float), center=soft_signal_rw(float) | ||
) | ||
self.parameters = SlitsParameters(gap_offset=soft_signal_rw(float)) | ||
|
||
self._backend = DerivedBackend( | ||
SlitsTransform(self.raw, self.derived, self.parameters) | ||
) | ||
|
||
super().__init__(name=name) | ||
|
||
|
||
async def test_derived_signal(): | ||
slits = Slits(name="hello") | ||
|
||
assert (await slits.derived.center.get_value()) == 0.0 |