Skip to content

Commit

Permalink
allowed backends to take pydantic models
Browse files Browse the repository at this point in the history
Also finished a `PvaTable` type.
  • Loading branch information
evalott100 committed Aug 21, 2024
1 parent 1b203e9 commit d976d02
Show file tree
Hide file tree
Showing 16 changed files with 507 additions and 410 deletions.
2 changes: 0 additions & 2 deletions src/ophyd_async/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
)
from ._signal_backend import (
BackendConverterFactory,
ProtocolDatatypeAbstraction,
RuntimeSubsetEnum,
SignalBackend,
SubsetEnum,
Expand Down Expand Up @@ -124,7 +123,6 @@
"NameProvider",
"PathInfo",
"PathProvider",
"ProtocolDatatypeAbstraction",
"ShapeProvider",
"StaticFilenameProvider",
"StaticPathProvider",
Expand Down
14 changes: 6 additions & 8 deletions src/ophyd_async/core/_device_save_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from bluesky.plan_stubs import abs_set, wait
from bluesky.protocols import Location
from bluesky.utils import Msg
from pydantic import BaseModel

from ._device import Device
from ._signal import SignalRW
from ._signal_backend import ProtocolDatatypeAbstraction


def ndarray_representer(dumper: yaml.Dumper, array: npt.NDArray[Any]) -> yaml.Node:
Expand All @@ -19,14 +19,12 @@ def ndarray_representer(dumper: yaml.Dumper, array: npt.NDArray[Any]) -> yaml.No
)


def protocol_datatype_abstraction_representer(
dumper: yaml.Dumper, protocol_datatype_abstraction: ProtocolDatatypeAbstraction
def pydantic_model_abstraction_representer(
dumper: yaml.Dumper, model: BaseModel
) -> yaml.Node:
"""Uses the protocol datatype since it has to be serializable."""

return dumper.represent_data(
protocol_datatype_abstraction.convert_to_protocol_datatype()
)
return dumper.represent_data(model.model_dump(mode="python"))


class OphydDumper(yaml.Dumper):
Expand Down Expand Up @@ -146,8 +144,8 @@ def save_to_yaml(phases: Sequence[Dict[str, Any]], save_path: str) -> None:

yaml.add_representer(np.ndarray, ndarray_representer, Dumper=yaml.Dumper)
yaml.add_multi_representer(
ProtocolDatatypeAbstraction,
protocol_datatype_abstraction_representer,
BaseModel,
pydantic_model_abstraction_representer,
Dumper=yaml.Dumper,
)

Expand Down
23 changes: 0 additions & 23 deletions src/ophyd_async/core/_signal_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,6 @@
from ._utils import DEFAULT_TIMEOUT, ReadingValueCallback, T


class ProtocolDatatypeAbstraction(ABC, Generic[T]):
@abstractmethod
def __init__(self):
"""The abstract datatype must be able to be intialized with no arguments."""

@abstractmethod
def convert_to_protocol_datatype(self) -> T:
"""
Convert the abstract datatype to a form which can be sent
over whichever protocol.
This output will be used when the device is serialized.
"""

@classmethod
@abstractmethod
def convert_from_protocol_datatype(cls, value: T) -> "ProtocolDatatypeAbstraction":
"""
Convert the datatype received from the protocol to a
higher level abstract datatype.
"""


class BackendConverterFactory(ABC):
"""Convert between the signal backend and the signal type"""

Expand Down
26 changes: 16 additions & 10 deletions src/ophyd_async/core/_soft_signal_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@

import numpy as np
from bluesky.protocols import DataKey, Dtype, Reading
from pydantic import BaseModel
from typing_extensions import TypedDict

from ._signal_backend import (
BackendConverterFactory,
ProtocolDatatypeAbstraction,
RuntimeSubsetEnum,
SignalBackend,
)
Expand Down Expand Up @@ -127,21 +127,27 @@ def make_initial_value(self, datatype: Optional[Type[T]]) -> T:
return cast(T, self.choices[0])


class SoftProtocolDatatypeAbstractionConverter(SoftConverter):
def __init__(self, datatype: Type[ProtocolDatatypeAbstraction]):
class SoftPydanticModelConverter(SoftConverter):
"""Necessary for serializing soft signals."""

def __init__(self, datatype: Type[BaseModel]):
self.datatype = datatype

def reading(self, value: T, timestamp: float, severity: int) -> Reading:
value = self.value(value)
return super().reading(value, timestamp, severity)

def value(self, value: Any) -> Any:
if not isinstance(value, self.datatype):
# For the case where we
value = self.datatype.convert_from_protocol_datatype(value)
if isinstance(value, dict):
value = self.datatype(**value)
return value

def write_value(self, value):
if isinstance(value, dict):
# If the device is being deserialized
return self.datatype(**value).model_dump(mode="python")
if isinstance(value, self.datatype):
return value.model_dump(mode="python")
return value

def make_initial_value(self, datatype: Type | None) -> Any:
Expand All @@ -162,16 +168,16 @@ def make_converter(cls, datatype):
is_enum = inspect.isclass(datatype) and (
issubclass(datatype, Enum) or issubclass(datatype, RuntimeSubsetEnum)
)
is_convertable_abstract_datatype = inspect.isclass(datatype) and issubclass(
datatype, ProtocolDatatypeAbstraction
is_pydantic_model = inspect.isclass(datatype) and issubclass(
datatype, BaseModel
)

if is_array or is_sequence:
return SoftArrayConverter()
if is_enum:
return SoftEnumConverter(datatype)
if is_convertable_abstract_datatype:
return SoftProtocolDatatypeAbstractionConverter(datatype)
if is_pydantic_model:
return SoftPydanticModelConverter(datatype)

return SoftConverter()

Expand Down
4 changes: 3 additions & 1 deletion src/ophyd_async/epics/signal/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ._common import LimitPair, Limits, get_supported_values
from ._p4p import PvaSignalBackend, PvaTableAbstraction
from ._p4p import PvaSignalBackend
from ._p4p_table_abstraction import PvaTable
from ._signal import (
epics_signal_r,
epics_signal_rw,
Expand All @@ -13,6 +14,7 @@
"LimitPair",
"Limits",
"PvaSignalBackend",
"PvaTable",
"PvaTableAbstraction",
"epics_signal_r",
"epics_signal_rw",
Expand Down
36 changes: 14 additions & 22 deletions src/ophyd_async/epics/signal/_p4p.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import inspect
import logging
import time
from abc import abstractmethod
from dataclasses import dataclass
from enum import Enum
from math import isnan, nan
Expand All @@ -13,12 +12,12 @@
from bluesky.protocols import DataKey, Dtype, Reading
from p4p import Value
from p4p.client.asyncio import Context, Subscription
from pydantic import BaseModel

from ophyd_async.core import (
DEFAULT_TIMEOUT,
BackendConverterFactory,
NotConnected,
ProtocolDatatypeAbstraction,
ReadingValueCallback,
RuntimeSubsetEnum,
SignalBackend,
Expand Down Expand Up @@ -288,32 +287,25 @@ def __getattribute__(self, __name: str) -> Any:
raise NotImplementedError("No PV has been set as connect() has not been called")


class PvaTableAbstraction(ProtocolDatatypeAbstraction[Dict]):
@abstractmethod
def convert_to_protocol_datatype(self) -> Dict:
"""Converts the object to a pva table (dictionary)."""

@classmethod
@abstractmethod
def convert_from_protocol_datatype(cls, value: Dict) -> "PvaTableAbstraction":
"""Converts from a pva table (dictionary) to a Python datatype."""


class PvaTableAbtractionConverter(PvaConverter):
def __init__(self, datatype: PvaTableAbstraction):
class PvaPydanticModelConverter(PvaConverter):
def __init__(self, datatype: BaseModel):
self.datatype = datatype

def reading(self, value: Value):
ts = time.time()
value = self.datatype.convert_from_protocol_datatype(value.todict())
value = self.value(value)
return {"value": value, "timestamp": ts, "alarm_severity": 0}

def value(self, value: Value):
return self.datatype.convert_from_protocol_datatype(value.todict())
return self.datatype(**value.todict())

def write_value(self, value):
def write_value(self, value: Union[BaseModel, Dict[str, Any]]):
"""
A user can put whichever form to the signal.
This is required for yaml deserialization.
"""
if isinstance(value, self.datatype):
return value.convert_to_protocol_datatype()
return value.model_dump(mode="python")
return value


Expand All @@ -327,8 +319,8 @@ class PvaConverterFactory(BackendConverterFactory):
np.ndarray,
Enum,
RuntimeSubsetEnum,
BaseModel,
dict,
PvaTableAbstraction,
)

@classmethod
Expand Down Expand Up @@ -411,9 +403,9 @@ def make_converter(
if (
datatype
and inspect.isclass(datatype)
and issubclass(datatype, PvaTableAbstraction)
and issubclass(datatype, BaseModel)
):
return PvaTableAbtractionConverter(datatype)
return PvaPydanticModelConverter(datatype)
return PvaDictConverter()
else:
raise TypeError(f"{pv}: Unsupported typeid {typeid}")
Expand Down
70 changes: 70 additions & 0 deletions src/ophyd_async/epics/signal/_p4p_table_abstraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from typing import Dict

import numpy as np
from pydantic import BaseModel, ConfigDict, model_validator
from pydantic_numpy.typing import NpNDArray


class PvaTable(BaseModel):
"""An abstraction of a PVA Table of str to python array."""

model_config = ConfigDict(validate_assignment=True, strict=False)

@classmethod
def row(cls, sub_cls, **kwargs) -> "PvaTable":
arrayified_kwargs = {
field_name: np.concatenate(
(
(default_arr := field_value.default_factory()),
np.array([kwargs[field_name]], dtype=default_arr.dtype),
)
)
for field_name, field_value in sub_cls.model_fields.items()
}
return sub_cls(**arrayified_kwargs)

def __add__(self, right: "PvaTable") -> "PvaTable":
"""Concatinate the arrays in field values."""

assert isinstance(right, type(self)), (
f"{right} is not a `PvaTable`, or is not the same "
f"type of `PvaTable` as {self}."
)

return type(self)(
**{
field_name: np.concatenate(
(getattr(self, field_name), getattr(right, field_name))
)
for field_name in self.model_fields
}
)

@model_validator(mode="after")
def validate_arrays(self) -> "PvaTable":
first_length = len(next(iter(self))[1])
assert all(
len(field_value) == first_length for _, field_value in self
), "Rows should all be of equal size."

assert 0 <= first_length < 4096, f"Length {first_length} not in range."

if not all(
np.issubdtype(
self.model_fields[field_name].default_factory().dtype, field_value.dtype
)
for field_name, field_value in self
):
raise ValueError(
f"Cannot construct a `{type(self).__name__}`, "
"some rows have incorrect types."
)

return self

def convert_to_pva_datatype(self) -> Dict[str, NpNDArray]:
return self.model_dump(mode="python")

@classmethod
def convert_from_pva_datatype(cls, pva_table: Dict[str, NpNDArray]):
return cls(**pva_table)
4 changes: 0 additions & 4 deletions src/ophyd_async/fastcs/panda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@
DatasetTable,
PandaHdf5DatasetType,
SeqTable,
SeqTableRowType,
SeqTrigger,
seq_table_row,
)
from ._trigger import (
PcompInfo,
Expand All @@ -44,9 +42,7 @@
"DatasetTable",
"PandaHdf5DatasetType",
"SeqTable",
"SeqTableRowType",
"SeqTrigger",
"seq_table_row",
"PcompInfo",
"SeqTableInfo",
"StaticPcompTriggerLogic",
Expand Down
Loading

0 comments on commit d976d02

Please sign in to comment.