Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
evalott100 committed Nov 22, 2024
1 parent 9aaff23 commit 9fdb677
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 130 deletions.
217 changes: 87 additions & 130 deletions src/ophyd_async/core/_derived_signal.py
Original file line number Diff line number Diff line change
@@ -1,154 +1,111 @@
import asyncio
import dataclasses
from collections.abc import Iterator
from typing import Generic, TypeVar, get_type_hints
from abc import abstractmethod, ABC
from dataclasses import asdict
from typing import (
Generic,
TypeVar,
Protocol,
ClassVar,
Generator,
Any,
)
from enum import Enum
from ophyd_async.core import AsyncStatus, SignalRW

import numpy as np

from ._device import Device
from ._protocol import AsyncMovable
from ._signal import SignalR, SignalRW, soft_signal_rw
from ._signal_backend import Array1D, SignalBackend
from ._status import AsyncStatus
from ._utils import T, get_origin_class
class DataclassInstance(Protocol):
__dataclass_fields__: ClassVar[dict[str, SignalRW]]

RawSignalsT = TypeVar("RawSignalsT")
ParametersSignalsT = TypeVar("ParametersSignalsT")
RawT = TypeVar("RawT")
DerivedT = TypeVar("DerivedT")
ParametersT = TypeVar("ParametersT")

RawT = TypeVar("RawT", bound=DataclassInstance)
DerivedT = TypeVar("DerivedT", bound=DataclassInstance)
ParametersT = TypeVar("ParametersT", bound=DataclassInstance)

class Transform(Generic[RawT, DerivedT, ParametersT]):
def forward(self, raw: RawT, parameters: ParametersT) -> DerivedT: ...
def inverse(self, derived: DerivedT, parameters: ParametersT) -> RawT: ...

def _get_signals(signals: DataclassInstance) -> Generator[Any, SignalRW, None]:
for signal in asdict(signals).values(): # type: ignore
if not isinstance(signal, SignalRW):
raise TypeError
yield signal

F_contra = TypeVar("F_contra", bound=float | Array1D[np.float64], contravariant=True)

class VALUES_UPDATING(Enum):
NEITHER = 0
RAW = 1
DERIVED = 0

# TODO: should this be a TypedDict?
@dataclasses.dataclass
class SlitsRaw(Generic[F_contra]):
top: F_contra
bottom: F_contra

class Transform(ABC, Generic[RawT, DerivedT, ParametersT]):
raw_signals: RawT
derived_signals: DerivedT
parameters: ParametersT
values_updating: VALUES_UPDATING

@dataclasses.dataclass
class SlitsDerived(Generic[F_contra]):
gap: F_contra
centre: F_contra
def __init__(
self,
raw_signals: RawT,
derived_signals: DerivedT,
parameters: ParametersT,
) -> None:
self.raw_signals = raw_signals
self.derived_signals = derived_signals
self.parameters = parameters

for signal in _get_signals(raw_signals):
self._set_backend_raw_callbacks(signal)
for signal in _get_signals(derived_signals):
self._set_backend_derived_callbacks(signal)

super().__init__()

self.values_updating = VALUES_UPDATING.NEITHER

def _set_backend_raw_callbacks(self, raw_signal: SignalRW):
old_set = raw_signal.set

@AsyncStatus.wrap
async def new_set(*args, **kwargs):
await old_set(*args, **kwargs)
await self.forward()

raw_signal.set = new_set

@dataclasses.dataclass
class SlitsParameters:
gap_offset: float
old_get_value = raw_signal.get_value

async def new_get_value(*args, **kwargs):
await self.inverse()
return await old_get_value(*args, **kwargs)

class SlitsTransform(Transform[SlitsRaw, SlitsDerived, SlitsParameters]):
def forward(
self, raw: SlitsRaw[F_contra], parameters: SlitsParameters
) -> SlitsDerived[F_contra]:
return SlitsDerived(
gap=raw.top - raw.bottom + parameters.gap_offset,
centre=(raw.top + raw.bottom) / 2,
)
raw_signal.get_value = new_get_value

def inverse(
self, derived: SlitsDerived[F_contra], parameters: SlitsParameters
) -> SlitsRaw[F_contra]:
half_gap = (derived.gap - parameters.gap_offset) / 2
return SlitsRaw(
top=derived.centre + half_gap,
bottom=derived.centre - half_gap,
)
def _set_backend_derived_callbacks(self, derived_signal: SignalRW):
old_set = derived_signal.set

@AsyncStatus.wrap
async def update_raw(*args, **kwargs):
await self.inverse()
await old_set(*args, **kwargs)

def _get_dataclass_args(method) -> Iterator[type]:
for k, v in get_type_hints(method):
cls = get_origin_class(v)
if k != "return" and dataclasses.is_dataclass(cls):
yield cls
derived_signal.set = update_raw

old_get_value = derived_signal.get_value

async def _get_dataclass_from_signals(cls: type[T], device: Device) -> T:
coros = {}
for field in dataclasses.fields(cls):
sig = getattr(device, field.name)
assert isinstance(
sig, SignalR
), f"{device.name}.{field.name} is {sig}, not a Signal"
coros[field.name] = sig.get_value()
results = await asyncio.gather(*coros.values())
kwargs = dict(zip(coros, results, strict=True))
return cls(**kwargs)
async def new_get_value(*args, **kwargs):
await self.forward()
return await old_get_value(*args, **kwargs)

derived_signal.get_value = new_get_value

@abstractmethod
async def forward(self):
pass

@abstractmethod
async def inverse(self):
pass


class DerivedBackend(Generic[RawT, DerivedT, ParametersT]):
def __init__(
self,
device: Device,
transform: Transform[RawT, DerivedT, ParametersT],
):
self._device = device
_transform: Transform[RawT, DerivedT, ParametersT]

def __init__(self, transform: Transform[RawT, DerivedT, ParametersT]):
self._transform = transform
self._raw_cls, self._param_cls = _get_dataclass_args(self._transform.forward)

async def get_parameters(self) -> ParametersT:
return await _get_dataclass_from_signals(self._param_cls, self._device)

async def get_raw_values(self) -> RawT:
return await _get_dataclass_from_signals(self._raw_cls, self._device)

async def get_derived_values(self) -> DerivedT:
raw, parameters = await asyncio.gather(
self.get_raw_values(), self.get_parameters()
)
return self._transform.forward(raw, parameters)

async def set_derived_values(self, derived: DerivedT):
assert isinstance(self._device, AsyncMovable)
await self._device.set(derived)

async def calculate_raw_values(self, derived: DerivedT) -> RawT:
return self._transform.inverse(derived, await self.get_parameters())

def derived_signal(self, variable: str) -> SignalRW[float]:
return SignalRW(DerivedSignalBackend(self, variable))


class DerivedSignalBackend(SignalBackend[float]):
def __init__(self, backend: DerivedBackend, variable: str):
self._backend = backend
self._variable = variable
super().__init__(float)

async def get_value(self) -> float:
derived = await self._backend.get_derived_values()
return getattr(derived, self._variable)

async def put(self, value: float | None, wait: bool):
derived = await self._backend.get_derived_values()
# TODO: we should be calling locate on these as we want to move relative to the
# setpoint, not readback
setattr(derived, self._variable, value)
await self._backend.set_derived_values(derived)


class Slits(Device):
def __init__(self, name=""):
self._backend = DerivedBackend(self, SlitsTransform())
# Raw signals
self.top = soft_signal_rw(float)
self.bottom = soft_signal_rw(float)
# Parameter
self.gap_offset = soft_signal_rw(float)
# Derived signals
self.gap = self._backend.derived_signal("gap")
self.centre = self._backend.derived_signal("centre")
super().__init__(name=name)

@AsyncStatus.wrap
async def set(self, derived: SlitsDerived[float]) -> None:
raw: SlitsRaw[float] = await self._backend.calculate_raw_values(derived)
await asyncio.gather(self.top.set(raw.top), self.bottom.set(raw.bottom))
72 changes: 72 additions & 0 deletions tests/core/test_derived_signal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from ophyd_async.core._derived_signal import Transform, DerivedBackend
from dataclasses import dataclass
import asyncio
from ophyd_async.core import Device, SignalRW, soft_signal_rw


@dataclass
class SlitsRaw:
top: SignalRW[float]
bottom: SignalRW[float]


@dataclass
class SlitsDerived:
gap: SignalRW[float]
center: SignalRW[float]


@dataclass
class SlitsParameters:
gap_offset: SignalRW[float]


class SlitsTransform(Transform[SlitsRaw, SlitsDerived, SlitsParameters]):
async def forward(self):
print("DEBUG: FORWARD")
top, bottom, gap_offset = await asyncio.gather(
self.raw_signals.top.get_value(),
self.raw_signals.bottom.get_value(),
self.parameters.gap_offset.get_value(),
)
new_gap = top - bottom + gap_offset
new_center = (top + bottom) / 2
asyncio.gather(
self.derived_signals.gap.set(new_gap),
self.derived_signals.center.set(new_center),
)

async def inverse(self):
print("DEBUG: INVERSE")
new_gap, new_center, gap_offset = await asyncio.gather(
self.derived_signals.gap.get_value(),
self.derived_signals.center.get_value(),
self.parameters.gap_offset.get_value(),
)
half_gap = (new_gap - gap_offset) / 2

await asyncio.gather(
self.raw_signals.top.set(new_center + half_gap),
self.raw_signals.bottom.set(new_center - half_gap),
)


class Slits(Device):
def __init__(self, name=""):
self.raw = SlitsRaw(top=soft_signal_rw(float), bottom=soft_signal_rw(float))
self.derived = SlitsDerived(
gap=soft_signal_rw(float), center=soft_signal_rw(float)
)
self.parameters = SlitsParameters(gap_offset=soft_signal_rw(float))

self._backend = DerivedBackend(
SlitsTransform(self.raw, self.derived, self.parameters)
)

super().__init__(name=name)


async def test_derived_signal():
slits = Slits(name="hello")

assert (await slits.derived.center.get_value()) == 0.0

0 comments on commit 9fdb677

Please sign in to comment.