From bc7b1723fbf2103d306b2ac1526dd8fea06552df Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 30 Oct 2024 09:51:25 -0600 Subject: [PATCH] add support for nans in the data --- spec/ndx-binned-spikes.extensions.yaml | 2 +- src/pynwb/ndx_binned_spikes/__init__.py | 1 + src/pynwb/ndx_binned_spikes/testing/mock.py | 9 +++++- src/pynwb/tests/test_binned_aligned_spikes.py | 31 +++++++++++++------ src/spec/create_extension_spec.py | 2 +- 5 files changed, 33 insertions(+), 12 deletions(-) diff --git a/spec/ndx-binned-spikes.extensions.yaml b/spec/ndx-binned-spikes.extensions.yaml index 3486762..c347891 100644 --- a/spec/ndx-binned-spikes.extensions.yaml +++ b/spec/ndx-binned-spikes.extensions.yaml @@ -26,7 +26,7 @@ groups: required: false datasets: - name: data - dtype: uint64 + dtype: numeric dims: - num_units - number_of_events diff --git a/src/pynwb/ndx_binned_spikes/__init__.py b/src/pynwb/ndx_binned_spikes/__init__.py index 0f169fa..6a65e0c 100644 --- a/src/pynwb/ndx_binned_spikes/__init__.py +++ b/src/pynwb/ndx_binned_spikes/__init__.py @@ -34,6 +34,7 @@ class BinnedAlignedSpikes(NWBDataInterface): "data", "timestamps", "condition_indices", + "condition_labels", {"name": "units_region", "child": True}, # TODO, I forgot why this is included ) diff --git a/src/pynwb/ndx_binned_spikes/testing/mock.py b/src/pynwb/ndx_binned_spikes/testing/mock.py index e2c8d07..3f8f3c1 100644 --- a/src/pynwb/ndx_binned_spikes/testing/mock.py +++ b/src/pynwb/ndx_binned_spikes/testing/mock.py @@ -19,6 +19,7 @@ def mock_BinnedAlignedSpikes( condition_labels: Optional[np.ndarray] = None, units_region: Optional[DynamicTableRegion] = None, sort_data: bool = True, + add_random_nans: bool = False, ) -> BinnedAlignedSpikes: """ Generate a mock BinnedAlignedSpikes object with specified parameters or from given data. @@ -62,7 +63,7 @@ def mock_BinnedAlignedSpikes( BinnedAlignedSpikes A mock BinnedAlignedSpikes object populated with the provided or generated data and parameters. """ - + if data is not None: number_of_units, number_of_events, number_of_bins = data.shape else: @@ -118,6 +119,12 @@ def mock_BinnedAlignedSpikes( if condition_indices is not None: condition_indices = condition_indices[sorted_indices] + # Add random nans over all the data + if add_random_nans: + data = data.astype("float32") + nan_mask = rng.choice([True, False], size=data.shape, p=[0.1, 0.9]) + data[nan_mask] = np.nan + binned_aligned_spikes = BinnedAlignedSpikes( bin_width_in_milliseconds=bin_width_in_milliseconds, milliseconds_from_event_to_first_bin=milliseconds_from_event_to_first_bin, diff --git a/src/pynwb/tests/test_binned_aligned_spikes.py b/src/pynwb/tests/test_binned_aligned_spikes.py index 6c75c61..be56648 100644 --- a/src/pynwb/tests/test_binned_aligned_spikes.py +++ b/src/pynwb/tests/test_binned_aligned_spikes.py @@ -273,7 +273,7 @@ def test_roundtrip_acquisition(self): number_of_conditions = 3 condition_labels = ["a", "b", "c"] - self.binned_aligned_spikes = mock_BinnedAlignedSpikes( + binned_aligned_spikes = mock_BinnedAlignedSpikes( number_of_units=number_of_units, number_of_bins=number_of_bins, number_of_events=number_of_events, @@ -281,7 +281,7 @@ def test_roundtrip_acquisition(self): condition_labels=condition_labels, ) - self.nwbfile.add_acquisition(self.binned_aligned_spikes) + self.nwbfile.add_acquisition(binned_aligned_spikes) with NWBHDF5IO(self.path, mode="w") as io: io.write(self.nwbfile) @@ -289,23 +289,23 @@ def test_roundtrip_acquisition(self): with NWBHDF5IO(self.path, mode="r", load_namespaces=True) as io: read_nwbfile = io.read() read_binned_aligned_spikes = read_nwbfile.acquisition["BinnedAlignedSpikes"] - self.assertContainerEqual(self.binned_aligned_spikes, read_binned_aligned_spikes) + self.assertContainerEqual(binned_aligned_spikes, read_binned_aligned_spikes) assert read_binned_aligned_spikes.number_of_units == number_of_units assert read_binned_aligned_spikes.number_of_bins == number_of_bins assert read_binned_aligned_spikes.number_of_events == number_of_events assert read_binned_aligned_spikes.number_of_conditions == number_of_conditions - expected_data_condition1 = self.binned_aligned_spikes.get_data_for_condition(condition_index=2) + expected_data_condition1 = binned_aligned_spikes.get_data_for_condition(condition_index=2) data_condition1 = read_binned_aligned_spikes.get_data_for_condition(condition_index=2) np.testing.assert_equal(data_condition1, expected_data_condition1) def test_roundtrip_processing_module(self): - self.binned_aligned_spikes = mock_BinnedAlignedSpikes() + binned_aligned_spikes = mock_BinnedAlignedSpikes() ecephys_processinng_module = self.nwbfile.create_processing_module(name="ecephys", description="a description") - ecephys_processinng_module.add(self.binned_aligned_spikes) + ecephys_processinng_module.add(binned_aligned_spikes) with NWBHDF5IO(self.path, mode="w") as io: io.write(self.nwbfile) @@ -313,7 +313,7 @@ def test_roundtrip_processing_module(self): with NWBHDF5IO(self.path, mode="r", load_namespaces=True) as io: read_nwbfile = io.read() read_container = read_nwbfile.processing["ecephys"]["BinnedAlignedSpikes"] - self.assertContainerEqual(self.binned_aligned_spikes, read_container) + self.assertContainerEqual(binned_aligned_spikes, read_container) def test_roundtrip_with_units_table(self): @@ -332,7 +332,20 @@ def test_roundtrip_with_units_table(self): with NWBHDF5IO(self.path, mode="r", load_namespaces=True) as io: read_nwbfile = io.read() - read_container = read_nwbfile.acquisition["BinnedAlignedSpikes"] - self.assertContainerEqual(binned_aligned_spikes_with_region, read_container) + read_binned_aligned_spikes = read_nwbfile.acquisition["BinnedAlignedSpikes"] + self.assertContainerEqual(binned_aligned_spikes_with_region, read_binned_aligned_spikes) + + + def test_data_with_nans(self): + + binned_aligned_spikes = mock_BinnedAlignedSpikes(add_random_nans=True) + self.nwbfile.add_acquisition(binned_aligned_spikes) + with NWBHDF5IO(self.path, mode="w") as io: + io.write(self.nwbfile) + + with NWBHDF5IO(self.path, mode="r", load_namespaces=True) as io: + read_nwbfile = io.read() + read_binned_aligned_spikes = read_nwbfile.acquisition["BinnedAlignedSpikes"] + self.assertContainerEqual(binned_aligned_spikes, read_binned_aligned_spikes) diff --git a/src/spec/create_extension_spec.py b/src/spec/create_extension_spec.py index cfd38c2..f8a23f1 100644 --- a/src/spec/create_extension_spec.py +++ b/src/spec/create_extension_spec.py @@ -29,7 +29,7 @@ def main(): "The binned data. It should be an array whose first dimension is the number of units, the second dimension " "is the number of events, and the third dimension is the number of bins." ), - dtype="uint64", + dtype="numeric", shape=[None, None, None], dims=["num_units", "number_of_events", "number_of_bins"], )