Skip to content

Commit

Permalink
made suggested changes
Browse files Browse the repository at this point in the history
  • Loading branch information
evalott100 committed Sep 11, 2024
1 parent a5f88a5 commit 1aafbc0
Show file tree
Hide file tree
Showing 11 changed files with 111 additions and 66 deletions.
16 changes: 16 additions & 0 deletions src/ophyd_async/basemodel_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from abc import ABC, abstractmethod

from pydantic import BaseModel, Field


class X(ABC):
@abstractmethod
def foo():
pass

class Y(X):
def foo():
pass

class Model(BaseModel):
x: int = Field(ge=0)
2 changes: 2 additions & 0 deletions src/ophyd_async/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
)
from ._soft_signal_backend import SignalMetadata, SoftSignalBackend
from ._status import AsyncStatus, WatchableAsyncStatus, completed_status
from ._table import Table
from ._utils import (
DEFAULT_TIMEOUT,
CalculatableTimeout,
Expand Down Expand Up @@ -156,6 +157,7 @@
"CalculateTimeout",
"NotConnected",
"ReadingValueCallback",
"Table",
"T",
"WatcherUpdate",
"get_dtype",
Expand Down
3 changes: 0 additions & 3 deletions src/ophyd_async/core/_signal_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,10 @@ class SignalBackend(Generic[T]):
#: Datatype of the signal value
datatype: Optional[Type[T]] = None

_ALLOWED_DATATYPES: ClassVar[Tuple[Type]]

@classmethod
@abstractmethod
def datatype_allowed(cls, dtype: type):
"""Check if a given datatype is acceptable for this signal backend."""
pass

#: Like ca://PV_PREFIX:SIGNAL
@abstractmethod
Expand Down
22 changes: 5 additions & 17 deletions src/ophyd_async/core/_soft_signal_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from abc import ABCMeta
from collections import abc
from enum import Enum
from typing import Any, Dict, Generic, Optional, Tuple, Type, Union, cast, get_origin
from typing import Dict, Generic, Optional, Tuple, Type, Union, cast, get_origin

import numpy as np
from bluesky.protocols import DataKey, Dtype, Reading
Expand Down Expand Up @@ -131,32 +131,22 @@ class SoftPydanticModelConverter(SoftConverter):
def __init__(self, datatype: Type[BaseModel]):
self.datatype = datatype

def reading(self, value: T, timestamp: float, severity: int) -> Reading:
value = self.value(value)
return super().reading(value, timestamp, severity)

def value(self, value: Any) -> Any:
if isinstance(value, dict):
value = self.datatype(**value)
return value

def write_value(self, value):
if isinstance(value, self.datatype):
return value.model_dump(mode="python")
if isinstance(value, dict):
return self.datatype(**value)
return value

def make_initial_value(self, datatype: Type | None) -> Any:
return super().make_initial_value(datatype)


def make_converter(datatype):
is_array = get_dtype(datatype) is not None
is_sequence = get_origin(datatype) == abc.Sequence
is_enum = inspect.isclass(datatype) and (
issubclass(datatype, Enum) or issubclass(datatype, RuntimeSubsetEnum)
)

is_pydantic_model = (
inspect.isclass(datatype)
# Necessary to avoid weirdness in ABCMeta.__subclasscheck__
and isinstance(datatype, ABCMeta)
and issubclass(datatype, BaseModel)
)
Expand All @@ -179,8 +169,6 @@ class SoftSignalBackend(SignalBackend[T]):
_timestamp: float
_severity: int

_ALLOWED_DATATYPES = (object,) # Any type is allowed

@classmethod
def datatype_allowed(cls, datatype: Type) -> bool:
return True # Any value allowed in a soft signal
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
from pydantic import BaseModel, ConfigDict, model_validator


class PvaTable(BaseModel):
"""An abstraction of a PVA Table of str to numpy array."""
class Table(BaseModel):
"""An abstraction of a Table of str to numpy array."""

model_config = ConfigDict(validate_assignment=True, strict=False)

@classmethod
def row(cls, sub_cls, **kwargs) -> "PvaTable":
def row(cls, sub_cls, **kwargs) -> "Table":
arrayified_kwargs = {
field_name: np.concatenate(
(
Expand All @@ -20,12 +20,12 @@ def row(cls, sub_cls, **kwargs) -> "PvaTable":
}
return sub_cls(**arrayified_kwargs)

def __add__(self, right: "PvaTable") -> "PvaTable":
def __add__(self, right: "Table") -> "Table":
"""Concatenate the arrays in field values."""

assert isinstance(right, type(self)), (
f"{right} is not a `PvaTable`, or is not the same "
f"type of `PvaTable` as {self}."
f"{right} is not a `Table`, or is not the same "
f"type of `Table` as {self}."
)

return type(self)(
Expand All @@ -38,14 +38,12 @@ def __add__(self, right: "PvaTable") -> "PvaTable":
)

@model_validator(mode="after")
def validate_arrays(self) -> "PvaTable":
def validate_arrays(self) -> "Table":
first_length = len(next(iter(self))[1])
assert all(
len(field_value) == first_length for _, field_value in self
), "Rows should all be of equal size."

assert 0 <= first_length < 4096, f"Length {first_length} not in range."

if not all(
np.issubdtype(
self.model_fields[field_name].default_factory().dtype, field_value.dtype
Expand Down
2 changes: 0 additions & 2 deletions src/ophyd_async/epics/signal/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from ._common import LimitPair, Limits, get_supported_values
from ._p4p import PvaSignalBackend
from ._p4p_table_model import PvaTable
from ._signal import (
epics_signal_r,
epics_signal_rw,
Expand All @@ -14,7 +13,6 @@
"LimitPair",
"Limits",
"PvaSignalBackend",
"PvaTable",
"epics_signal_r",
"epics_signal_rw",
"epics_signal_rw_rbv",
Expand Down
10 changes: 9 additions & 1 deletion src/ophyd_async/epics/signal/_p4p.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import inspect
import logging
import time
from abc import ABCMeta
from dataclasses import dataclass
from enum import Enum
from math import isnan, nan
Expand Down Expand Up @@ -363,7 +364,14 @@ def make_converter(datatype: Optional[Type], values: Dict[str, Any]) -> PvaConve
raise TypeError(f"{pv} has type {typ.__name__} not {datatype.__name__}")
return PvaConverter()
elif "NTTable" in typeid:
if datatype and inspect.isclass(datatype) and issubclass(datatype, BaseModel):
if (
datatype
and inspect.isclass(datatype)
and
# Necessary to avoid weirdness in ABCMeta.__subclasscheck__
isinstance(datatype, ABCMeta)
and issubclass(datatype, BaseModel)
):
return PvaPydanticModelConverter(datatype)
return PvaTableConverter()
elif "structure" in typeid:
Expand Down
66 changes: 38 additions & 28 deletions src/ophyd_async/fastcs/panda/_table.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import inspect
from enum import Enum
from typing import Annotated, Sequence

import numpy as np
import numpy.typing as npt
from pydantic import Field
from pydantic import Field, field_validator, model_validator
from pydantic_numpy.helper.annotation import NpArrayPydanticAnnotation
from typing_extensions import TypedDict

from ophyd_async.epics.signal import PvaTable
from ophyd_async.core import Table


class PandaHdf5DatasetType(str, Enum):
Expand Down Expand Up @@ -50,8 +51,7 @@ class SeqTrigger(str, Enum):
),
Field(default_factory=lambda: np.array([], dtype=np.bool_)),
]

PydanticNp1DArrayUnicodeString = Annotated[
TriggerStr = Annotated[
np.ndarray[tuple[int], np.unicode_],
NpArrayPydanticAnnotation.factory(
data_type=np.unicode_, dimensions=1, strict_data_typing=False
Expand All @@ -60,9 +60,9 @@ class SeqTrigger(str, Enum):
]


class SeqTable(PvaTable):
class SeqTable(Table):
repeats: PydanticNp1DArrayInt32
trigger: PydanticNp1DArrayUnicodeString
trigger: TriggerStr
position: PydanticNp1DArrayInt32
time1: PydanticNp1DArrayInt32
outa1: PydanticNp1DArrayBool
Expand All @@ -83,8 +83,8 @@ class SeqTable(PvaTable):
def row(
cls,
*,
repeats: int = 0,
trigger: str = "",
repeats: int = 1,
trigger: str = SeqTrigger.IMMEDIATE,
position: int = 0,
time1: int = 0,
outa1: bool = False,
Expand All @@ -101,23 +101,33 @@ def row(
oute2: bool = False,
outf2: bool = False,
) -> "SeqTable":
return PvaTable.row(
cls,
repeats=repeats,
trigger=trigger,
position=position,
time1=time1,
outa1=outa1,
outb1=outb1,
outc1=outc1,
outd1=outd1,
oute1=oute1,
outf1=outf1,
time2=time2,
outa2=outa2,
outb2=outb2,
outc2=outc2,
outd2=outd2,
oute2=oute2,
outf2=outf2,
)
sig = inspect.signature(cls.row)
kwargs = {k: v for k, v in locals().items() if k in sig.parameters}
if isinstance(kwargs["trigger"], SeqTrigger):
kwargs["trigger"] = kwargs["trigger"].value
return Table.row(cls, **kwargs)

@field_validator("trigger", mode="before")
@classmethod
def trigger_to_np_array(cls, trigger_column):
"""
The user can provide a list of SeqTrigger enum elements instead of a numpy str.
"""
if isinstance(trigger_column, Sequence) and all(
isinstance(trigger, SeqTrigger) for trigger in trigger_column
):
trigger_column = np.array(
[trigger.value for trigger in trigger_column], dtype=np.dtype("<U32")
)
return trigger_column

@model_validator(mode="after")
def validate_max_length(self) -> "SeqTable":
"""
Used to check max_length. Unfortunately trying the `max_length` arg in
the pydantic field doesn't work
"""

first_length = len(next(iter(self))[1])
assert 0 <= first_length < 4096, f"Length {first_length} not in range."
return self
7 changes: 3 additions & 4 deletions tests/fastcs/panda/test_panda_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from ophyd_async.fastcs.panda import (
CommonPandaBlocks,
DataBlock,
PcompDirectionOptions,
SeqTable,
TimeUnits,
phase_sorter,
Expand Down Expand Up @@ -78,13 +77,13 @@ def check_equal_with_seq_tables(actual, expected):
"data.hdf_file_name": "",
"data.num_capture": 0,
"pcap.arm": False,
"pcomp.1.dir": PcompDirectionOptions.positive,
"pcomp.1.dir": "Positive",
"pcomp.1.enable": "ZERO",
"pcomp.1.pulses": 0,
"pcomp.1.start": 0,
"pcomp.1.step": 0,
"pcomp.1.width": 0,
"pcomp.2.dir": PcompDirectionOptions.positive,
"pcomp.2.dir": "Positive",
"pcomp.2.enable": "ZERO",
"pcomp.2.pulses": 0,
"pcomp.2.start": 0,
Expand Down Expand Up @@ -112,7 +111,7 @@ def check_equal_with_seq_tables(actual, expected):
"repeats": [1],
"time1": [0],
"time2": [0],
"trigger": [""],
"trigger": ["Immediate"],
},
"seq.1.repeats": 0,
"seq.1.prescale": 0.0,
Expand Down
31 changes: 30 additions & 1 deletion tests/fastcs/panda/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pydantic import ValidationError

from ophyd_async.fastcs.panda import SeqTable
from ophyd_async.fastcs.panda._table import SeqTrigger


def test_seq_table_converts_lists():
Expand All @@ -16,7 +17,7 @@ def test_seq_table_converts_lists():


def test_seq_table_validation_errors():
with pytest.raises(ValidationError, match="81 validation errors for SeqTable"):
with pytest.raises(ValidationError, match="80 validation errors for SeqTable"):
SeqTable(
repeats=0,
trigger="",
Expand Down Expand Up @@ -195,3 +196,31 @@ def test_seq_table_pva_conversion():
for column1, column2 in zip(applied_twice_to_pva_dict.values(), pva_dict.values()):
assert np.array_equal(column1, column2)
assert column1.dtype == column2.dtype


def test_seq_table_takes_trigger_enum_row():
for trigger in (SeqTrigger.BITA_0, "BITA=0"):
table = SeqTable.row(trigger=trigger)
assert table.trigger[0] == "BITA=0"
assert np.issubdtype(table.trigger.dtype, np.dtype("<U32"))
table = SeqTable(
repeats=np.array([1], dtype=np.int32),
trigger=[trigger],
position=np.array([1], dtype=np.int32),
time1=np.array([1], dtype=np.int32),
outa1=np.array([1], dtype=np.bool_),
outb1=np.array([1], dtype=np.bool_),
outc1=np.array([1], dtype=np.bool_),
outd1=np.array([1], dtype=np.bool_),
oute1=np.array([1], dtype=np.bool_),
outf1=np.array([1], dtype=np.bool_),
time2=np.array([1], dtype=np.int32),
outa2=np.array([1], dtype=np.bool_),
outb2=np.array([1], dtype=np.bool_),
outc2=np.array([1], dtype=np.bool_),
outd2=np.array([1], dtype=np.bool_),
oute2=np.array([1], dtype=np.bool_),
outf2=np.array([1], dtype=np.bool_),
)
assert table.trigger[0] == "BITA=0"
assert np.issubdtype(table.trigger.dtype, np.dtype("<U32"))
2 changes: 1 addition & 1 deletion tests/test_data/test_yaml_save.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@
pv_int: 1
pv_protocol_device_abstraction:
some_int_field: 1
some_pydantic_numpy_field_float: [1.0, 2.0, 3.0]
some_pydantic_numpy_field_float: [1, 2, 3]
some_pydantic_numpy_field_int: [1, 2, 3]
pv_str: test_string

0 comments on commit 1aafbc0

Please sign in to comment.