From 618adf0ef8a129279972063649fced41c8adc67c Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 20 Mar 2024 16:31:34 -0600 Subject: [PATCH] make units a region --- spec/ndx-binned-spikes.extensions.yaml | 6 +- src/pynwb/ndx_binned_spikes/testing/mock.py | 13 ++- src/pynwb/tests/test_binned_aligned_spikes.py | 104 ++++++++++++------ src/spec/create_extension_spec.py | 12 +- 4 files changed, 86 insertions(+), 49 deletions(-) diff --git a/spec/ndx-binned-spikes.extensions.yaml b/spec/ndx-binned-spikes.extensions.yaml index 48ad4b9..0ce14a0 100644 --- a/spec/ndx-binned-spikes.extensions.yaml +++ b/spec/ndx-binned-spikes.extensions.yaml @@ -17,9 +17,9 @@ groups: required: false - name: units dtype: - target_type: Units - reftype: object - doc: A link to the units Table that contains the units of the data. + target_type: DynamicTableRegion + reftype: region + doc: A reference to the Units table region that contains the units of the data. required: false datasets: - name: data diff --git a/src/pynwb/ndx_binned_spikes/testing/mock.py b/src/pynwb/ndx_binned_spikes/testing/mock.py index a522424..5db1da9 100644 --- a/src/pynwb/ndx_binned_spikes/testing/mock.py +++ b/src/pynwb/ndx_binned_spikes/testing/mock.py @@ -3,6 +3,7 @@ from ndx_binned_spikes import BinnedAlignedSpikes import numpy as np + def mock_BinnedAlignedSpikes( number_of_units: int = 2, number_of_event_repetitions: int = 4, @@ -12,7 +13,7 @@ def mock_BinnedAlignedSpikes( seed: int = 0, event_timestamps: Optional[np.ndarray] = None, data: Optional[np.ndarray] = None, -) -> 'BinnedAlignedSpikes': +) -> "BinnedAlignedSpikes": """ Generate a mock BinnedAlignedSpikes object with specified parameters or from given data. @@ -65,13 +66,13 @@ def mock_BinnedAlignedSpikes( else: rng = np.random.default_rng(seed=seed) data = rng.integers(low=0, high=100, size=(number_of_units, number_of_event_repetitions, number_of_bins)) - + if event_timestamps is None: event_timestamps = np.arange(number_of_event_repetitions, dtype="float64") else: - assert event_timestamps.shape[0] == number_of_event_repetitions, ( - "The shape of `event_timestamps` does not match `number_of_event_repetitions`." - ) + assert ( + event_timestamps.shape[0] == number_of_event_repetitions + ), "The shape of `event_timestamps` does not match `number_of_event_repetitions`." event_timestamps = np.array(event_timestamps, dtype="float64") if event_timestamps.shape[0] != data.shape[1]: @@ -83,4 +84,4 @@ def mock_BinnedAlignedSpikes( data=data, event_timestamps=event_timestamps, ) - return binned_aligned_spikes \ No newline at end of file + return binned_aligned_spikes diff --git a/src/pynwb/tests/test_binned_aligned_spikes.py b/src/pynwb/tests/test_binned_aligned_spikes.py index a9d5001..62b0869 100644 --- a/src/pynwb/tests/test_binned_aligned_spikes.py +++ b/src/pynwb/tests/test_binned_aligned_spikes.py @@ -18,46 +18,90 @@ class TestBinnedAlignedSpikesConstructor(TestCase): def setUp(self): """Set up an NWB file. Necessary because BinnedAlignedSpikes requires references to electrodes.""" + + self.number_of_units = 2 + self.number_of_bins = 3 + self.number_of_event_repetitions = 4 + self.bin_width_in_milliseconds = 20.0 + self.milliseconds_from_event_to_first_bin = -100.0 + rng = np.random.default_rng(seed=0) + + self.data = rng.integers( + low=0, + high=100, + size=( + self.number_of_units, + self.number_of_event_repetitions, + self.number_of_bins, + ), + ) + + self.event_timestamps = np.arange(self.number_of_event_repetitions, dtype="float64") + self.nwbfile = mock_NWBFile() def test_constructor(self): """Test that the constructor for BinnedAlignedSpikes sets values as expected.""" - number_of_units = 2 - number_of_bins = 3 - number_of_event_repetitions = 4 - bin_width_in_milliseconds = 20.0 - milliseconds_from_event_to_first_bin = 1.0 - - rng = np.random.default_rng(seed=0) - data = rng.integers(low=0, high=100, size=(number_of_units, number_of_event_repetitions, number_of_bins)) - event_timestamps = np.arange(number_of_event_repetitions, dtype="float64") - binned_aligned_spikes = BinnedAlignedSpikes( - bin_width_in_milliseconds=bin_width_in_milliseconds, - milliseconds_from_event_to_first_bin=milliseconds_from_event_to_first_bin, - data=data, - event_timestamps=event_timestamps + bin_width_in_milliseconds=self.bin_width_in_milliseconds, + milliseconds_from_event_to_first_bin=self.milliseconds_from_event_to_first_bin, + data=self.data, + event_timestamps=self.event_timestamps, ) - - np.testing.assert_array_equal(binned_aligned_spikes.data, data) - np.testing.assert_array_equal(binned_aligned_spikes.event_timestamps, event_timestamps) - self.assertEqual(binned_aligned_spikes.bin_width_in_milliseconds, bin_width_in_milliseconds) + + np.testing.assert_array_equal(binned_aligned_spikes.data, self.data) + np.testing.assert_array_equal(binned_aligned_spikes.event_timestamps, self.event_timestamps) + self.assertEqual(binned_aligned_spikes.bin_width_in_milliseconds, self.bin_width_in_milliseconds) self.assertEqual( - binned_aligned_spikes.milliseconds_from_event_to_first_bin, milliseconds_from_event_to_first_bin + binned_aligned_spikes.milliseconds_from_event_to_first_bin, self.milliseconds_from_event_to_first_bin ) - - self.assertEqual(binned_aligned_spikes.data.shape[0], number_of_units) - self.assertEqual(binned_aligned_spikes.data.shape[1], number_of_event_repetitions) - self.assertEqual(binned_aligned_spikes.data.shape[2], number_of_bins) + self.assertEqual(binned_aligned_spikes.data.shape[0], self.number_of_units) + self.assertEqual(binned_aligned_spikes.data.shape[1], self.number_of_event_repetitions) + self.assertEqual(binned_aligned_spikes.data.shape[2], self.number_of_bins) + def test_constructor_units_region(self): + from pynwb.misc import Units + from hdmf.common import DynamicTableRegion + units_table = Units() + units_table.add_column(name="unit_name", description="a readable identifier for the units") -class TestBinnedAlignedSpikesSimpleRoundtrip(TestCase): - """Simple roundtrip test for BinnedAlignedSpikes.""" + unit_name_a = "a" + spike_times_a = [1.1, 2.2, 3.3] + units_table.add_row(spike_times=spike_times_a, unit_name=unit_name_a) + + unit_name_b = "b" + spike_times_b = [4.4, 5.5, 6.6] + units_table.add_row(spike_times=spike_times_b, unit_name=unit_name_b) + unit_name_c = "c" + spike_times_c = [7.7, 8.8, 9.9] + units_table.add_row(spike_times=spike_times_c, unit_name=unit_name_c) + region_indices = [0, 2] + units_region = DynamicTableRegion( + data=region_indices, table=units_table, description="region of units table", name="units_region" + ) + + binned_aligned_spikes = BinnedAlignedSpikes( + bin_width_in_milliseconds=self.bin_width_in_milliseconds, + milliseconds_from_event_to_first_bin=self.milliseconds_from_event_to_first_bin, + data=self.data, + event_timestamps=self.event_timestamps, + units=units_region, + ) + + unit_table_indices = binned_aligned_spikes.units.data + unit_table_names = binned_aligned_spikes.units.table["unit_name"][unit_table_indices] + + expected_names = [unit_name_a, unit_name_c] + self.assertListEqual(unit_table_names, expected_names) + + +class TestBinnedAlignedSpikesSimpleRoundtrip(TestCase): + """Simple roundtrip test for BinnedAlignedSpikes.""" def setUp(self): self.nwbfile = mock_NWBFile() @@ -85,17 +129,13 @@ def test_roundtrip_acquisition(self): self.assertContainerEqual(self.binned_aligned_spikes, read_nwbfile.acquisition["BinnedAlignedSpikes"]) def test_roundtrip_processing_module(self): - - - ecephys_processinng_module = self.nwbfile.create_processing_module( - name="ecephys", description="a description" - ) + ecephys_processinng_module = self.nwbfile.create_processing_module(name="ecephys", description="a description") ecephys_processinng_module.add(self.binned_aligned_spikes) - + with NWBHDF5IO(self.path, mode="w") as io: io.write(self.nwbfile) with NWBHDF5IO(self.path, mode="r", load_namespaces=True) as io: read_nwbfile = io.read() read_container = read_nwbfile.processing["ecephys"]["BinnedAlignedSpikes"] - self.assertContainerEqual(self.binned_aligned_spikes, read_container) \ No newline at end of file + self.assertContainerEqual(self.binned_aligned_spikes, read_container) diff --git a/src/spec/create_extension_spec.py b/src/spec/create_extension_spec.py index 0dbddac..7996065 100644 --- a/src/spec/create_extension_spec.py +++ b/src/spec/create_extension_spec.py @@ -3,6 +3,7 @@ from pynwb.spec import NWBNamespaceBuilder, export_spec, NWBGroupSpec, NWBAttributeSpec, NWBRefSpec, NWBDatasetSpec + def main(): # these arguments were auto-generated from your cookiecutter inputs ns_builder = NWBNamespaceBuilder( @@ -24,11 +25,6 @@ def main(): # of the other extension below # ns_builder.include_namespace("ndx-other-extension") - # TODO: define your new data types - # see https://pynwb.readthedocs.io/en/stable/tutorials/general/extensions.html - # for more information - - binned_aligned_spikes_data = NWBDatasetSpec( name="data", doc="TODO", @@ -36,7 +32,7 @@ def main(): shape=[(None, None, None)], dims=[("num_units", "number_of_event_repetitions", "number_of_bins")], ) - + event_timestamps = NWBDatasetSpec( name="event_timestamps", doc="The timestamps at which the event occurred.", @@ -68,9 +64,9 @@ def main(): ), NWBAttributeSpec( name="units", - doc="A link to the Units table that contains the units of the data.", + doc="A reference to the Units table region that contains the units of the data.", required=False, - dtype=NWBRefSpec(target_type="Units", reftype="object"), + dtype=NWBRefSpec(target_type="DynamicTableRegion", reftype="region"), ), ], )