diff --git a/src/ophyd_async/core/_table.py b/src/ophyd_async/core/_table.py index bdb619a3b..3316ba84b 100644 --- a/src/ophyd_async/core/_table.py +++ b/src/ophyd_async/core/_table.py @@ -1,7 +1,17 @@ +from enum import Enum +from typing import get_args + import numpy as np from pydantic import BaseModel, ConfigDict, model_validator +def _concat(value1, value2): + if isinstance(value1, np.ndarray): + return np.concatenate((value1, value2)) + else: + return value1 + value2 + + class Table(BaseModel): """An abstraction of a Table of str to numpy array.""" @@ -9,34 +19,89 @@ class Table(BaseModel): @classmethod def row(cls, sub_cls, **kwargs) -> "Table": - arrayified_kwargs = { - field_name: np.concatenate( - ( - (default_arr := field_value.default_factory()), - np.array([kwargs[field_name]], dtype=default_arr.dtype), + arrayified_kwargs = {} + for field_name, field_value in sub_cls.model_fields.items(): + value = kwargs.pop(field_name) + if field_value.default_factory is None: + raise ValueError( + "`Table` models should have default factories for their " + "mutable empty columns." + ) + default_array = field_value.default_factory() + if isinstance(default_array, np.ndarray): + arrayified_kwargs[field_name] = np.array( + [value], dtype=default_array.dtype ) + elif issubclass(type(value), Enum) and isinstance(value, str): + arrayified_kwargs[field_name] = [value] + else: + raise TypeError( + "Row column should be numpy arrays or sequence of string `Enum`." + ) + if kwargs: + raise TypeError( + f"Unexpected keyword arguments {kwargs.keys()} for {sub_cls.__name__}." ) - for field_name, field_value in sub_cls.model_fields.items() - } + return sub_cls(**arrayified_kwargs) def __add__(self, right: "Table") -> "Table": """Concatenate the arrays in field values.""" - assert isinstance(right, type(self)), ( - f"{right} is not a `Table`, or is not the same " - f"type of `Table` as {self}." - ) + if not isinstance(right, type(self)): + raise RuntimeError( + f"{right} is not a `Table`, or is not the same " + f"type of `Table` as {self}." + ) return type(self)( **{ - field_name: np.concatenate( - (getattr(self, field_name), getattr(right, field_name)) + field_name: _concat( + getattr(self, field_name), getattr(right, field_name) ) for field_name in self.model_fields } ) + def numpy_dtype(self) -> np.dtype: + dtype = [] + for field_value in self.model_fields.values(): + if isinstance(field_value, np.ndarray): + dtype.append(field_value.dtype) + else: + enum_type = get_args(field_value.annotation)[0] + assert issubclass(enum_type, Enum) + enum_values = [element.value for element in enum_type] + max_length_in_enum = max(len(value) for value in enum_values) + dtype.append(np.dtype(f" list[np.ndarray]: + """Columns in the table can be lists of string enums or numpy arrays. + + This method returns the columns, converting the string enums to numpy arrays. + """ + + columns = [] + for field_value in self.model_fields.values(): + if isinstance(field_value, np.ndarray): + columns.append(field_value) + else: + enum_type = get_args(field_value.field_info.annotation)[0] + assert issubclass(enum_type, Enum) + enum_values = [element.value for element in enum_type] + max_length_in_enum = max(len(value) for value in enum_values) + dtype = np.dtype(f" "Table": first_length = len(next(iter(self))[1]) @@ -45,10 +110,15 @@ def validate_arrays(self) -> "Table": ), "Rows should all be of equal size." if not all( - np.issubdtype( - self.model_fields[field_name].default_factory().dtype, field_value.dtype + # Checks if the values are numpy subtypes if the array is a numpy array, + # or if the value is a string enum. + np.issubdtype(getattr(self, field_name).dtype, default_array.dtype) + if isinstance( + default_array := self.model_fields[field_name].default_factory(), + np.ndarray, ) - for field_name, field_value in self + else issubclass(get_args(field_value.annotation)[0], Enum) + for field_name, field_value in self.model_fields.items() ): raise ValueError( f"Cannot construct a `{type(self).__name__}`, " diff --git a/src/ophyd_async/fastcs/panda/_table.py b/src/ophyd_async/fastcs/panda/_table.py index ee6df7522..3b8a2a8e5 100644 --- a/src/ophyd_async/fastcs/panda/_table.py +++ b/src/ophyd_async/fastcs/panda/_table.py @@ -51,13 +51,7 @@ class SeqTrigger(str, Enum): ), Field(default_factory=lambda: np.array([], dtype=np.bool_)), ] -TriggerStr = Annotated[ - np.ndarray[tuple[int], np.unicode_], - NpArrayPydanticAnnotation.factory( - data_type=np.unicode_, dimensions=1, strict_data_typing=False - ), - Field(default_factory=lambda: np.array([], dtype=np.dtype(" "SeqTable": sig = inspect.signature(cls.row) kwargs = {k: v for k, v in locals().items() if k in sig.parameters} - - if isinstance(kwargs["trigger"], SeqTrigger): - kwargs["trigger"] = kwargs["trigger"].value - elif isinstance(kwargs["trigger"], str): - SeqTrigger(kwargs["trigger"]) - + if not isinstance(kwargs["trigger"], SeqTrigger): + if kwargs["trigger"] not in SeqTrigger.__members__.values(): + raise ValueError(f"'{kwargs['trigger']}' is not a valid trigger.") + kwargs["trigger"] = SeqTrigger(kwargs["trigger"]) return Table.row(cls, **kwargs) - @field_validator("trigger", mode="before") - @classmethod - def trigger_to_np_array(cls, trigger_column): - """ - The user can provide a list of SeqTrigger enum elements instead of a numpy str. - """ - if isinstance(trigger_column, Sequence) and all( - isinstance(trigger, SeqTrigger) for trigger in trigger_column - ): - trigger_column = np.array( - [trigger.value for trigger in trigger_column], dtype=np.dtype(" "SeqTable": """ diff --git a/tests/fastcs/panda/test_table.py b/tests/fastcs/panda/test_table.py index c5f5abb84..b2bc3d4cc 100644 --- a/tests/fastcs/panda/test_table.py +++ b/tests/fastcs/panda/test_table.py @@ -12,8 +12,11 @@ def test_seq_table_converts_lists(): seq_table_dict_with_lists = {field_name: [] for field_name, _ in SeqTable()} # Validation passes seq_table = SeqTable(**seq_table_dict_with_lists) - assert isinstance(seq_table.trigger, np.ndarray) - assert seq_table.trigger.dtype == np.dtype("U32") + for field_name, field_value in seq_table: + if field_name == "trigger": + assert field_value == [] + else: + assert np.array_equal(field_value, np.array([], dtype=field_value.dtype)) def test_seq_table_validation_errors(): @@ -73,6 +76,7 @@ def test_seq_table_validation_errors(): wrong_types = { field_name: field_value.astype(np.unicode_) for field_name, field_value in row_one + if isinstance(field_value, np.ndarray) } SeqTable(**wrong_types) @@ -80,9 +84,12 @@ def test_seq_table_validation_errors(): def test_seq_table_pva_conversion(): pva_dict = { "repeats": np.array([1, 2, 3, 4], dtype=np.int32), - "trigger": np.array( - ["Immediate", "Immediate", "BITC=0", "Immediate"], dtype=np.dtype("U32") - ), + "trigger": [ + SeqTrigger.IMMEDIATE, + SeqTrigger.IMMEDIATE, + SeqTrigger.BITC_0, + SeqTrigger.IMMEDIATE, + ], "position": np.array([1, 2, 3, 4], dtype=np.int32), "time1": np.array([1, 0, 1, 0], dtype=np.int32), "outa1": np.array([1, 0, 1, 0], dtype=np.bool_), @@ -178,31 +185,36 @@ def test_seq_table_pva_conversion(): }, ] + def _assert_col_equal(column1, column2): + if isinstance(column1, np.ndarray): + assert np.array_equal(column1, column2) + assert column1.dtype == column2.dtype + else: + assert column1 == column2 + assert all(isinstance(x, SeqTrigger) for x in column1) + assert all(isinstance(x, SeqTrigger) for x in column2) + seq_table_from_pva_dict = SeqTable(**pva_dict) for (_, column1), column2 in zip(seq_table_from_pva_dict, pva_dict.values()): - assert np.array_equal(column1, column2) - assert column1.dtype == column2.dtype + _assert_col_equal(column1, column2) seq_table_from_rows = reduce( lambda x, y: x + y, [SeqTable.row(**row_kwargs) for row_kwargs in row_wise_dicts], ) for (_, column1), column2 in zip(seq_table_from_rows, pva_dict.values()): - assert np.array_equal(column1, column2) - assert column1.dtype == column2.dtype + _assert_col_equal(column1, column2) # Idempotency applied_twice_to_pva_dict = SeqTable(**pva_dict).model_dump(mode="python") for column1, column2 in zip(applied_twice_to_pva_dict.values(), pva_dict.values()): - assert np.array_equal(column1, column2) - assert column1.dtype == column2.dtype + _assert_col_equal(column1, column2) def test_seq_table_takes_trigger_enum_row(): for trigger in (SeqTrigger.BITA_0, "BITA=0"): table = SeqTable.row(trigger=trigger) - assert table.trigger[0] == "BITA=0" - assert np.issubdtype(table.trigger.dtype, np.dtype("