Skip to content

Commit

Permalink
Slight improvements to the util functions
Browse files Browse the repository at this point in the history
  • Loading branch information
evalott100 committed Oct 10, 2023
1 parent 2cb2a9a commit 6ccaa8f
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 94 deletions.
35 changes: 16 additions & 19 deletions pandablocks/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Iterable, List, Sequence, Union, cast
from typing import Dict, Iterable, List, Union, cast

import numpy as np
import numpy.typing as npt
Expand All @@ -8,7 +8,7 @@
UnpackedArray = Union[
npt.NDArray[np.int32],
npt.NDArray[np.uint32],
Sequence[str],
List[str],
]


Expand All @@ -25,9 +25,8 @@ def words_to_table(
expected to be the string representation of a uint32.
table_fields_info: The info for tables, containing the number of words per row,
and the bit information for fields.
convert_enum_indices: If True, converts enum indices to labels, the packed
value will be a list of strings. If False the packed value will be a
numpy array of the indices the labels correspond to.
convert_enum_indices: If True, convert all enum values to their string
representation. Otherwise return enums as integer values
Returns:
unpacked: A dict containing record information, where keys are field names
and values are numpy arrays or a sequence of strings of record values
Expand Down Expand Up @@ -61,23 +60,19 @@ def words_to_table(
if field_info.subtype == "int":
# First convert from 2's complement to offset, then add in offset.
packing_value = (value ^ (1 << (bit_length - 1))) + (-1 << (bit_length - 1))
elif convert_enum_indices and field_info.labels:
elif field_info.subtype == "enum" and convert_enum_indices:
assert field_info.labels, f"Enum field {field_name} has no labels"
packing_value = [field_info.labels[x] for x in value]
else:
if bit_length <= 8:
packing_value = value.astype(np.uint8)
elif bit_length <= 16:
packing_value = value.astype(np.uint16)
else:
packing_value = value.astype(np.uint32)
packing_value = value.astype(np.uint32)

unpacked.update({field_name: packing_value})

return unpacked


def table_to_words(
table: Dict[str, Union[np.ndarray, List]], table_field_info: TableFieldInfo
table: Dict[str, UnpackedArray], table_field_info: TableFieldInfo
) -> List[str]:
"""Convert records based on the field definitions into the format PandA expects
for table writes.
Expand All @@ -100,16 +95,17 @@ def table_to_words(
field_details = table_field_info.fields[column_name]
if field_details.labels and len(column) and isinstance(column[0], str):
# Must convert the list of strings to list of ints
column = [field_details.labels.index(x) for x in column]

# PandA always handles tables in uint32 format
column_value = np.array(column, dtype=np.uint32)
column_value = np.array(
[field_details.labels.index(x) for x in column], dtype=np.uint32
)
else:
# PandA always handles tables in uint32 format
column_value = np.array(column, dtype=np.uint32)

if packed is None:
# Create 1-D array sufficiently long to exactly hold the entire table, cast
# to prevent type error, this will still work if column is another iterable
# e.g numpy array
column = cast(List, column)
packed = np.zeros((len(column), row_words), dtype=np.uint32)
else:
assert len(packed) == len(column), (
Expand All @@ -127,7 +123,8 @@ def table_to_words(

# Slice to get the column to apply the values to.
# bit shift the value to the relevant bits of the word
packed[:, word_offset] |= column_value << bit_offset

packed[:, word_offset] |= cast(np.unsignedinteger, column_value) << bit_offset

assert isinstance(packed, np.ndarray), "Table has no columns" # Squash mypy warning

Expand Down
97 changes: 22 additions & 75 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, OrderedDict, Union
from typing import Dict, List, OrderedDict

import numpy as np
import pytest
Expand Down Expand Up @@ -157,36 +157,13 @@ def table_field_info(table_fields) -> TableFieldInfo:


@pytest.fixture
def table_1() -> OrderedDict[str, Union[List, np.ndarray]]:
return OrderedDict(
{
"REPEATS": [5, 0, 50000],
"TRIGGER": ["Immediate", "BITC=1", "Immediate"],
"POSITION": [-5, 678, 0],
"TIME1": [100, 0, 9],
"OUTA1": [0, 1, 1],
"OUTB1": [0, 0, 1],
"OUTC1": [0, 0, 1],
"OUTD1": [1, 0, 1],
"OUTE1": [0, 0, 1],
"OUTF1": [1, 0, 1],
"TIME2": [0, 55, 9999],
"OUTA2": [0, 0, 1],
"OUTB2": [0, 0, 1],
"OUTC2": [1, 1, 1],
"OUTD2": [0, 0, 1],
"OUTE2": [0, 0, 1],
"OUTF2": [1, 0, 1],
}
)


@pytest.fixture
def table_1_np_arrays() -> OrderedDict[str, Union[List, np.ndarray]]:
def table_1_np_arrays() -> OrderedDict[str, UnpackedArray]:
# Intentionally not in panda order. Whatever types the np arrays are,
# the outputs from words_to_table will be uint32 or int32.
return OrderedDict(
{
"REPEATS": np.array([5, 0, 50000], dtype=np.uint32),
"TRIGGER": ["Immediate", "BITC=1", "Immediate"],
"POSITION": np.array([-5, 678, 0], dtype=np.int32),
"TIME1": np.array([100, 0, 9], dtype=np.uint32),
"OUTA1": np.array([0, 1, 1], dtype=np.uint8),
Expand All @@ -198,18 +175,16 @@ def table_1_np_arrays() -> OrderedDict[str, Union[List, np.ndarray]]:
"TIME2": np.array([0, 55, 9999], dtype=np.uint32),
"OUTA2": np.array([0, 0, 1], dtype=np.uint8),
"OUTB2": np.array([0, 0, 1], dtype=np.uint8),
"REPEATS": np.array([5, 0, 50000], dtype=np.uint32),
"OUTC2": np.array([1, 1, 1], dtype=np.uint8),
"OUTD2": np.array([0, 0, 1], dtype=np.uint8),
"OUTE2": np.array([0, 0, 1], dtype=np.uint8),
"OUTF2": np.array([1, 0, 1], dtype=np.uint8),
"TRIGGER": np.array(["Immediate", "BITC=1", "Immediate"], dtype="<U9"),
}
)


@pytest.fixture
def table_1_np_arrays_int_enums() -> OrderedDict[str, Union[List, np.ndarray]]:
def table_1_np_arrays_int_enums() -> OrderedDict[str, UnpackedArray]:
# Intentionally not in panda order. Whatever types the np arrays are,
# the outputs from words_to_table will be uint32 or int32.
return OrderedDict(
Expand All @@ -235,31 +210,6 @@ def table_1_np_arrays_int_enums() -> OrderedDict[str, Union[List, np.ndarray]]:
)


@pytest.fixture
def table_1_not_in_panda_order() -> OrderedDict[str, Union[List, np.ndarray]]:
return OrderedDict(
{
"REPEATS": [5, 0, 50000],
"TRIGGER": ["Immediate", "BITC=1", "Immediate"],
"POSITION": [-5, 678, 0],
"TIME1": [100, 0, 9],
"OUTA1": [0, 1, 1],
"OUTB1": [0, 0, 1],
"OUTC1": [0, 0, 1],
"OUTD1": [1, 0, 1],
"OUTF1": [1, 0, 1],
"OUTE1": [0, 0, 1],
"TIME2": [0, 55, 9999],
"OUTA2": [0, 0, 1],
"OUTC2": [1, 1, 1],
"OUTB2": [0, 0, 1],
"OUTD2": [0, 0, 1],
"OUTE2": [0, 0, 1],
"OUTF2": [1, 0, 1],
}
)


@pytest.fixture
def table_data_1() -> List[str]:
return [
Expand All @@ -279,19 +229,19 @@ def table_data_1() -> List[str]:


@pytest.fixture
def table_2() -> Dict[str, Union[List, np.ndarray]]:
table: Dict[str, Union[List, np.ndarray]] = dict(
REPEATS=[1, 0],
def table_2_np_arrays() -> Dict[str, UnpackedArray]:
table: Dict[str, UnpackedArray] = dict(
REPEATS=np.array([1, 0], dtype=np.uint32),
TRIGGER=["Immediate", "Immediate"],
POSITION=[-20, 2**31 - 1],
TIME1=[12, 2**32 - 1],
TIME2=[32, 1],
POSITION=np.array([-20, 2**31 - 1], dtype=np.int32),
TIME1=np.array([12, 2**32 - 1], dtype=np.uint32),
TIME2=np.array([32, 1], dtype=np.uint32),
)

table["OUTA1"] = [False, True]
table["OUTA2"] = [True, False]
table["OUTA1"] = np.array([0, 1], dtype=np.uint8)
table["OUTA2"] = np.array([1, 0], dtype=np.uint8)
for key in "BCDEF":
table[f"OUT{key}1"] = table[f"OUT{key}2"] = [False, False]
table[f"OUT{key}1"] = table[f"OUT{key}2"] = np.array([0, 0], dtype=np.uint8)

return table

Expand All @@ -311,25 +261,24 @@ def table_data_2() -> List[str]:


def test_table_packing_pack_length_mismatched(
table_1: OrderedDict[str, Union[List, np.ndarray]],
table_1_np_arrays: OrderedDict[str, UnpackedArray],
table_field_info: TableFieldInfo,
):
assert table_field_info.row_words

# Adjust one of the record lengths so it mismatches
field_info = table_field_info.fields[("OUTC1")]
assert field_info
table_1["OUTC1"] = np.array([1, 2, 3, 4, 5, 6, 7, 8])
table_1_np_arrays["OUTC1"] = np.array([1, 2, 3, 4, 5, 6, 7, 8])

with pytest.raises(AssertionError):
table_to_words(table_1, table_field_info)
table_to_words(table_1_np_arrays, table_field_info)


@pytest.mark.parametrize(
"table_fixture_name,table_data_fixture_name",
[
("table_1_not_in_panda_order", "table_data_1"),
("table_2", "table_data_2"),
("table_2_np_arrays", "table_data_2"),
("table_1_np_arrays", "table_data_1"),
],
)
Expand All @@ -339,9 +288,7 @@ def test_table_to_words_and_words_to_table(
table_field_info: TableFieldInfo,
request,
):
table: Dict[str, Union[List, np.ndarray]] = request.getfixturevalue(
table_fixture_name
)
table: Dict[str, UnpackedArray] = request.getfixturevalue(table_fixture_name)
table_data: List[str] = request.getfixturevalue(table_data_fixture_name)

output_data = table_to_words(table, table_field_info)
Expand Down Expand Up @@ -379,7 +326,7 @@ def test_table_packing_unpack(


def test_table_packing_unpack_no_convert_enum(
table_1_np_arrays_int_enums: OrderedDict[str, np.ndarray],
table_1_np_arrays_int_enums: OrderedDict[str, UnpackedArray],
table_field_info: TableFieldInfo,
table_data_1: List[str],
):
Expand All @@ -393,12 +340,12 @@ def test_table_packing_unpack_no_convert_enum(


def test_table_packing_pack(
table_1: Dict[str, Union[List, np.ndarray]],
table_1_np_arrays: Dict[str, UnpackedArray],
table_field_info: TableFieldInfo,
table_data_1: List[str],
):
assert table_field_info.row_words
unpacked = table_to_words(table_1, table_field_info)
unpacked = table_to_words(table_1_np_arrays, table_field_info)

for actual, expected in zip(unpacked, table_data_1):
assert actual == expected

0 comments on commit 6ccaa8f

Please sign in to comment.