diff --git a/src/pynwb/ndx_binned_spikes/__init__.py b/src/pynwb/ndx_binned_spikes/__init__.py index c232e06..3afa00d 100644 --- a/src/pynwb/ndx_binned_spikes/__init__.py +++ b/src/pynwb/ndx_binned_spikes/__init__.py @@ -1,5 +1,9 @@ import os +import numpy as np + from pynwb import load_namespaces, get_class +from pynwb.core import NWBDataInterface +from hdmf.utils import docval, popargs_to_dict, get_docval, popargs try: from importlib.resources import files @@ -18,7 +22,89 @@ # Load the namespace load_namespaces(str(__spec_path)) -BinnedAlignedSpikes = get_class("BinnedAlignedSpikes", "ndx-binned-spikes") +# BinnedAlignedSpikes = get_class("BinnedAlignedSpikes", "ndx-binned-spikes") + +from pynwb import register_class, docval + + +@register_class(neurodata_type="BinnedAlignedSpikes", namespace="ndx-binned-spikes") +class BinnedAlignedSpikes(NWBDataInterface): + __nwbfields__ = ( + "name", + "bin_width_in_milliseconds", + "milliseconds_from_event_to_first_bin", + "data", + "event_timestamps", + "units", + ) + + DEFAULT_NAME = "BinnedAlignedSpikes" + + @docval( + { + "name": "name", + "type": str, + "doc": "The name of this container", + "default": DEFAULT_NAME, + }, + { + "name": "bin_width_in_milliseconds", + "type": float, + "doc": "The length in milliseconds of the bins", + }, + { + "name": "milliseconds_from_event_to_first_bin", + "type": float, + "doc": ( + "The time in milliseconds from the event (e.g. a stimuli or the beginning of a trial)," + "to the first bin. Note that this is a negative number if the first bin is before the event." + ), + "default": 0.0, + }, + { + "name": "data", + "type": "array_data", + "shape": [(None, None, None), (None, None)], + "doc": "The source of the data", + }, + { + "name": "event_timestamps", + "type": "array_data", + "doc": "The timestamps at which the event occurred.", + "shape": (None,), + }, + { + "name": "units", + "type": ("DynamicTableRegion"), + "doc": "A reference to the Units table region that contains the units of the data.", + "default": None, + }, + ) + def __init__(self, **kwargs): + + keys_to_set = ("bin_width_in_milliseconds", "milliseconds_from_event_to_first_bin", "units") + args_to_set = popargs_to_dict(keys_to_set, kwargs) + + keys_to_process = ("data", "event_timestamps") # these are properties and cannot be set with setattr + args_to_process = popargs_to_dict(keys_to_process, kwargs) + super().__init__(**kwargs) + + # Set the values + for key, val in args_to_set.items(): + setattr(self, key, val) + + # Post-process / post_init + data = args_to_process["data"] + + data = data if data.ndim == 3 else data[np.newaxis, ...] + + event_timestamps = args_to_process["event_timestamps"] + + if data.shape[1] != event_timestamps.shape[0]: + raise ValueError("The number of event timestamps must match the number of event repetitions in the data.") + + self.fields["data"] = data + self.fields["event_timestamps"] = event_timestamps # Remove these functions from the package diff --git a/src/pynwb/tests/test_binned_aligned_spikes.py b/src/pynwb/tests/test_binned_aligned_spikes.py index 62b0869..3af8e25 100644 --- a/src/pynwb/tests/test_binned_aligned_spikes.py +++ b/src/pynwb/tests/test_binned_aligned_spikes.py @@ -24,9 +24,9 @@ def setUp(self): self.number_of_event_repetitions = 4 self.bin_width_in_milliseconds = 20.0 self.milliseconds_from_event_to_first_bin = -100.0 - rng = np.random.default_rng(seed=0) + self.rng = np.random.default_rng(seed=0) - self.data = rng.integers( + self.data = self.rng.integers( low=0, high=100, size=( @@ -99,6 +99,25 @@ def test_constructor_units_region(self): expected_names = [unit_name_a, unit_name_c] self.assertListEqual(unit_table_names, expected_names) + def test_accepting_input_with_no_number_of_units_dimension(self): + + data = self.rng.integers( + low=0, + high=100, + size=( + self.number_of_event_repetitions, + self.number_of_bins, + ), + ) + binned_aligned_spikes = BinnedAlignedSpikes( + bin_width_in_milliseconds=self.bin_width_in_milliseconds, + milliseconds_from_event_to_first_bin=self.milliseconds_from_event_to_first_bin, + data=data, + event_timestamps=self.event_timestamps, + ) + + self.assertEqual(binned_aligned_spikes.data.shape, (1, self.number_of_event_repetitions, self.number_of_bins)) + class TestBinnedAlignedSpikesSimpleRoundtrip(TestCase): """Simple roundtrip test for BinnedAlignedSpikes."""