Skip to content

Commit

Permalink
wip: pushing because my computer broke
Browse files Browse the repository at this point in the history
  • Loading branch information
Eva Lott committed Sep 16, 2024
1 parent 0cbfbe6 commit 357990c
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 71 deletions.
102 changes: 86 additions & 16 deletions src/ophyd_async/core/_table.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,107 @@
from enum import Enum
from typing import get_args

import numpy as np
from pydantic import BaseModel, ConfigDict, model_validator


def _concat(value1, value2):
if isinstance(value1, np.ndarray):
return np.concatenate((value1, value2))
else:
return value1 + value2


class Table(BaseModel):
"""An abstraction of a Table of str to numpy array."""

model_config = ConfigDict(validate_assignment=True, strict=False)

@classmethod
def row(cls, sub_cls, **kwargs) -> "Table":
arrayified_kwargs = {
field_name: np.concatenate(
(
(default_arr := field_value.default_factory()),
np.array([kwargs[field_name]], dtype=default_arr.dtype),
arrayified_kwargs = {}
for field_name, field_value in sub_cls.model_fields.items():
value = kwargs.pop(field_name)
if field_value.default_factory is None:
raise ValueError(
"`Table` models should have default factories for their "
"mutable empty columns."
)
default_array = field_value.default_factory()
if isinstance(default_array, np.ndarray):
arrayified_kwargs[field_name] = np.array(
[value], dtype=default_array.dtype
)
elif issubclass(type(value), Enum) and isinstance(value, str):
arrayified_kwargs[field_name] = [value]
else:
raise TypeError(
"Row column should be numpy arrays or sequence of string `Enum`."
)
if kwargs:
raise TypeError(
f"Unexpected keyword arguments {kwargs.keys()} for {sub_cls.__name__}."
)
for field_name, field_value in sub_cls.model_fields.items()
}

return sub_cls(**arrayified_kwargs)

def __add__(self, right: "Table") -> "Table":
"""Concatenate the arrays in field values."""

assert isinstance(right, type(self)), (
f"{right} is not a `Table`, or is not the same "
f"type of `Table` as {self}."
)
if not isinstance(right, type(self)):
raise RuntimeError(
f"{right} is not a `Table`, or is not the same "
f"type of `Table` as {self}."
)

return type(self)(
**{
field_name: np.concatenate(
(getattr(self, field_name), getattr(right, field_name))
field_name: _concat(
getattr(self, field_name), getattr(right, field_name)
)
for field_name in self.model_fields
}
)

def numpy_dtype(self) -> np.dtype:
dtype = []
for field_value in self.model_fields.values():
if isinstance(field_value, np.ndarray):
dtype.append(field_value.dtype)
else:
enum_type = get_args(field_value.annotation)[0]
assert issubclass(enum_type, Enum)
enum_values = [element.value for element in enum_type]
max_length_in_enum = max(len(value) for value in enum_values)
dtype.append(np.dtype(f"<U{max_length_in_enum}"))

def numpy_table(self):
return np.array(
self.numpy_columns(),
dtype=self.numpy_dtype(),
).transpose()

def numpy_columns(self) -> list[np.ndarray]:
"""Columns in the table can be lists of string enums or numpy arrays.
This method returns the columns, converting the string enums to numpy arrays.
"""

columns = []
for field_value in self.model_fields.values():
if isinstance(field_value, np.ndarray):
columns.append(field_value)
else:
enum_type = get_args(field_value.field_info.annotation)[0]
assert issubclass(enum_type, Enum)
enum_values = [element.value for element in enum_type]
max_length_in_enum = max(len(value) for value in enum_values)
dtype = np.dtype(f"<U{max_length_in_enum}")

columns.append(np.array(enum_values, dtype=dtype))

return columns

@model_validator(mode="after")
def validate_arrays(self) -> "Table":
first_length = len(next(iter(self))[1])
Expand All @@ -45,10 +110,15 @@ def validate_arrays(self) -> "Table":
), "Rows should all be of equal size."

if not all(
np.issubdtype(
self.model_fields[field_name].default_factory().dtype, field_value.dtype
# Checks if the values are numpy subtypes if the array is a numpy array,
# or if the value is a string enum.
np.issubdtype(getattr(self, field_name).dtype, default_array.dtype)
if isinstance(
default_array := self.model_fields[field_name].default_factory(),
np.ndarray,
)
for field_name, field_value in self
else issubclass(get_args(field_value.annotation)[0], Enum)
for field_name, field_value in self.model_fields.items()
):
raise ValueError(
f"Cannot construct a `{type(self).__name__}`, "
Expand Down
44 changes: 5 additions & 39 deletions src/ophyd_async/fastcs/panda/_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,7 @@ class SeqTrigger(str, Enum):
),
Field(default_factory=lambda: np.array([], dtype=np.bool_)),
]
TriggerStr = Annotated[
np.ndarray[tuple[int], np.unicode_],
NpArrayPydanticAnnotation.factory(
data_type=np.unicode_, dimensions=1, strict_data_typing=False
),
Field(default_factory=lambda: np.array([], dtype=np.dtype("<U32"))),
]
TriggerStr = Annotated[list[SeqTrigger], Field(default_factory=list)]


class SeqTable(Table):
Expand Down Expand Up @@ -103,40 +97,12 @@ def row(
) -> "SeqTable":
sig = inspect.signature(cls.row)
kwargs = {k: v for k, v in locals().items() if k in sig.parameters}

if isinstance(kwargs["trigger"], SeqTrigger):
kwargs["trigger"] = kwargs["trigger"].value
elif isinstance(kwargs["trigger"], str):
SeqTrigger(kwargs["trigger"])

if not isinstance(kwargs["trigger"], SeqTrigger):
if kwargs["trigger"] not in SeqTrigger.__members__.values():
raise ValueError(f"'{kwargs['trigger']}' is not a valid trigger.")
kwargs["trigger"] = SeqTrigger(kwargs["trigger"])
return Table.row(cls, **kwargs)

@field_validator("trigger", mode="before")
@classmethod
def trigger_to_np_array(cls, trigger_column):
"""
The user can provide a list of SeqTrigger enum elements instead of a numpy str.
"""
if isinstance(trigger_column, Sequence) and all(
isinstance(trigger, SeqTrigger) for trigger in trigger_column
):
trigger_column = np.array(
[trigger.value for trigger in trigger_column], dtype=np.dtype("<U32")
)
elif isinstance(trigger_column, Sequence) or isinstance(
trigger_column, np.ndarray
):
for trigger in trigger_column:
SeqTrigger(
trigger
) # To check all the given strings are actually `SeqTrigger`s
else:
raise ValueError(
"Expected a numpy array or a sequence of `SeqTrigger`, got "
f"{type(trigger_column)}."
)
return trigger_column

@model_validator(mode="after")
def validate_max_length(self) -> "SeqTable":
"""
Expand Down
41 changes: 26 additions & 15 deletions tests/fastcs/panda/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@ def test_seq_table_converts_lists():
seq_table_dict_with_lists = {field_name: [] for field_name, _ in SeqTable()}
# Validation passes
seq_table = SeqTable(**seq_table_dict_with_lists)
assert isinstance(seq_table.trigger, np.ndarray)
assert seq_table.trigger.dtype == np.dtype("U32")
for field_name, field_value in seq_table:
if field_name == "trigger":
assert field_value == []
else:
assert np.array_equal(field_value, np.array([], dtype=field_value.dtype))


def test_seq_table_validation_errors():
Expand Down Expand Up @@ -73,16 +76,20 @@ def test_seq_table_validation_errors():
wrong_types = {
field_name: field_value.astype(np.unicode_)
for field_name, field_value in row_one
if isinstance(field_value, np.ndarray)
}
SeqTable(**wrong_types)


def test_seq_table_pva_conversion():
pva_dict = {
"repeats": np.array([1, 2, 3, 4], dtype=np.int32),
"trigger": np.array(
["Immediate", "Immediate", "BITC=0", "Immediate"], dtype=np.dtype("U32")
),
"trigger": [
SeqTrigger.IMMEDIATE,
SeqTrigger.IMMEDIATE,
SeqTrigger.BITC_0,
SeqTrigger.IMMEDIATE,
],
"position": np.array([1, 2, 3, 4], dtype=np.int32),
"time1": np.array([1, 0, 1, 0], dtype=np.int32),
"outa1": np.array([1, 0, 1, 0], dtype=np.bool_),
Expand Down Expand Up @@ -178,31 +185,36 @@ def test_seq_table_pva_conversion():
},
]

def _assert_col_equal(column1, column2):
if isinstance(column1, np.ndarray):
assert np.array_equal(column1, column2)
assert column1.dtype == column2.dtype
else:
assert column1 == column2
assert all(isinstance(x, SeqTrigger) for x in column1)
assert all(isinstance(x, SeqTrigger) for x in column2)

seq_table_from_pva_dict = SeqTable(**pva_dict)
for (_, column1), column2 in zip(seq_table_from_pva_dict, pva_dict.values()):
assert np.array_equal(column1, column2)
assert column1.dtype == column2.dtype
_assert_col_equal(column1, column2)

seq_table_from_rows = reduce(
lambda x, y: x + y,
[SeqTable.row(**row_kwargs) for row_kwargs in row_wise_dicts],
)
for (_, column1), column2 in zip(seq_table_from_rows, pva_dict.values()):
assert np.array_equal(column1, column2)
assert column1.dtype == column2.dtype
_assert_col_equal(column1, column2)

# Idempotency
applied_twice_to_pva_dict = SeqTable(**pva_dict).model_dump(mode="python")
for column1, column2 in zip(applied_twice_to_pva_dict.values(), pva_dict.values()):
assert np.array_equal(column1, column2)
assert column1.dtype == column2.dtype
_assert_col_equal(column1, column2)


def test_seq_table_takes_trigger_enum_row():
for trigger in (SeqTrigger.BITA_0, "BITA=0"):
table = SeqTable.row(trigger=trigger)
assert table.trigger[0] == "BITA=0"
assert np.issubdtype(table.trigger.dtype, np.dtype("<U32"))
assert table.trigger[0] == SeqTrigger.BITA_0
table = SeqTable(
repeats=np.array([1], dtype=np.int32),
trigger=[trigger],
Expand All @@ -222,5 +234,4 @@ def test_seq_table_takes_trigger_enum_row():
oute2=np.array([1], dtype=np.bool_),
outf2=np.array([1], dtype=np.bool_),
)
assert table.trigger[0] == "BITA=0"
assert np.issubdtype(table.trigger.dtype, np.dtype("<U32"))
assert table.trigger[0] == SeqTrigger.BITA_0
2 changes: 1 addition & 1 deletion tests/fastcs/panda/test_trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,5 +144,5 @@ def full_seq_table(trigger):
full_seq_table(["A"])
assert "Value error, 'A' is not a valid SeqTrigger" in str(exc)
with pytest.raises(ValidationError) as exc:
full_seq_table({"A"})
full_seq_table({"Immediate"})
assert "Expected a numpy array or a sequence of `SeqTrigger`, got" in str(exc)

0 comments on commit 357990c

Please sign in to comment.