Skip to content

Commit

Permalink
[Backend Configuration IIIc]: Improved testing and debugs (#703)
Browse files Browse the repository at this point in the history
  • Loading branch information
CodyCBakerPhD authored Jan 12, 2024
1 parent cdb323a commit 8ddcfa3
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 21 deletions.
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""Base Pydantic models for DatasetInfo and DatasetConfiguration."""
import math
from abc import ABC, abstractmethod
from typing import Any, Dict, Literal, Tuple, Union
from typing import Any, Dict, List, Literal, Tuple, Union

import h5py
import numcodecs
import numpy as np
import zarr
from hdmf import Container
from hdmf.data_utils import DataChunkIterator, GenericDataChunkIterator
from hdmf.data_utils import GenericDataChunkIterator
from hdmf.utils import get_data_shape
from pydantic import BaseModel, Field, root_validator
from pynwb import NWBFile
Expand All @@ -35,20 +35,35 @@ def _find_location_in_memory_nwbfile(current_location: str, neurodata_object: Co
)


def _infer_dtype_using_data_chunk_iterator(candidate_dataset: Union[h5py.Dataset, zarr.Array]):
def _infer_dtype_of_list(list_: List[Union[int, float, list]]) -> np.dtype:
"""
The DataChunkIterator has one of the best generic dtype inference, though logic is hard to peel out of it.
Attempt to infer the dtype of values in an arbitrarily sized and nested list.
It can fail in rare cases but not essential to our default configuration
Relies on the ability of the numpy.array call to cast the list as an array so the 'dtype' attribute can be used.
"""
try:
data_type = DataChunkIterator(candidate_dataset).dtype
return data_type
except Exception as exception:
if str(exception) != "Data type could not be determined. Please specify dtype in DataChunkIterator init.":
raise exception
for item in list_:
if isinstance(item, list):
dtype = _infer_dtype_of_list(list_=item)
if dtype is not None:
return dtype
else:
return np.dtype("object")
return np.array([item]).dtype

raise ValueError("Unable to determine the dtype of values in the list.")


def _infer_dtype(dataset: Union[h5py.Dataset, zarr.Array]) -> np.dtype:
"""Attempt to infer the dtype of the contained values of the dataset."""
if hasattr(dataset, "dtype"):
data_type = np.dtype(dataset.dtype)
return data_type

if isinstance(dataset, list):
return _infer_dtype_of_list(list_=dataset)

# Think more on if there is a better way to handle this fallback
data_type = np.dtype("object")
return data_type


class DatasetInfo(BaseModel):
Expand Down Expand Up @@ -108,7 +123,7 @@ def from_neurodata_object(cls, neurodata_object: Container, field_name: str) ->
candidate_dataset = getattr(neurodata_object, field_name)

full_shape = get_data_shape(data=candidate_dataset)
dtype = _infer_dtype_using_data_chunk_iterator(candidate_dataset=candidate_dataset)
dtype = _infer_dtype(dataset=candidate_dataset)

return cls(
object_id=neurodata_object.object_id,
Expand Down
3 changes: 3 additions & 0 deletions src/neuroconv/tools/nwb_helpers/_dataset_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ def _is_dataset_written_to_file(
This object should then be skipped by the `get_io_datasets` function when working in append mode.
"""
if existing_file is None:
return False

return (
isinstance(candidate_dataset, h5py.Dataset) # If the source data is an HDF5 Dataset
and backend == "hdf5"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Unit tests for `get_default_dataset_configurations`."""
from pathlib import Path
from typing import Literal
from typing import Literal, Tuple

import numcodecs
import numpy as np
import pytest
from hdmf.common import DynamicTable, VectorData
from hdmf.data_utils import DataChunkIterator
from numpy.testing import assert_array_equal
from pynwb.testing.mock.base import mock_TimeSeries
from pynwb.testing.mock.file import mock_NWBFile

Expand All @@ -18,6 +19,24 @@
)


@pytest.fixture(scope="session")
def integer_array(
seed: int = 0,
dtype: np.dtype = np.dtype("int16"),
shape: Tuple[int, int] = (30_000 * 5, 384),
):
"""
Generate an array of integers.
Default values are chosen to be similar to 5 seconds of v1 NeuroPixel data.
"""
random_number_generator = np.random.default_rng(seed=seed)

low = np.iinfo(dtype).min
high = np.iinfo(dtype).max
return random_number_generator.integers(low=low, high=high, size=shape, dtype=dtype)


@pytest.mark.parametrize(
"case_name,iterator,iterator_options",
[
Expand All @@ -29,10 +48,14 @@
)
@pytest.mark.parametrize("backend", ["hdf5", "zarr"])
def test_simple_time_series(
tmpdir: Path, case_name: str, iterator: callable, iterator_options: dict, backend: Literal["hdf5", "zarr"]
tmpdir: Path,
integer_array: np.ndarray,
case_name: str,
iterator: callable,
iterator_options: dict,
backend: Literal["hdf5", "zarr"],
):
array = np.zeros(shape=(30_000 * 5, 384), dtype="int16")
data = iterator(array, **iterator_options)
data = iterator(integer_array, **iterator_options)

nwbfile = mock_NWBFile()
time_series = mock_TimeSeries(name="TestTimeSeries", data=data)
Expand All @@ -57,14 +80,16 @@ def test_simple_time_series(
elif backend == "zarr":
assert written_data.compressor == numcodecs.GZip(level=1)

assert_array_equal(x=integer_array, y=written_data[:])

@pytest.mark.parametrize("backend", ["hdf5", "zarr"])
def test_simple_dynamic_table(tmpdir: Path, backend: Literal["hdf5", "zarr"]):
data = np.zeros(shape=(30_000 * 5, 384), dtype="int16")

@pytest.mark.parametrize("backend", ["hdf5", "zarr"])
def test_simple_dynamic_table(tmpdir: Path, integer_array: np.ndarray, backend: Literal["hdf5", "zarr"]):
nwbfile = mock_NWBFile()
dynamic_table = DynamicTable(
name="TestDynamicTable", description="", columns=[VectorData(name="TestColumn", description="", data=data)]
name="TestDynamicTable",
description="",
columns=[VectorData(name="TestColumn", description="", data=integer_array)],
)
nwbfile.add_acquisition(dynamic_table)

Expand All @@ -87,3 +112,5 @@ def test_simple_dynamic_table(tmpdir: Path, backend: Literal["hdf5", "zarr"]):
assert written_data.compression == "gzip"
elif backend == "zarr":
assert written_data.compressor == numcodecs.GZip(level=1)

assert_array_equal(x=integer_array, y=written_data[:])

0 comments on commit 8ddcfa3

Please sign in to comment.